This is a PyTorch implementation of Weight Standardization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Batch-Channel Normalization.

Batch normalization **gives a smooth loss landscape** and
**avoids elimination singularities**.
Elimination singularities are nodes of the network that become
useless (e.g. a ReLU that gives 0 all the time).

However, batch normalization doesn’t work well when the batch size is too small, which happens when training large networks because of device memory limitations. The paper introduces Weight Standardization with Batch-Channel Normalization as a better alternative.

Weight Standardization: 1. Normalizes the gradients 2. Smoothes the landscape (reduced Lipschitz constant) 3. Avoids elimination singularities

The Lipschitz constant is the maximum slope a function has between two points. That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies, $\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$ where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.

Elimination singularities are avoided because it keeps the statistics of the outputs similar to the inputs. So as long as the inputs are normally distributed the outputs remain close to normal. This avoids outputs of nodes from always falling beyond the active range of the activation function (e.g. always negative input for a ReLU).

*Refer to the paper for proofs*.

Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data. This uses a 2D-Convolution Layer with Weight Standardization.

`50import torch`

where,

for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$) and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)

`53def weight_standardization(weight: torch.Tensor, eps: float):`

Get $C_{out}$, $C_{in}$ and kernel shape

`72 c_out, c_in, *kernel_shape = weight.shape`

Reshape $W$ to $O \times I$

`74 weight = weight.view(c_out, -1)`

`81 var, mean = torch.var_mean(weight, dim=1, keepdim=True)`

Normalize

`84 weight = (weight - mean) / (torch.sqrt(var + eps))`

Change back to original shape and return

`86 return weight.view(c_out, c_in, *kernel_shape)`