在 CIFAR 10 上训练一个小型模型

这在 CIFAR 10 上训练了一个小型模型,以测试蒸馏的益处有多大。

13import torch.nn as nn
14
15from labml import experiment, logger
16from labml.configs import option
17from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
18from labml_nn.normalization.batch_norm import BatchNorm

配置

我们使用CIFAR10Configs 它来定义所有与数据集相关的配置、优化器和训练循环。

21class Configs(CIFAR10Configs):
28    pass

适用于 CIFAR-10 分类的 VGG 样式模型

这源于通用的 VGG 风格架构

31class SmallModel(CIFAR10VGGModel):

创建卷积层和激活

38    def conv_block(self, in_channels, out_channels) -> nn.Module:
42        return nn.Sequential(

卷积层

44            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),

批量标准化

46            BatchNorm(out_channels, track_running_stats=False),

激活 ReLU

48            nn.ReLU(inplace=True),
49        )
51    def __init__(self):

使用给定的卷积大小(通道)创建模型

53        super().__init__([[32, 32], [64, 64], [128], [128], [128]])

创建模型

56@option(Configs.model)
57def _small_model(c: Configs):
61    return SmallModel().to(c.device)
64def main():

创建实验

66    experiment.create(name='cifar10', comment='small model')

创建配置

68    conf = Configs()

装载配置

70    experiment.configs(conf, {
71        'optimizer.optimizer': 'Adam',
72        'optimizer.learning_rate': 2.5e-4,
73    })

设置保存/加载的模型

75    experiment.add_pytorch_models({'model': conf.model})

打印模型中参数的数量

77    logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))

开始实验并运行训练循环

79    with experiment.start():
80        conf.run()

84if __name__ == '__main__':
85    main()