CIFAR 10 でのコンバージョンミキサーのトレーニング

このスクリプトは、CIFAR 10 データセットで ConvMixer をトレーニングします。

これは論文の結果を再現する試みではありません。この論文では、PyTorch画像モデル(timm)にある画像拡張をトレーニングに使用しています。簡略化のためにこれを行ったわけではありません。そのため、検証の精度が低下します

18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs

コンフィギュレーション

データセットに関連するすべての構成、オプティマイザー、トレーニングループを定義するものを使用していますCIFAR10Configs

23class Configs(CIFAR10Configs):

パッチのサイズ、

32    patch_size: int = 2

パッチ埋め込みのチャンネル数、

34    d_model: int = 256
36    n_layers: int = 8

深さ方向の畳み込みのカーネルサイズ、

38    kernel_size: int = 7

タスク内のクラス数

40    n_classes: int = 10

モデル作成

43@option(Configs.model)
44def _conv_mixer(c: Configs):
48    from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings

ConvMixer の作成

51    return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52                     PatchEmbeddings(c.d_model, c.patch_size, 3),
53                     ClassificationHead(c.d_model, c.n_classes)).to(c.device)
56def main():

実験を作成

58    experiment.create(name='ConvMixer', comment='cifar10')

構成の作成

60    conf = Configs()

構成をロード

62    experiment.configs(conf, {

オプティマイザー

64        'optimizer.optimizer': 'Adam',
65        'optimizer.learning_rate': 2.5e-4,

トレーニングエポックとバッチサイズ

68        'epochs': 150,
69        'train_batch_size': 64,

シンプルな画像補正

72        'train_dataset': 'cifar10_train_augmented',

検証のために画像を補足しないでください

74        'valid_dataset': 'cifar10_valid_no_augment',
75    })

保存/読み込み用のモデルを設定

77    experiment.add_pytorch_models({'model': conf.model})

実験を開始し、トレーニングループを実行します

79    with experiment.start():
80        conf.run()

84if __name__ == '__main__':
85    main()