12import torch.nn as nn
13
14from labml import experiment
15from labml.configs import option
16from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
19class Model(CIFAR10VGGModel):
26 def conv_block(self, in_channels, out_channels) -> nn.Module:
27 return nn.Sequential(
28 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
29 fnorm.GroupNorm(self.groups, out_channels), # new
30 nn.ReLU(inplace=True),
31 )
33 def __init__(self, groups: int = 32):
34 self.groups = groups # input param:groups to conv_block
35 super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
38class Configs(CIFAR10Configs):
Number of groups
40 groups: int = 16
43@option(Configs.model)
44def model(c: Configs):
48 return Model(c.groups).to(c.device)
51def main():
Create experiment
53 experiment.create(name='cifar10', comment='group norm')
Create configurations
55 conf = Configs()
Load configurations
57 experiment.configs(conf, {
58 'optimizer.optimizer': 'Adam',
59 'optimizer.learning_rate': 2.5e-4,
60 })
Start the experiment and run the training loop
62 with experiment.start():
63 conf.run()
67if __name__ == '__main__':
68 main()