This is a PyTorch implementation of Batch-Channel Normalization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Weight Standardization.
Batch-Channel Normalization performs batch normalization followed by a channel normalization (similar to a Group Normalization. When the batch size is small a running mean and variance is used for batch normalization.
Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data.
25import torch
26from torch import nn
27
28from labml_helpers.module import Module
29from labml_nn.normalization.batch_norm import BatchNorm
This first performs a batch normalization - either normal batch norm or a batch norm with estimated mean and variance (exponential mean/variance over multiple batches). Then a channel normalization performed.
32class BatchChannelNorm(Module):
channels
is the number of features in the input groups
is the number of groups the features are divided into eps
is , used in for numerical stability momentum
is the momentum in taking the exponential moving average estimate
is whether to use running mean and variance for batch norm42 def __init__(self, channels: int, groups: int,
43 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
51 super().__init__()
Use estimated batch norm or normal batch norm.
54 if estimate:
55 self.batch_norm = EstimatedBatchNorm(channels,
56 eps=eps, momentum=momentum)
57 else:
58 self.batch_norm = BatchNorm(channels,
59 eps=eps, momentum=momentum)
Channel normalization
62 self.channel_norm = ChannelNorm(channels, groups, eps)
64 def forward(self, x):
65 x = self.batch_norm(x)
66 return self.channel_norm(x)
When input is a batch of image representations, where is the batch size, is the number of channels, is the height and is the width. and .
where,
are the running mean and variances. is the momentum for calculating the exponential mean.
69class EstimatedBatchNorm(Module):
channels
is the number of features in the input eps
is , used in for numerical stability momentum
is the momentum in taking the exponential moving average estimate
is whether to use running mean and variance for batch norm90 def __init__(self, channels: int,
91 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
98 super().__init__()
99
100 self.eps = eps
101 self.momentum = momentum
102 self.affine = affine
103 self.channels = channels
Channel wise transformation parameters
106 if self.affine:
107 self.scale = nn.Parameter(torch.ones(channels))
108 self.shift = nn.Parameter(torch.zeros(channels))
Tensors for and
111 self.register_buffer('exp_mean', torch.zeros(channels))
112 self.register_buffer('exp_var', torch.ones(channels))
x
is a tensor of shape [batch_size, channels, *]
. *
denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]
114 def forward(self, x: torch.Tensor):
Keep old shape
122 x_shape = x.shape
Get the batch size
124 batch_size = x_shape[0]
Sanity check to make sure the number of features is correct
127 assert self.channels == x.shape[1]
Reshape into [batch_size, channels, n]
130 x = x.view(batch_size, self.channels, -1)
Update and in training mode only
133 if self.training:
No backpropagation through and
135 with torch.no_grad():
Calculate the mean across first and last dimensions;
138 mean = x.mean(dim=[0, 2])
Calculate the squared mean across first and last dimensions;
141 mean_x2 = (x ** 2).mean(dim=[0, 2])
Variance for each feature
144 var = mean_x2 - mean ** 2
152 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
153 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
Normalize
157 x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
Scale and shift
162 if self.affine:
163 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
Reshape to original and return
166 return x_norm.view(x_shape)
This is similar to Group Normalization but affine transform is done group wise.
169class ChannelNorm(Module):
groups
is the number of groups the features are divided into channels
is the number of features in the input eps
is , used in for numerical stability affine
is whether to scale and shift the normalized value176 def __init__(self, channels, groups,
177 eps: float = 1e-5, affine: bool = True):
184 super().__init__()
185 self.channels = channels
186 self.groups = groups
187 self.eps = eps
188 self.affine = affine
Parameters for affine transformation.
Note that these transforms are per group, unlike in group norm where they are transformed channel-wise.
193 if self.affine:
194 self.scale = nn.Parameter(torch.ones(groups))
195 self.shift = nn.Parameter(torch.zeros(groups))
x
is a tensor of shape [batch_size, channels, *]
. *
denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]
197 def forward(self, x: torch.Tensor):
Keep the original shape
206 x_shape = x.shape
Get the batch size
208 batch_size = x_shape[0]
Sanity check to make sure the number of features is the same
210 assert self.channels == x.shape[1]
Reshape into [batch_size, groups, n]
213 x = x.view(batch_size, self.groups, -1)
Calculate the mean across last dimension; i.e. the means for each sample and channel group
217 mean = x.mean(dim=[-1], keepdim=True)
Calculate the squared mean across last dimension; i.e. the means for each sample and channel group
220 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
Variance for each sample and feature group
223 var = mean_x2 - mean ** 2
Normalize
228 x_norm = (x - mean) / torch.sqrt(var + self.eps)
Scale and shift group-wise
232 if self.affine:
233 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
Reshape to original and return
236 return x_norm.view(x_shape)