このスクリプトは、CIFAR 10 データセットで ConvMixer をトレーニングします。
これは論文の結果を再現する試みではありません。この論文では、PyTorch画像モデル(timm)にある画像拡張をトレーニングに使用しています。簡略化のためにこれを行ったわけではありません。そのため、検証の精度が低下します
。18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs
23class Configs(CIFAR10Configs):
パッチのサイズ、
32 patch_size: int = 2
パッチ埋め込みのチャンネル数、
34 d_model: int = 256
36 n_layers: int = 8
深さ方向の畳み込みのカーネルサイズ、
38 kernel_size: int = 7
タスク内のクラス数
40 n_classes: int = 10
43@option(Configs.model)
44def _conv_mixer(c: Configs):
48 from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings
ConvMixer の作成
51 return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52 PatchEmbeddings(c.d_model, c.patch_size, 3),
53 ClassificationHead(c.d_model, c.n_classes)).to(c.device)
56def main():
実験を作成
58 experiment.create(name='ConvMixer', comment='cifar10')
構成の作成
60 conf = Configs()
構成をロード
62 experiment.configs(conf, {
オプティマイザー
64 'optimizer.optimizer': 'Adam',
65 'optimizer.learning_rate': 2.5e-4,
トレーニングエポックとバッチサイズ
68 'epochs': 150,
69 'train_batch_size': 64,
シンプルな画像補正
72 'train_dataset': 'cifar10_train_augmented',
検証のために画像を補足しないでください
74 'valid_dataset': 'cifar10_valid_no_augment',
75 })
保存/読み込み用のモデルを設定
77 experiment.add_pytorch_models({'model': conf.model})
実験を開始し、トレーニングループを実行します
79 with experiment.start():
80 conf.run()
84if __name__ == '__main__':
85 main()