CIFAR10 实例规范化实验

这演示了如何在卷积神经网络中使用实例归一化层进行分类。并不是说实例规范化是为风格转移而设计的,这只是一个演示。

16import torch.nn as nn
17
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
21from labml_nn.normalization.instance_norm import InstanceNorm

用于 CIFAR-10 分类的 VGG 模型

这源于通用的 VGG 风格架构

24class Model(CIFAR10VGGModel):
31    def conv_block(self, in_channels, out_channels) -> nn.Module:
32        return nn.Sequential(
33            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
34            InstanceNorm(out_channels),
35            nn.ReLU(inplace=True),
36        )
38    def __init__(self):
39        super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])

创建模型

42@option(CIFAR10Configs.model)
43def _model(c: CIFAR10Configs):
47    return Model().to(c.device)
50def main():

创建实验

52    experiment.create(name='cifar10', comment='instance norm')

创建配置

54    conf = CIFAR10Configs()

装载配置

56    experiment.configs(conf, {
57        'optimizer.optimizer': 'Adam',
58        'optimizer.learning_rate': 2.5e-4,
59    })

开始实验并运行训练循环

61    with experiment.start():
62        conf.run()

66if __name__ == '__main__':
67    main()