13import torch.nn as nn
14
15from labml import experiment, logger
16from labml.configs import option
17from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
18from labml_nn.normalization.batch_norm import BatchNorm
21class Configs(CIFAR10Configs):
28 pass
31class SmallModel(CIFAR10VGGModel):
创建卷积层和激活
38 def conv_block(self, in_channels, out_channels) -> nn.Module:
42 return nn.Sequential(
卷积层
44 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
批量标准化
46 BatchNorm(out_channels, track_running_stats=False),
激活 ReLU
48 nn.ReLU(inplace=True),
49 )
51 def __init__(self):
使用给定的卷积大小(通道)创建模型
53 super().__init__([[32, 32], [64, 64], [128], [128], [128]])
56@option(Configs.model)
57def _small_model(c: Configs):
61 return SmallModel().to(c.device)
64def main():
创建实验
66 experiment.create(name='cifar10', comment='small model')
创建配置
68 conf = Configs()
装载配置
70 experiment.configs(conf, {
71 'optimizer.optimizer': 'Adam',
72 'optimizer.learning_rate': 2.5e-4,
73 })
设置保存/加载的模型
75 experiment.add_pytorch_models({'model': conf.model})
打印模型中参数的数量
77 logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))
开始实验并运行训练循环
79 with experiment.start():
80 conf.run()
84if __name__ == '__main__':
85 main()