Batch-Channel Normalization

This is a PyTorch implementation of Batch-Channel Normalization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Weight Standardization.

Batch-Channel Normalization performs batch normalization followed by a channel normalization (similar to a Group Normalization. When the batch size is small a running mean and variance is used for batch normalization.

Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data.

Open In Colab

25import torch
26from torch import nn
27
28from labml_helpers.module import Module
29from labml_nn.normalization.batch_norm import BatchNorm

Batch-Channel Normalization

This first performs a batch normalization - either normal batch norm or a batch norm with estimated mean and variance (exponential mean/variance over multiple batches). Then a channel normalization performed.

32class BatchChannelNorm(Module):
  • channels is the number of features in the input
  • groups is the number of groups the features are divided into
  • eps is , used in for numerical stability
  • momentum is the momentum in taking the exponential moving average
  • estimate is whether to use running mean and variance for batch norm
42    def __init__(self, channels: int, groups: int,
43                 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
51        super().__init__()

Use estimated batch norm or normal batch norm.

54        if estimate:
55            self.batch_norm = EstimatedBatchNorm(channels,
56                                                 eps=eps, momentum=momentum)
57        else:
58            self.batch_norm = BatchNorm(channels,
59                                        eps=eps, momentum=momentum)

Channel normalization

62        self.channel_norm = ChannelNorm(channels, groups, eps)
64    def forward(self, x):
65        x = self.batch_norm(x)
66        return self.channel_norm(x)

Estimated Batch Normalization

When input is a batch of image representations, where is the batch size, is the number of channels, is the height and is the width. and .

where,

are the running mean and variances. is the momentum for calculating the exponential mean.

69class EstimatedBatchNorm(Module):
  • channels is the number of features in the input
  • eps is , used in for numerical stability
  • momentum is the momentum in taking the exponential moving average
  • estimate is whether to use running mean and variance for batch norm
90    def __init__(self, channels: int,
91                 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
98        super().__init__()
99
100        self.eps = eps
101        self.momentum = momentum
102        self.affine = affine
103        self.channels = channels

Channel wise transformation parameters

106        if self.affine:
107            self.scale = nn.Parameter(torch.ones(channels))
108            self.shift = nn.Parameter(torch.zeros(channels))

Tensors for and

111        self.register_buffer('exp_mean', torch.zeros(channels))
112        self.register_buffer('exp_var', torch.ones(channels))

x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

114    def forward(self, x: torch.Tensor):

Keep old shape

122        x_shape = x.shape

Get the batch size

124        batch_size = x_shape[0]

Sanity check to make sure the number of features is correct

127        assert self.channels == x.shape[1]

Reshape into [batch_size, channels, n]

130        x = x.view(batch_size, self.channels, -1)

Update and in training mode only

133        if self.training:

No backpropagation through and

135            with torch.no_grad():

Calculate the mean across first and last dimensions;

138                mean = x.mean(dim=[0, 2])

Calculate the squared mean across first and last dimensions;

141                mean_x2 = (x ** 2).mean(dim=[0, 2])

Variance for each feature

144                var = mean_x2 - mean ** 2

Update exponential moving averages

152                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
153                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var

Normalize

157        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)

Scale and shift

162        if self.affine:
163            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

Reshape to original and return

166        return x_norm.view(x_shape)

Channel Normalization

This is similar to Group Normalization but affine transform is done group wise.

169class ChannelNorm(Module):
  • groups is the number of groups the features are divided into
  • channels is the number of features in the input
  • eps is , used in for numerical stability
  • affine is whether to scale and shift the normalized value
176    def __init__(self, channels, groups,
177                 eps: float = 1e-5, affine: bool = True):
184        super().__init__()
185        self.channels = channels
186        self.groups = groups
187        self.eps = eps
188        self.affine = affine

Parameters for affine transformation.

Note that these transforms are per group, unlike in group norm where they are transformed channel-wise.

193        if self.affine:
194            self.scale = nn.Parameter(torch.ones(groups))
195            self.shift = nn.Parameter(torch.zeros(groups))

x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

197    def forward(self, x: torch.Tensor):

Keep the original shape

206        x_shape = x.shape

Get the batch size

208        batch_size = x_shape[0]

Sanity check to make sure the number of features is the same

210        assert self.channels == x.shape[1]

Reshape into [batch_size, groups, n]

213        x = x.view(batch_size, self.groups, -1)

Calculate the mean across last dimension; i.e. the means for each sample and channel group

217        mean = x.mean(dim=[-1], keepdim=True)

Calculate the squared mean across last dimension; i.e. the means for each sample and channel group

220        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

Variance for each sample and feature group

223        var = mean_x2 - mean ** 2

Normalize

228        x_norm = (x - mean) / torch.sqrt(var + self.eps)

Scale and shift group-wise

232        if self.affine:
233            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

Reshape to original and return

236        return x_norm.view(x_shape)