This is a PyTorch implementation of Instance Normalization: The Missing Ingredient for Fast Stylization.
Instance normalization was introduced to improve style transfer. It is based on the observation that stylization should not depend on the contrast of the content image. The "contrast normalization" is
where is a batch of images with dimensions image index , feature channel , and spatial position .
Since it's hard for a convolutional network to learn "contrast normalization", this paper introduces instance normalization which does that.
Here's a CIFAR 10 classification model that uses instance normalization.
29import torch 30from torch import nn 31 32from labml_helpers.module import Module
Instance normalization layer normalizes the input as follows:
When input is a batch of image representations, where is the batch size, is the number of channels, is the height and is the width. and . The affine transformation with and are optional.
channelsis the number of features in the input
epsis , used in for numerical stability
affineis whether to scale and shift the normalized value
51 def __init__(self, channels: int, *, 52 eps: float = 1e-5, affine: bool = True):
58 super().__init__() 59 60 self.channels = channels 61 62 self.eps = eps 63 self.affine = affine
Create parameters for and for scale and shift
65 if self.affine: 66 self.scale = nn.Parameter(torch.ones(channels)) 67 self.shift = nn.Parameter(torch.zeros(channels))
is a tensor of shape
[batch_size, channels, *]
denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be
[batch_size, channels, height, width]
69 def forward(self, x: torch.Tensor):
Keep the original shape
77 x_shape = x.shape
Get the batch size
79 batch_size = x_shape
Sanity check to make sure the number of features is the same
81 assert self.channels == x.shape
[batch_size, channels, n]
84 x = x.view(batch_size, self.channels, -1)
Calculate the mean across last dimension i.e. the means for each feature
88 mean = x.mean(dim=[-1], keepdim=True)
Calculate the squared mean across first and last dimension; i.e. the means for each feature
91 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
Variance for each feature
93 var = mean_x2 - mean ** 2
96 x_norm = (x - mean) / torch.sqrt(var + self.eps) 97 x_norm = x_norm.view(batch_size, self.channels, -1)
Scale and shift
100 if self.affine: 101 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
Reshape to original and return
104 return x_norm.view(x_shape)
111 from labml.logger import inspect 112 113 x = torch.zeros([2, 6, 2, 4]) 114 inspect(x.shape) 115 bn = InstanceNorm(6) 116 117 x = bn(x) 118 inspect(x.shape)
122if __name__ == '__main__': 123 _test()