Layer Normalization

This is a PyTorch implementation of Layer Normalization.

Limitations of Batch Normalization

  • You need to maintain running means.
  • Tricky for RNNs. Do you need different normalizations for each step?
  • Doesn't work with small batch sizes; large NLP models are usually trained with small batch sizes.
  • Need to compute means and variances across devices in distributed training.

Layer Normalization

Layer normalization is a simpler normalization method that works on a wider range of settings. Layer normalization transforms the inputs to have zero mean and unit variance across the features. Note that batch normalization fixes the zero mean and unit variance for each element. Layer normalization does it for each batch across all elements.

Layer normalization is generally used for NLP tasks.

We have used layer normalization in most of the transformer implementations.

35from typing import Union, List
37import torch
38from torch import nn, Size
40from labml_helpers.module import Module

Layer Normalization

Layer normalization normalizes the input as follows:

When input is a batch of embeddings, where is the batch size and is the number of features. and .

When input is a batch of a sequence of embeddings, where is the batch size, is the number of channels, is the length of the sequence. and .

When input is a batch of image representations, where is the batch size, is the number of channels, is the height and is the width. This is not a widely used scenario. and .

43class LayerNorm(Module):
  • normalized_shape is the shape of the elements (except the batch). The input should then be
  • eps is , used in for numerical stability
  • elementwise_affine is whether to scale and shift the normalized value

We've tried to use the same names for arguments as PyTorch LayerNorm implementation.

72    def __init__(self, normalized_shape: Union[int, List[int], Size], *,
73                 eps: float = 1e-5,
74                 elementwise_affine: bool = True):
84        super().__init__()

Convert normalized_shape to torch.Size

87        if isinstance(normalized_shape, int):
88            normalized_shape = torch.Size([normalized_shape])
89        elif isinstance(normalized_shape, list):
90            normalized_shape = torch.Size(normalized_shape)
91        assert isinstance(normalized_shape, torch.Size)

94        self.normalized_shape = normalized_shape
95        self.eps = eps
96        self.elementwise_affine = elementwise_affine

Create parameters for and for gain and bias

98        if self.elementwise_affine:
99            self.gain = nn.Parameter(torch.ones(normalized_shape))
100            self.bias = nn.Parameter(torch.zeros(normalized_shape))

x is a tensor of shape [*, S[0], S[1], ..., S[n]] . * could be any number of dimensions. For example, in an NLP task this will be [seq_len, batch_size, features]

102    def forward(self, x: torch.Tensor):

Sanity check to make sure the shapes match

110        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]

The dimensions to calculate the mean and variance on

113        dims = [-(i + 1) for i in range(len(self.normalized_shape))]

Calculate the mean of all elements; i.e. the means for each element

117        mean = x.mean(dim=dims, keepdim=True)

Calculate the squared mean of all elements; i.e. the means for each element

120        mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)

Variance of all element

122        var = mean_x2 - mean ** 2


125        x_norm = (x - mean) / torch.sqrt(var + self.eps)

Scale and shift

127        if self.elementwise_affine:
128            x_norm = self.gain * x_norm + self.bias

131        return x_norm

Simple test

134def _test():
138    from labml.logger import inspect
140    x = torch.zeros([2, 3, 2, 4])
141    inspect(x.shape)
142    ln = LayerNorm(x.shape[2:])
144    x = ln(x)
145    inspect(x.shape)
146    inspect(ln.gain.shape)

150if __name__ == '__main__':
151    _test()