CIFAR10 Experiment for Group Normalization

12import torch.nn as nn
13
14from labml import experiment
15from labml.configs import option
16from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel

VGG model for CIFAR-10 classification

This derives from the generic VGG style architecture.

19class Model(CIFAR10VGGModel):
26    def conv_block(self, in_channels, out_channels) -> nn.Module:
27        return nn.Sequential(
28            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
29            fnorm.GroupNorm(self.groups, out_channels),  # new
30            nn.ReLU(inplace=True),
31        )
33    def __init__(self, groups: int = 32):
34        self.groups = groups  # input param:groups to conv_block
35        super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
38class Configs(CIFAR10Configs):

Number of groups

40    groups: int = 16

Create model

43@option(Configs.model)
44def model(c: Configs):
48    return Model(c.groups).to(c.device)
51def main():

Create experiment

53    experiment.create(name='cifar10', comment='group norm')

Create configurations

55    conf = Configs()

Load configurations

57    experiment.configs(conf, {
58        'optimizer.optimizer': 'Adam',
59        'optimizer.learning_rate': 2.5e-4,
60    })

Start the experiment and run the training loop

62    with experiment.start():
63        conf.run()

67if __name__ == '__main__':
68    main()