Batch Normalization

This is a PyTorch implementation of Batch Normalization from paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

Internal Covariate Shift

The paper defines Internal Covariate Shift as the change in the distribution of network activations due to the change in network parameters during training. For example, let's say there are two layers and . During the beginning of the training outputs (inputs to ) could be in distribution . Then, after some training steps, it could move to . This is internal covariate shift.

Internal covariate shift will adversely affect training speed because the later layers ( in the above example) have to adapt to this shifted distribution.

By stabilizing the distribution, batch normalization minimizes the internal covariate shift.

Normalization

It is known that whitening improves training speed and convergence. Whitening is linearly transforming inputs to have zero mean, unit variance, and be uncorrelated.

Normalizing outside gradient computation doesn't work

Normalizing outside the gradient computation using pre-computed (detached) means and variances doesn't work. For instance. (ignoring variance), let where and is a trained bias and is an outside gradient computation (pre-computed constant).

Note that has no effect on . Therefore, will increase or decrease based , and keep on growing indefinitely in each training update. The paper notes that similar explosions happen with variances.

Batch Normalization

Whitening is computationally expensive because you need to de-correlate and the gradients must flow through the full whitening calculation.

The paper introduces a simplified version which they call Batch Normalization. First simplification is that it normalizes each feature independently to have zero mean and unit variance: where is the -dimensional input.

The second simplification is to use estimates of mean and variance from the mini-batch for normalization; instead of calculating the mean and variance across the whole dataset.

Normalizing each feature to zero mean and unit variance could affect what the layer can represent. As an example paper illustrates that, if the inputs to a sigmoid are normalized most of it will be within range where the sigmoid is linear. To overcome this each feature is scaled and shifted by two trained parameters and . where is the output of the batch normalization layer.

Note that when applying batch normalization after a linear transform like the bias parameter gets cancelled due to normalization. So you can and should omit bias parameter in linear transforms right before the batch normalization.

Batch normalization also makes the back propagation invariant to the scale of the weights and empirically it improves generalization, so it has regularization effects too.

Inference

We need to know and in order to perform the normalization. So during inference, you either need to go through the whole (or part of) dataset and find the mean and variance, or you can use an estimate calculated during training. The usual practice is to calculate an exponential moving average of mean and variance during the training phase and use that for inference.

Here's the training code and a notebook for training a CNN classifier that uses batch normalization for MNIST dataset.

Open In Colab

97import torch
98from torch import nn

Batch Normalization Layer

Batch normalization layer normalizes the input as follows:

When input is a batch of image representations, where is the batch size, is the number of channels, is the height and is the width. and .

When input is a batch of embeddings, where is the batch size and is the number of features. and .

When input is a batch of a sequence embeddings, where is the batch size, is the number of features, and is the length of the sequence. and .

102class BatchNorm(nn.Module):
  • channels is the number of features in the input
  • eps is , used in for numerical stability
  • momentum is the momentum in taking the exponential moving average
  • affine is whether to scale and shift the normalized value
  • track_running_stats is whether to calculate the moving averages or mean and variance

We've tried to use the same names for arguments as PyTorch BatchNorm implementation.

130    def __init__(self, channels: int, *,
131                 eps: float = 1e-5, momentum: float = 0.1,
132                 affine: bool = True, track_running_stats: bool = True):
142        super().__init__()
143
144        self.channels = channels
145
146        self.eps = eps
147        self.momentum = momentum
148        self.affine = affine
149        self.track_running_stats = track_running_stats

Create parameters for and for scale and shift

151        if self.affine:
152            self.scale = nn.Parameter(torch.ones(channels))
153            self.shift = nn.Parameter(torch.zeros(channels))

Create buffers to store exponential moving averages of mean and variance

156        if self.track_running_stats:
157            self.register_buffer('exp_mean', torch.zeros(channels))
158            self.register_buffer('exp_var', torch.ones(channels))

x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

160    def forward(self, x: torch.Tensor):

Keep the original shape

168        x_shape = x.shape

Get the batch size

170        batch_size = x_shape[0]

Sanity check to make sure the number of features is the same

172        assert self.channels == x.shape[1]

Reshape into [batch_size, channels, n]

175        x = x.view(batch_size, self.channels, -1)

We will calculate the mini-batch mean and variance if we are in training mode or if we have not tracked exponential moving averages

179        if self.training or not self.track_running_stats:

Calculate the mean across first and last dimension; i.e. the means for each feature

182            mean = x.mean(dim=[0, 2])

Calculate the squared mean across first and last dimension; i.e. the means for each feature

185            mean_x2 = (x ** 2).mean(dim=[0, 2])

Variance for each feature

187            var = mean_x2 - mean ** 2

Update exponential moving averages

190            if self.training and self.track_running_stats:
191                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
192                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var

Use exponential moving averages as estimates

194        else:
195            mean = self.exp_mean
196            var = self.exp_var

Normalize

199        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)

Scale and shift

201        if self.affine:
202            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

Reshape to original and return

205        return x_norm.view(x_shape)

Simple test

208def _test():
212    from labml.logger import inspect
213
214    x = torch.zeros([2, 3, 2, 4])
215    inspect(x.shape)
216    bn = BatchNorm(3)
217
218    x = bn(x)
219    inspect(x.shape)
220    inspect(bn.exp_var.shape)

224if __name__ == '__main__':
225    _test()