MNIST Experiment for Batch Normalization

12import torch.nn as nn
13import torch.nn.functional as F
14import torch.utils.data
15
16from labml import experiment
17from labml.configs import option
18from labml_nn.experiments.mnist import MNISTConfigs
19from labml_nn.normalization.batch_norm import BatchNorm

Model definition

22class Model(nn.Module):
27    def __init__(self):
28        super().__init__()

Note that we omit the bias parameter

30        self.conv1 = nn.Conv2d(1, 20, 5, 1, bias=False)

Batch normalization with 20 channels (output of convolution layer). The input to this layer will have shape [batch_size, 20, height(24), width(24)]

33        self.bn1 = BatchNorm(20)

35        self.conv2 = nn.Conv2d(20, 50, 5, 1, bias=False)

Batch normalization with 50 channels. The input to this layer will have shape [batch_size, 50, height(8), width(8)]

38        self.bn2 = BatchNorm(50)

40        self.fc1 = nn.Linear(4 * 4 * 50, 500, bias=False)

Batch normalization with 500 channels (output of fully connected layer). The input to this layer will have shape [batch_size, 500]

43        self.bn3 = BatchNorm(500)

45        self.fc2 = nn.Linear(500, 10)
47    def forward(self, x: torch.Tensor):
48        x = F.relu(self.bn1(self.conv1(x)))
49        x = F.max_pool2d(x, 2, 2)
50        x = F.relu(self.bn2(self.conv2(x)))
51        x = F.max_pool2d(x, 2, 2)
52        x = x.view(-1, 4 * 4 * 50)
53        x = F.relu(self.bn3(self.fc1(x)))
54        return self.fc2(x)

Create model

We use MNISTConfigs configurations and set a new function to calculate the model.

57@option(MNISTConfigs.model)
58def model(c: MNISTConfigs):
65    return Model().to(c.device)
68def main():

Create experiment

70    experiment.create(name='mnist_batch_norm')

Create configurations

72    conf = MNISTConfigs()

Load configurations

74    experiment.configs(conf, {
75        'optimizer.optimizer': 'Adam',
76        'optimizer.learning_rate': 0.001,
77    })

Start the experiment and run the training loop

79    with experiment.start():
80        conf.run()

84if __name__ == '__main__':
85    main()