Weight Standardization

This is a PyTorch implementation of Weight Standardization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Batch-Channel Normalization.

Batch normalization gives a smooth loss landscape and avoids elimination singularities. Elimination singularities are nodes of the network that become useless (e.g. a ReLU that gives 0 all the time).

However, batch normalization doesn't work well when the batch size is too small, which happens when training large networks because of device memory limitations. The paper introduces Weight Standardization with Batch-Channel Normalization as a better alternative.

Weight Standardization: 1. Normalizes the gradients 2. Smoothes the landscape (reduced Lipschitz constant) 3. Avoids elimination singularities

The Lipschitz constant is the maximum slope a function has between two points. That is, is the Lipschitz constant where is the smallest value that satisfies, where .

Elimination singularities are avoided because it keeps the statistics of the outputs similar to the inputs. So as long as the inputs are normally distributed the outputs remain close to normal. This avoids outputs of nodes from always falling beyond the active range of the activation function (e.g. always negative input for a ReLU).

Refer to the paper for proofs.

Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data. This uses a 2D-Convolution Layer with Weight Standardization.

Open In Colab View Run WandB

50import torch

Weight Standardization

where,

for a 2D-convolution layer is the number of output channels () and is the number of input channels times the kernel size ()

53def weight_standardization(weight: torch.Tensor, eps: float):

Get , and kernel shape

72    c_out, c_in, *kernel_shape = weight.shape

Reshape to

74    weight = weight.view(c_out, -1)

Calculate

81    var, mean = torch.var_mean(weight, dim=1, keepdim=True)

Normalize

84    weight = (weight - mean) / (torch.sqrt(var + eps))

Change back to original shape and return

86    return weight.view(c_out, c_in, *kernel_shape)