此脚本在 CIFAR 10 数据集上训练 ConvMixer。
这并不是试图重现论文的结果。本文使用 PyTorch 图像模型 (timm) 中存在的图像增强进行训练。为了简单起见,我们没有这样做——这会导致我们的验证精度下降。
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs
23class Configs(CIFAR10Configs):
补丁的大小,
32 patch_size: int = 2
补丁嵌入中的通道数,
34 d_model: int = 256
ConvMixer 层数或深度,
36 n_layers: int = 8
深度卷积的内核大小,
38 kernel_size: int = 7
任务中的类数
40 n_classes: int = 10
43@option(Configs.model)
44def _conv_mixer(c: Configs):
48 from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings
创建混音器
51 return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52 PatchEmbeddings(c.d_model, c.patch_size, 3),
53 ClassificationHead(c.d_model, c.n_classes)).to(c.device)
56def main():
创建实验
58 experiment.create(name='ConvMixer', comment='cifar10')
创建配置
60 conf = Configs()
装载配置
62 experiment.configs(conf, {
优化器
64 'optimizer.optimizer': 'Adam',
65 'optimizer.learning_rate': 2.5e-4,
训练周期和批次大小
68 'epochs': 150,
69 'train_batch_size': 64,
简单的图像增强
72 'train_dataset': 'cifar10_train_augmented',
不要扩充图像以进行验证
74 'valid_dataset': 'cifar10_valid_no_augment',
75 })
设置保存/加载的模型
77 experiment.add_pytorch_models({'model': conf.model})
开始实验并运行训练循环
79 with experiment.start():
80 conf.run()
84if __name__ == '__main__':
85 main()