This is a PyTorch implementation of Weight Standardization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Batch-Channel Normalization.
Batch normalization gives a smooth loss landscape and avoids elimination singularities. Elimination singularities are nodes of the network that become useless (e.g. a ReLU that gives 0 all the time).
However, batch normalization doesn't work well when the batch size is too small, which happens when training large networks because of device memory limitations. The paper introduces Weight Standardization with Batch-Channel Normalization as a better alternative.
Weight Standardization: 1. Normalizes the gradients 2. Smoothes the landscape (reduced Lipschitz constant) 3. Avoids elimination singularities
The Lipschitz constant is the maximum slope a function has between two points. That is, is the Lipschitz constant where is the smallest value that satisfies, where .
Elimination singularities are avoided because it keeps the statistics of the outputs similar to the inputs. So as long as the inputs are normally distributed the outputs remain close to normal. This avoids outputs of nodes from always falling beyond the active range of the activation function (e.g. always negative input for a ReLU).
for a 2D-convolution layer is the number of output channels () and is the number of input channels times the kernel size ()
52def weight_standardization(weight: torch.Tensor, eps: float):
Get , and kernel shape
71 c_out, c_in, *kernel_shape = weight.shape
73 weight = weight.view(c_out, -1)
80 var, mean = torch.var_mean(weight, dim=1, keepdim=True)
83 weight = (weight - mean) / (torch.sqrt(var + eps))
Change back to original shape and return
85 return weight.view(c_out, c_in, *kernel_shape)