This is a PyTorch implementation of the Group Normalization paper.
Batch Normalization works well for large enough batch sizes but not well for small batch sizes, because it normalizes over the batch. Training large models with large batch sizes is not possible due to the memory capacity of the devices.
This paper introduces Group Normalization, which normalizes a set of features together as a group. This is based on the observation that classical features such as SIFT and HOG are group-wise features. The paper proposes dividing feature channels into groups and then separately normalizing all channels within each group.
All normalization layers can be defined by the following computation.
where is the tensor representing the batch, and is the index of a single value. For instance, when it's 2D images is a 4-d vector for indexing image within batch, feature channel, vertical coordinate and horizontal coordinate. and are mean and standard deviation.
is the set of indexes across which the mean and standard deviation are calculated for index . is the size of the set which is the same for all .
The definition of is different for Batch normalization, Layer normalization, and Instance normalization.
The values that share the same feature channel are normalized together.
The values from the same sample in the batch are normalized together.
The values from the same sample and same feature channel are normalized together.
where is the number of groups and is the number of channels.
Group normalization normalizes values of the same sample and the same group of channels together.
Here's a CIFAR 10 classification model that uses instance normalization.
84import torch 85from torch import nn 86 87from labml_helpers.module import Module
groupsis the number of groups the features are divided into
channelsis the number of features in the input
epsis , used in for numerical stability
affineis whether to scale and shift the normalized value
95 def __init__(self, groups: int, channels: int, *, 96 eps: float = 1e-5, affine: bool = True):
103 super().__init__() 104 105 assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups" 106 self.groups = groups 107 self.channels = channels 108 109 self.eps = eps 110 self.affine = affine
Create parameters for and for scale and shift
112 if self.affine: 113 self.scale = nn.Parameter(torch.ones(channels)) 114 self.shift = nn.Parameter(torch.zeros(channels))
is a tensor of shape
[batch_size, channels, *]
denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be
[batch_size, channels, height, width]
116 def forward(self, x: torch.Tensor):
Keep the original shape
124 x_shape = x.shape
Get the batch size
126 batch_size = x_shape
Sanity check to make sure the number of features is the same
128 assert self.channels == x.shape
[batch_size, groups, n]
131 x = x.view(batch_size, self.groups, -1)
Calculate the mean across last dimension; i.e. the means for each sample and channel group
135 mean = x.mean(dim=[-1], keepdim=True)
Calculate the squared mean across last dimension; i.e. the means for each sample and channel group
138 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
Variance for each sample and feature group
141 var = mean_x2 - mean ** 2
146 x_norm = (x - mean) / torch.sqrt(var + self.eps)
Scale and shift channel-wise
150 if self.affine: 151 x_norm = x_norm.view(batch_size, self.channels, -1) 152 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
Reshape to original and return
155 return x_norm.view(x_shape)
162 from labml.logger import inspect 163 164 x = torch.zeros([2, 6, 2, 4]) 165 inspect(x.shape) 166 bn = GroupNorm(2, 6) 167 168 x = bn(x) 169 inspect(x.shape)
173if __name__ == '__main__': 174 _test()