16import torch.nn as nn
17
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
21from labml_nn.normalization.instance_norm import InstanceNorm24class Model(CIFAR10VGGModel):31    def conv_block(self, in_channels, out_channels) -> nn.Module:
32        return nn.Sequential(
33            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
34            InstanceNorm(out_channels),
35            nn.ReLU(inplace=True),
36        )38    def __init__(self):
39        super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])42@option(CIFAR10Configs.model)
43def _model(c: CIFAR10Configs):47    return Model().to(c.device)50def main():创建实验
52    experiment.create(name='cifar10', comment='instance norm')创建配置
54    conf = CIFAR10Configs()装载配置
56    experiment.configs(conf, {
57        'optimizer.optimizer': 'Adam',
58        'optimizer.learning_rate': 2.5e-4,
59    })开始实验并运行训练循环
61    with experiment.start():
62        conf.run()66if __name__ == '__main__':
67    main()