11from labml import experiment
12from labml.configs import option
13from labml_nn.experiments.cifar10 import CIFAR10Configs
14from labml_nn.transformers import TransformerConfigs
17class Configs(CIFAR10Configs):
27 transformer: TransformerConfigs
パッチのサイズ
30 patch_size: int = 4
分類ヘッドの隠れ層のサイズ
32 n_hidden_classification: int = 2048
タスク内のクラス数
34 n_classes: int = 10
トランスフォーマー構成の作成
37@option(Configs.transformer)
38def _transformer():
42 return TransformerConfigs()
45@option(Configs.model)
46def _vit(c: Configs):
50 from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
51 PatchEmbeddings
54 d_model = c.transformer.d_model
ビジョントランスフォーマーの作成
56 return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
57 PatchEmbeddings(d_model, c.patch_size, 3),
58 LearnedPositionalEmbeddings(d_model),
59 ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
62def main():
実験を作成
64 experiment.create(name='ViT', comment='cifar10')
構成の作成
66 conf = Configs()
構成をロード
68 experiment.configs(conf, {
オプティマイザー
70 'optimizer.optimizer': 'Adam',
71 'optimizer.learning_rate': 2.5e-4,
変圧器埋め込みサイズ
74 'transformer.d_model': 512,
トレーニングエポックとバッチサイズ
77 'epochs': 32,
78 'train_batch_size': 64,
トレーニング用の CIFAR 10 イメージの拡張
81 'train_dataset': 'cifar10_train_augmented',
検証用に CIFAR 10 イメージを拡張しないでください
83 'valid_dataset': 'cifar10_valid_no_augment',
84 })
保存/読み込み用のモデルを設定
86 experiment.add_pytorch_models({'model': conf.model})
実験を開始し、トレーニングループを実行します
88 with experiment.start():
89 conf.run()
93if __name__ == '__main__':
94 main()