This is a PyTorch implementation of Batch Normalization from paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

The paper defines *Internal Covariate Shift* as the change in the distribution of network activations due to the change in network parameters during training. For example, let's say there are two layers $l_{1}$ and $l_{2}$. During the beginning of the training $l_{1}$ outputs (inputs to $l_{2}$) could be in distribution $N(0.5,1)$. Then, after some training steps, it could move to $N(0.6,1.5)$. This is *internal covariate shift*.

Internal covariate shift will adversely affect training speed because the later layers ($l_{2}$ in the above example) have to adapt to this shifted distribution.

By stabilizing the distribution, batch normalization minimizes the internal covariate shift.

It is known that whitening improves training speed and convergence. *Whitening* is linearly transforming inputs to have zero mean, unit variance, and be uncorrelated.

Normalizing outside the gradient computation using pre-computed (detached) means and variances doesn't work. For instance. (ignoring variance), let $x^=x−E[x]$ where $x=u+b$ and $b$ is a trained bias and $E[x]$ is an outside gradient computation (pre-computed constant).

Note that $x^$ has no effect on $b$. Therefore, $b$ will increase or decrease based $∂x∂L $, and keep on growing indefinitely in each training update. The paper notes that similar explosions happen with variances.

Whitening is computationally expensive because you need to de-correlate and the gradients must flow through the full whitening calculation.

The paper introduces a simplified version which they call *Batch Normalization*. First simplification is that it normalizes each feature independently to have zero mean and unit variance: $x^_{(k)}=Var[x_{(k)}] x_{(k)}−E[x_{(k)}] $ where $x=(x_{(1)}...x_{(d)})$ is the $d$-dimensional input.

The second simplification is to use estimates of mean $E[x_{(k)}]$ and variance $Var[x_{(k)}]$ from the mini-batch for normalization; instead of calculating the mean and variance across the whole dataset.

Normalizing each feature to zero mean and unit variance could affect what the layer can represent. As an example paper illustrates that, if the inputs to a sigmoid are normalized most of it will be within $[−1,1]$ range where the sigmoid is linear. To overcome this each feature is scaled and shifted by two trained parameters $γ_{(k)}$ and $β_{(k)}$. $y_{(k)}=γ_{(k)}x^_{(k)}+β_{(k)}$ where $y_{(k)}$ is the output of the batch normalization layer.

Note that when applying batch normalization after a linear transform like $Wu+b$ the bias parameter $b$ gets cancelled due to normalization. So you can and should omit bias parameter in linear transforms right before the batch normalization.

Batch normalization also makes the back propagation invariant to the scale of the weights and empirically it improves generalization, so it has regularization effects too.

We need to know $E[x_{(k)}]$ and $Var[x_{(k)}]$ in order to perform the normalization. So during inference, you either need to go through the whole (or part of) dataset and find the mean and variance, or you can use an estimate calculated during training. The usual practice is to calculate an exponential moving average of mean and variance during the training phase and use that for inference.

Here's the training code and a notebook for training a CNN classifier that uses batch normalization for MNIST dataset.

```
97import torch
98from torch import nn
99
100from labml_helpers.module import Module
```

Batch normalization layer $BN$ normalizes the input $X$ as follows:

When input $X∈R_{B×C×H×W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. $γ∈R_{C}$ and $β∈R_{C}$. $BN(X)=γB,H,WVar [X]+ϵ X−B,H,WE [X] +β$

When input $X∈R_{B×C}$ is a batch of embeddings, where $B$ is the batch size and $C$ is the number of features. $γ∈R_{C}$ and $β∈R_{C}$. $BN(X)=γBVar [X]+ϵ X−BE [X] +β$

When input $X∈R_{B×C×L}$ is a batch of a sequence embeddings, where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence. $γ∈R_{C}$ and $β∈R_{C}$. $BN(X)=γB,LVar [X]+ϵ X−B,LE [X] +β$

`103class BatchNorm(Module):`

`channels`

is the number of features in the input`eps`

is $ϵ$, used in $Var[x_{(k)}]+ϵ $ for numerical stability`momentum`

is the momentum in taking the exponential moving average`affine`

is whether to scale and shift the normalized value`track_running_stats`

is whether to calculate the moving averages or mean and variance

We've tried to use the same names for arguments as PyTorch `BatchNorm`

implementation.

```
131 def __init__(self, channels: int, *,
132 eps: float = 1e-5, momentum: float = 0.1,
133 affine: bool = True, track_running_stats: bool = True):
```

```
143 super().__init__()
144
145 self.channels = channels
146
147 self.eps = eps
148 self.momentum = momentum
149 self.affine = affine
150 self.track_running_stats = track_running_stats
```

Create parameters for $γ$ and $β$ for scale and shift

```
152 if self.affine:
153 self.scale = nn.Parameter(torch.ones(channels))
154 self.shift = nn.Parameter(torch.zeros(channels))
```

Create buffers to store exponential moving averages of mean $E[x_{(k)}]$ and variance $Var[x_{(k)}]$

```
157 if self.track_running_stats:
158 self.register_buffer('exp_mean', torch.zeros(channels))
159 self.register_buffer('exp_var', torch.ones(channels))
```

`x`

is a tensor of shape `[batch_size, channels, *]`

. `*`

denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be `[batch_size, channels, height, width]`

`161 def forward(self, x: torch.Tensor):`

Keep the original shape

`169 x_shape = x.shape`

Get the batch size

`171 batch_size = x_shape[0]`

Sanity check to make sure the number of features is the same

`173 assert self.channels == x.shape[1]`

Reshape into `[batch_size, channels, n]`

`176 x = x.view(batch_size, self.channels, -1)`

We will calculate the mini-batch mean and variance if we are in training mode or if we have not tracked exponential moving averages

`180 if self.training or not self.track_running_stats:`

Calculate the mean across first and last dimension; i.e. the means for each feature $E[x_{(k)}]$

`183 mean = x.mean(dim=[0, 2])`

Calculate the squared mean across first and last dimension; i.e. the means for each feature $E[(x_{(k)})_{2}]$

`186 mean_x2 = (x ** 2).mean(dim=[0, 2])`

Variance for each feature $Var[x_{(k)}]=E[(x_{(k)})_{2}]−E[x_{(k)}]_{2}$

`188 var = mean_x2 - mean ** 2`

Update exponential moving averages

```
191 if self.training and self.track_running_stats:
192 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
193 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
```

Use exponential moving averages as estimates

```
195 else:
196 mean = self.exp_mean
197 var = self.exp_var
```

Normalize $x^_{(k)}=Var[x_{(k)}]+ϵ x_{(k)}−E[x_{(k)}] $

`200 x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)`

Scale and shift $y_{(k)}=γ_{(k)}x^_{(k)}+β_{(k)}$

```
202 if self.affine:
203 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
```

Reshape to original and return

`206 return x_norm.view(x_shape)`

Simple test

`209def _test():`

```
213 from labml.logger import inspect
214
215 x = torch.zeros([2, 3, 2, 4])
216 inspect(x.shape)
217 bn = BatchNorm(3)
218
219 x = bn(x)
220 inspect(x.shape)
221 inspect(bn.exp_var.shape)
```

```
225if __name__ == '__main__':
226 _test()
```