Group Normalization

This is a PyTorch implementation of the Group Normalization paper.

Batch Normalization works well for large enough batch sizes but not well for small batch sizes, because it normalizes over the batch. Training large models with large batch sizes is not possible due to the memory capacity of the devices.

This paper introduces Group Normalization, which normalizes a set of features together as a group. This is based on the observation that classical features such as SIFT and HOG are group-wise features. The paper proposes dividing feature channels into groups and then separately normalizing all channels within each group.

Formulation

All normalization layers can be defined by the following computation.

where $x$ is the tensor representing the batch, and $i$ is the index of a single value. For instance, when it’s 2D images $i = (i_N, i_C, i_H, i_W)$ is a 4-d vector for indexing image within batch, feature channel, vertical coordinate and horizontal coordinate. $\mu_i$ and $\sigma_i$ are mean and standard deviation.

$\mathcal{S}_i$ is the set of indexes across which the mean and standard deviation are calculated for index $i$. $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

The definition of $\mathcal{S}_i$ is different for Batch normalization, Layer normalization, and Instance normalization.

Batch Normalization

The values that share the same feature channel are normalized together.

Layer Normalization

The values from the same sample in the batch are normalized together.

Instance Normalization

The values from the same sample and same feature channel are normalized together.

Group Normalization

where $G$ is the number of groups and $C$ is the number of channels.

Group normalization normalizes values of the same sample and the same group of channels together.

Here’s a CIFAR 10 classification model that uses instance normalization.

Open In Colab View Run WandB

86import torch
87from torch import nn
88
89from labml_helpers.module import Module

Group Normalization Layer

92class GroupNorm(Module):
  • groups is the number of groups the features are divided into
  • channels is the number of features in the input
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • affine is whether to scale and shift the normalized value
97    def __init__(self, groups: int, channels: int, *,
98                 eps: float = 1e-5, affine: bool = True):
105        super().__init__()
106
107        assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
108        self.groups = groups
109        self.channels = channels
110
111        self.eps = eps
112        self.affine = affine

Create parameters for $\gamma$ and $\beta$ for scale and shift

114        if self.affine:
115            self.scale = nn.Parameter(torch.ones(channels))
116            self.shift = nn.Parameter(torch.zeros(channels))

x is a tensor of shape [batch_size, channels, *]. * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

118    def forward(self, x: torch.Tensor):

Keep the original shape

126        x_shape = x.shape

Get the batch size

128        batch_size = x_shape[0]

Sanity check to make sure the number of features is the same

130        assert self.channels == x.shape[1]

Reshape into [batch_size, groups, n]

133        x = x.view(batch_size, self.groups, -1)

Calculate the mean across last dimension; i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$

137        mean = x.mean(dim=[-1], keepdim=True)

Calculate the squared mean across last dimension; i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$

140        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

Variance for each sample and feature group $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$

143        var = mean_x2 - mean ** 2

Normalize

148        x_norm = (x - mean) / torch.sqrt(var + self.eps)

Scale and shift channel-wise

152        if self.affine:
153            x_norm = x_norm.view(batch_size, self.channels, -1)
154            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

Reshape to original and return

157        return x_norm.view(x_shape)

Simple test

160def _test():
164    from labml.logger import inspect
165
166    x = torch.zeros([2, 6, 2, 4])
167    inspect(x.shape)
168    bn = GroupNorm(2, 6)
169
170    x = bn(x)
171    inspect(x.shape)
175if __name__ == '__main__':
176    _test()