Batch-Channel Normalization

This is a PyTorch implementation of Batch-Channel Normalization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Weight Standardization.

Batch-Channel Normalization performs batch normalization followed by a channel normalization (similar to a Group Normalization. When the batch size is small a running mean and variance is used for batch normalization.

Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data.

Open In Colab View Run WandB

27import torch
28from torch import nn
29
30from labml_helpers.module import Module
31from labml_nn.normalization.batch_norm import BatchNorm

Batch-Channel Normalization

This first performs a batch normalization - either normal batch norm or a batch norm with estimated mean and variance (exponential mean/variance over multiple batches). Then a channel normalization performed.

34class BatchChannelNorm(Module):
  • channels is the number of features in the input
  • groups is the number of groups the features are divided into
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • momentum is the momentum in taking the exponential moving average
  • estimate is whether to use running mean and variance for batch norm
44    def __init__(self, channels: int, groups: int,
45                 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
53        super().__init__()

Use estimated batch norm or normal batch norm.

56        if estimate:
57            self.batch_norm = EstimatedBatchNorm(channels,
58                                                 eps=eps, momentum=momentum)
59        else:
60            self.batch_norm = BatchNorm(channels,
61                                        eps=eps, momentum=momentum)

Channel normalization

64        self.channel_norm = ChannelNorm(channels, groups, eps)
66    def forward(self, x):
67        x = self.batch_norm(x)
68        return self.channel_norm(x)

Estimated Batch Normalization

When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.

where,

are the running mean and variances. $r$ is the momentum for calculating the exponential mean.

71class EstimatedBatchNorm(Module):
  • channels is the number of features in the input
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • momentum is the momentum in taking the exponential moving average
  • estimate is whether to use running mean and variance for batch norm
92    def __init__(self, channels: int,
93                 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
100        super().__init__()
101
102        self.eps = eps
103        self.momentum = momentum
104        self.affine = affine
105        self.channels = channels

Channel wise transformation parameters

108        if self.affine:
109            self.scale = nn.Parameter(torch.ones(channels))
110            self.shift = nn.Parameter(torch.zeros(channels))

Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$

113        self.register_buffer('exp_mean', torch.zeros(channels))
114        self.register_buffer('exp_var', torch.ones(channels))

x is a tensor of shape [batch_size, channels, *]. * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

116    def forward(self, x: torch.Tensor):

Keep old shape

124        x_shape = x.shape

Get the batch size

126        batch_size = x_shape[0]

Sanity check to make sure the number of features is correct

129        assert self.channels == x.shape[1]

Reshape into [batch_size, channels, n]

132        x = x.view(batch_size, self.channels, -1)

Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only

135        if self.training:

No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$

137            with torch.no_grad():

Calculate the mean across first and last dimensions;

140                mean = x.mean(dim=[0, 2])

Calculate the squared mean across first and last dimensions;

143                mean_x2 = (x ** 2).mean(dim=[0, 2])

Variance for each feature

146                var = mean_x2 - mean ** 2

Update exponential moving averages

153                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
154                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var

Normalize

158        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)

Scale and shift

163        if self.affine:
164            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

Reshape to original and return

167        return x_norm.view(x_shape)

Channel Normalization

This is similar to Group Normalization but affine transform is done group wise.

170class ChannelNorm(Module):
  • groups is the number of groups the features are divided into
  • channels is the number of features in the input
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • affine is whether to scale and shift the normalized value
177    def __init__(self, channels, groups,
178                 eps: float = 1e-5, affine: bool = True):
185        super().__init__()
186        self.channels = channels
187        self.groups = groups
188        self.eps = eps
189        self.affine = affine

Parameters for affine transformation.

Note that these transforms are per group, unlike in group norm where they are transformed channel-wise.

194        if self.affine:
195            self.scale = nn.Parameter(torch.ones(groups))
196            self.shift = nn.Parameter(torch.zeros(groups))

x is a tensor of shape [batch_size, channels, *]. * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

198    def forward(self, x: torch.Tensor):

Keep the original shape

207        x_shape = x.shape

Get the batch size

209        batch_size = x_shape[0]

Sanity check to make sure the number of features is the same

211        assert self.channels == x.shape[1]

Reshape into [batch_size, groups, n]

214        x = x.view(batch_size, self.groups, -1)

Calculate the mean across last dimension; i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$

218        mean = x.mean(dim=[-1], keepdim=True)

Calculate the squared mean across last dimension; i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$

221        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

Variance for each sample and feature group $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$

224        var = mean_x2 - mean ** 2

Normalize

229        x_norm = (x - mean) / torch.sqrt(var + self.eps)

Scale and shift group-wise

233        if self.affine:
234            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

Reshape to original and return

237        return x_norm.view(x_shape)