CIFAR10 群归一化实验

12import torch.nn as nn
13
14from labml import experiment
15from labml.configs import option
16from labml_helpers.module import Module
17from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
18from labml_nn.normalization.group_norm import GroupNorm

VGG model for CIFAR-10 classification

This derives from the generic VGG style architecture.

21class Model(CIFAR10VGGModel):
28    def conv_block(self, in_channels, out_channels) -> nn.Module:
29        return nn.Sequential(
30            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
31            fnorm.GroupNorm(self.groups, out_channels),#new
32            nn.ReLU(inplace=True),
33        )
35    def __init__(self, groups: int = 32):
36        self.groups = groups#input param:groups to conv_block
37        super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
40class Configs(CIFAR10Configs):

组数

42    groups: int = 16

创建模型

45@option(Configs.model)
46def model(c: Configs):
50    return Model(c.groups).to(c.device)
53def main():

创建实验

55    experiment.create(name='cifar10', comment='group norm')

创建配置

57    conf = Configs()

装载配置

59    experiment.configs(conf, {
60        'optimizer.optimizer': 'Adam',
61        'optimizer.learning_rate': 2.5e-4,
62    })

开始实验并运行训练循环

64    with experiment.start():
65        conf.run()

69if __name__ == '__main__':
70    main()