CIFAR 10 でビジョントランスフォーマー (VIT) をトレーニングしましょう

11from labml import experiment
12from labml.configs import option
13from labml_nn.experiments.cifar10 import CIFAR10Configs
14from labml_nn.transformers import TransformerConfigs

コンフィギュレーション

データセットに関連するすべての構成、オプティマイザー、トレーニングループを定義するものを使用していますCIFAR10Configs

17class Configs(CIFAR10Configs):
27    transformer: TransformerConfigs

パッチのサイズ

30    patch_size: int = 4

分類ヘッドの隠れ層のサイズ

32    n_hidden_classification: int = 2048

タスク内のクラス数

34    n_classes: int = 10

トランスフォーマー構成の作成

37@option(Configs.transformer)
38def _transformer():
42    return TransformerConfigs()

モデル作成

45@option(Configs.model)
46def _vit(c: Configs):
50    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
51        PatchEmbeddings
54    d_model = c.transformer.d_model

ビジョントランスフォーマーの作成

56    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
57                             PatchEmbeddings(d_model, c.patch_size, 3),
58                             LearnedPositionalEmbeddings(d_model),
59                             ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
62def main():

実験を作成

64    experiment.create(name='ViT', comment='cifar10')

構成の作成

66    conf = Configs()

構成をロード

68    experiment.configs(conf, {

オプティマイザー

70        'optimizer.optimizer': 'Adam',
71        'optimizer.learning_rate': 2.5e-4,

変圧器埋め込みサイズ

74        'transformer.d_model': 512,

トレーニングエポックとバッチサイズ

77        'epochs': 32,
78        'train_batch_size': 64,

トレーニング用の CIFAR 10 イメージの拡張

81        'train_dataset': 'cifar10_train_augmented',

検証用に CIFAR 10 イメージを拡張しないでください

83        'valid_dataset': 'cifar10_valid_no_augment',
84    })

保存/読み込み用のモデルを設定

86    experiment.add_pytorch_models({'model': conf.model})

実験を開始し、トレーニングループを実行します

88    with experiment.start():
89        conf.run()

93if __name__ == '__main__':
94    main()