11from labml import experiment
12from labml.configs import option
13from labml_nn.experiments.cifar10 import CIFAR10Configs
14from labml_nn.transformers import TransformerConfigs

配置

我们使用CIFAR10Configs 它来定义所有与数据集相关的配置、优化器和训练循环。

17class Configs(CIFAR10Configs):
27    transformer: TransformerConfigs

补丁的大小

30    patch_size: int = 4

分类头中隐藏层的大小

32    n_hidden_classification: int = 2048

任务中的类数

34    n_classes: int = 10

创建变压器配置

37@option(Configs.transformer)
38def _transformer():
42    return TransformerConfigs()

创建模型

45@option(Configs.model)
46def _vit(c: Configs):
50    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
51        PatchEmbeddings

变压器配置中的变压器尺寸

54    d_model = c.transformer.d_model

创建视觉变压器

56    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
57                             PatchEmbeddings(d_model, c.patch_size, 3),
58                             LearnedPositionalEmbeddings(d_model),
59                             ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
62def main():

创建实验

64    experiment.create(name='ViT', comment='cifar10')

创建配置

66    conf = Configs()

装载配置

68    experiment.configs(conf, {

优化器

70        'optimizer.optimizer': 'Adam',
71        'optimizer.learning_rate': 2.5e-4,

变压器嵌入尺寸

74        'transformer.d_model': 512,

训练周期和批次大小

77        'epochs': 32,
78        'train_batch_size': 64,

增强 CIFAR 10 图像用于训练

81        'train_dataset': 'cifar10_train_augmented',

不要扩大 CIFAR 10 图像进行验证

83        'valid_dataset': 'cifar10_valid_no_augment',
84    })

设置保存/加载的模型

86    experiment.add_pytorch_models({'model': conf.model})

开始实验并运行训练循环

88    with experiment.start():
89        conf.run()

93if __name__ == '__main__':
94    main()