16import torch.nn as nn
17
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
21from labml_nn.normalization.instance_norm import InstanceNorm
24class Model(CIFAR10VGGModel):
31 def conv_block(self, in_channels, out_channels) -> nn.Module:
32 return nn.Sequential(
33 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
34 InstanceNorm(out_channels),
35 nn.ReLU(inplace=True),
36 )
38 def __init__(self):
39 super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
42@option(CIFAR10Configs.model)
43def _model(c: CIFAR10Configs):
47 return Model().to(c.device)
50def main():
创建实验
52 experiment.create(name='cifar10', comment='instance norm')
创建配置
54 conf = CIFAR10Configs()
装载配置
56 experiment.configs(conf, {
57 'optimizer.optimizer': 'Adam',
58 'optimizer.learning_rate': 2.5e-4,
59 })
开始实验并运行训练循环
61 with experiment.start():
62 conf.run()
66if __name__ == '__main__':
67 main()