Layer normalization is a simpler normalization method that works on a wider range of settings. Layer normalization transforms the inputs to have zero mean and unit variance across the features. Note that batch normalization fixes the zero mean and unit variance for each element. Layer normalization does it for each batch across all elements.
Layer normalization is generally used for NLP tasks.
We have used layer normalization in most of the transformer implementations.
35from typing import Union, List 36 37import torch 38from torch import nn, Size 39 40from labml_helpers.module import Module
Layer normalization normalizes the input as follows:
When input is a batch of embeddings, where is the batch size and is the number of features. and .
When input is a batch of a sequence of embeddings, where is the batch size, is the number of channels, is the length of the sequence. and .
When input is a batch of image representations, where is the batch size, is the number of channels, is the height and is the width. This is not a widely used scenario. and .
normalized_shapeis the shape of the elements (except the batch). The input should then be
epsis , used in for numerical stability
elementwise_affineis whether to scale and shift the normalized value
We've tried to use the same names for arguments as PyTorch
72 def __init__(self, normalized_shape: Union[int, List[int], Size], *, 73 eps: float = 1e-5, 74 elementwise_affine: bool = True):
87 if isinstance(normalized_shape, int): 88 normalized_shape = torch.Size([normalized_shape]) 89 elif isinstance(normalized_shape, list): 90 normalized_shape = torch.Size(normalized_shape) 91 assert isinstance(normalized_shape, torch.Size)
94 self.normalized_shape = normalized_shape 95 self.eps = eps 96 self.elementwise_affine = elementwise_affine
Create parameters for and for gain and bias
98 if self.elementwise_affine: 99 self.gain = nn.Parameter(torch.ones(normalized_shape)) 100 self.bias = nn.Parameter(torch.zeros(normalized_shape))
is a tensor of shape
[*, S, S, ..., S[n]]
could be any number of dimensions. For example, in an NLP task this will be
[seq_len, batch_size, features]
102 def forward(self, x: torch.Tensor):
Sanity check to make sure the shapes match
110 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
The dimensions to calculate the mean and variance on
113 dims = [-(i + 1) for i in range(len(self.normalized_shape))]
Calculate the mean of all elements; i.e. the means for each element
117 mean = x.mean(dim=dims, keepdim=True)
Calculate the squared mean of all elements; i.e. the means for each element
120 mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)
Variance of all element
122 var = mean_x2 - mean ** 2
125 x_norm = (x - mean) / torch.sqrt(var + self.eps)
Scale and shift
127 if self.elementwise_affine: 128 x_norm = self.gain * x_norm + self.bias
131 return x_norm
138 from labml.logger import inspect 139 140 x = torch.zeros([2, 3, 2, 4]) 141 inspect(x.shape) 142 ln = LayerNorm(x.shape[2:]) 143 144 x = ln(x) 145 inspect(x.shape) 146 inspect(ln.gain.shape)
150if __name__ == '__main__': 151 _test()