在 CIFA R 10 上训练 convMixer

此脚本在 CIFAR 10 数据集上训练 ConvMixer。

这并不是试图重现论文的结果。本文使用 PyTorch 图像模型 (timm) 中存在的图像增强进行训练。为了简单起见,我们没有这样做——这会导致我们的验证精度下降。

18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs

配置

我们使用CIFAR10Configs 它来定义所有与数据集相关的配置、优化器和训练循环。

23class Configs(CIFAR10Configs):

补丁的大小,

32    patch_size: int = 2

补丁嵌入中的通道数,

34    d_model: int = 256

ConvMixer 层数或深度,

36    n_layers: int = 8

深度卷积的内核大小,

38    kernel_size: int = 7

任务中的类数

40    n_classes: int = 10

创建模型

43@option(Configs.model)
44def _conv_mixer(c: Configs):
48    from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings

创建混音器

51    return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52                     PatchEmbeddings(c.d_model, c.patch_size, 3),
53                     ClassificationHead(c.d_model, c.n_classes)).to(c.device)
56def main():

创建实验

58    experiment.create(name='ConvMixer', comment='cifar10')

创建配置

60    conf = Configs()

装载配置

62    experiment.configs(conf, {

优化器

64        'optimizer.optimizer': 'Adam',
65        'optimizer.learning_rate': 2.5e-4,

训练周期和批次大小

68        'epochs': 150,
69        'train_batch_size': 64,

简单的图像增强

72        'train_dataset': 'cifar10_train_augmented',

不要扩充图像以进行验证

74        'valid_dataset': 'cifar10_valid_no_augment',
75    })

设置保存/加载的模型

77    experiment.add_pytorch_models({'model': conf.model})

开始实验并运行训练循环

79    with experiment.start():
80        conf.run()

84if __name__ == '__main__':
85    main()