Instance Normalization

This is a PyTorch implementation of Instance Normalization: The Missing Ingredient for Fast Stylization.

Instance normalization was introduced to improve style transfer. It is based on the observation that stylization should not depend on the contrast of the content image. The "contrast normalization" is

where is a batch of images with dimensions image index , feature channel , and spatial position .

Since it's hard for a convolutional network to learn "contrast normalization", this paper introduces instance normalization which does that.

Here's a CIFAR 10 classification model that uses instance normalization.

29import torch
30from torch import nn

Instance Normalization Layer

Instance normalization layer normalizes the input as follows:

When input is a batch of image representations, where is the batch size, is the number of channels, is the height and is the width. and . The affine transformation with and are optional.

34class InstanceNorm(nn.Module):
  • channels is the number of features in the input
  • eps is , used in for numerical stability
  • affine is whether to scale and shift the normalized value
50    def __init__(self, channels: int, *,
51                 eps: float = 1e-5, affine: bool = True):
57        super().__init__()
58
59        self.channels = channels
60
61        self.eps = eps
62        self.affine = affine

Create parameters for and for scale and shift

64        if self.affine:
65            self.scale = nn.Parameter(torch.ones(channels))
66            self.shift = nn.Parameter(torch.zeros(channels))

x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

68    def forward(self, x: torch.Tensor):

Keep the original shape

76        x_shape = x.shape

Get the batch size

78        batch_size = x_shape[0]

Sanity check to make sure the number of features is the same

80        assert self.channels == x.shape[1]

Reshape into [batch_size, channels, n]

83        x = x.view(batch_size, self.channels, -1)

Calculate the mean across last dimension i.e. the means for each feature

87        mean = x.mean(dim=[-1], keepdim=True)

Calculate the squared mean across first and last dimension; i.e. the means for each feature

90        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

Variance for each feature

92        var = mean_x2 - mean ** 2

Normalize

95        x_norm = (x - mean) / torch.sqrt(var + self.eps)
96        x_norm = x_norm.view(batch_size, self.channels, -1)

Scale and shift

99        if self.affine:
100            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

Reshape to original and return

103        return x_norm.view(x_shape)

Simple test

106def _test():
110    from labml.logger import inspect
111
112    x = torch.zeros([2, 6, 2, 4])
113    inspect(x.shape)
114    bn = InstanceNorm(6)
115
116    x = bn(x)
117    inspect(x.shape)

121if __name__ == '__main__':
122    _test()