CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization

12import torch.nn as nn
13
14from labml import experiment
15from labml.configs import option
16from labml_helpers.module import Module
17from labml_nn.experiments.cifar10 import CIFAR10Configs
18from labml_nn.normalization.batch_channel_norm import BatchChannelNorm
19from labml_nn.normalization.weight_standardization.conv2d import Conv2d

Model

A VGG model that use Weight Standardization and Batch-Channel Normalization.

22class Model(Module):
29    def __init__(self):
30        super().__init__()
31        layers = []
32        in_channels = 3
33        for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]:
34            for channels in block:
35                layers += [Conv2d(in_channels, channels, kernel_size=3, padding=1),
36                           BatchChannelNorm(channels, 32),
37                           nn.ReLU(inplace=True)]
38                in_channels = channels
39            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
40        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
41        self.layers = nn.Sequential(*layers)
42        self.fc = nn.Linear(512, 10)
44    def __call__(self, x):
45        x = self.layers(x)
46        x = x.view(x.shape[0], -1)
47        return self.fc(x)

Create model

50@option(CIFAR10Configs.model)
51def model(c: CIFAR10Configs):
55    return Model().to(c.device)
58def main():

Create experiment

60    experiment.create(name='cifar10', comment='weight standardization')

Create configurations

62    conf = CIFAR10Configs()

Load configurations

64    experiment.configs(conf, {
65        'optimizer.optimizer': 'Adam',
66        'optimizer.learning_rate': 2.5e-4,
67        'train_batch_size': 64,
68    })

Start the experiment and run the training loop

70    with experiment.start():
71        conf.run()
75if __name__ == '__main__':
76    main()