This is a PyTorch implementation of Instance Normalization: The Missing Ingredient for Fast Stylization.

Instance normalization was introduced to improve style transfer. It is based on the observation that stylization should not depend on the contrast of the content image. The “contrast normalization” is

where $x$ is a batch of images with dimensions image index $t$, feature channel $i$, and spatial position $j, k$.

Since it’s hard for a convolutional network to learn “contrast normalization”, this paper introduces instance normalization which does that.

Here’s a CIFAR 10 classification model that uses instance normalization.

```
29import torch
30from torch import nn
31
32from labml_helpers.module import Module
```

Instance normalization layer $\text{IN}$ normalizes the input $X$ as follows:

When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. The affine transformation with $gamma$ and $beta$ are optional.

`35class InstanceNorm(Module):`

`channels`

is the number of features in the input`eps`

is $\epsilon$, used in $\sqrt{Var[X] + \epsilon}$ for numerical stability`affine`

is whether to scale and shift the normalized value

```
51 def __init__(self, channels: int, *,
52 eps: float = 1e-5, affine: bool = True):
```

```
58 super().__init__()
59
60 self.channels = channels
61
62 self.eps = eps
63 self.affine = affine
```

Create parameters for $\gamma$ and $\beta$ for scale and shift

```
65 if self.affine:
66 self.scale = nn.Parameter(torch.ones(channels))
67 self.shift = nn.Parameter(torch.zeros(channels))
```

`x`

is a tensor of shape `[batch_size, channels, *]`

.
`*`

denotes any number of (possibly 0) dimensions.
For example, in an image (2D) convolution this will be
`[batch_size, channels, height, width]`

`69 def forward(self, x: torch.Tensor):`

Keep the original shape

`77 x_shape = x.shape`

Get the batch size

`79 batch_size = x_shape[0]`

Sanity check to make sure the number of features is the same

`81 assert self.channels == x.shape[1]`

Reshape into `[batch_size, channels, n]`

`84 x = x.view(batch_size, self.channels, -1)`

Calculate the mean across last dimension i.e. the means for each feature $\mathbb{E}[x_{t,i}]$

`88 mean = x.mean(dim=[-1], keepdim=True)`

Calculate the squared mean across first and last dimension; i.e. the means for each feature $\mathbb{E}[(x_{t,i}^2]$

`91 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)`

Variance for each feature $Var[x_{t,i}] = \mathbb{E}[x_{t,i}^2] - \mathbb{E}[x_{t,i}]^2$

`93 var = mean_x2 - mean ** 2`

Normalize

```
96 x_norm = (x - mean) / torch.sqrt(var + self.eps)
97 x_norm = x_norm.view(batch_size, self.channels, -1)
```

Scale and shift

```
100 if self.affine:
101 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
```

Reshape to original and return

`104 return x_norm.view(x_shape)`

Simple test

`107def _test():`

```
111 from labml.logger import inspect
112
113 x = torch.zeros([2, 6, 2, 4])
114 inspect(x.shape)
115 bn = InstanceNorm(6)
116
117 x = bn(x)
118 inspect(x.shape)
```

```
122if __name__ == '__main__':
123 _test()
```