This is a PyTorch implementation of Layer Normalization.
Layer normalization is a simpler normalization method that works on a wider range of settings. Layer normalization transforms the inputs to have zero mean and unit variance across the features. Note that batch normalization fixes the zero mean and unit variance for each element. Layer normalization does it for each batch across all elements.
Layer normalization is generally used for NLP tasks.
We have used layer normalization in most of the transformer implementations.
35from typing import Union, List
36
37import torch
38from torch import nn, Size
39
40from labml_helpers.module import Module
Layer normalization $\text{LN}$ normalizes the input $X$ as follows:
When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings, where $B$ is the batch size and $C$ is the number of features. $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
When input $X \in \mathbb{R}^{L \times B \times C}$ is a batch of a sequence of embeddings, where $B$ is the batch size, $C$ is the number of channels, $L$ is the length of the sequence. $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. This is not a widely used scenario. $\gamma \in \mathbb{R}^{C \times H \times W}$ and $\beta \in \mathbb{R}^{C \times H \times W}$.
43class LayerNorm(Module):
normalized_shape
$S$ is the shape of the elements (except the batch).
The input should then be
$X \in \mathbb{R}^{* \times S[0] \times S[1] \times … \times S[n]}$eps
is $\epsilon$, used in $\sqrt{Var[X] + \epsilon}$ for numerical stabilityelementwise_affine
is whether to scale and shift the normalized valueWe’ve tried to use the same names for arguments as PyTorch LayerNorm
implementation.
72 def __init__(self, normalized_shape: Union[int, List[int], Size], *,
73 eps: float = 1e-5,
74 elementwise_affine: bool = True):
84 super().__init__()
85
86 self.normalized_shape = normalized_shape
87 self.eps = eps
88 self.elementwise_affine = elementwise_affine
Create parameters for $\gamma$ and $\beta$ for gain and bias
90 if self.elementwise_affine:
91 self.gain = nn.Parameter(torch.ones(normalized_shape))
92 self.bias = nn.Parameter(torch.zeros(normalized_shape))
x
is a tensor of shape [*, S[0], S[1], ..., S[n]]
.
*
could be any number of dimensions.
For example, in an NLP task this will be
[seq_len, batch_size, features]
94 def forward(self, x: torch.Tensor):
Sanity check to make sure the shapes match
102 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
The dimensions to calculate the mean and variance on
105 dims = [-(i + 1) for i in range(len(self.normalized_shape))]
Calculate the mean of all elements; i.e. the means for each element $\mathbb{E}[X]$
109 mean = x.mean(dim=dims, keepdims=True)
Calculate the squared mean of all elements; i.e. the means for each element $\mathbb{E}[X^2]$
112 mean_x2 = (x ** 2).mean(dim=dims, keepdims=True)
Variance of all element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$
114 var = mean_x2 - mean ** 2
Normalize
117 x_norm = (x - mean) / torch.sqrt(var + self.eps)
Scale and shift
119 if self.elementwise_affine:
120 x_norm = self.gain * x_norm + self.bias
123 return x_norm
Simple test
126def _test():
130 from labml.logger import inspect
131
132 x = torch.zeros([2, 3, 2, 4])
133 inspect(x.shape)
134 ln = LayerNorm(x.shape[2:])
135
136 x = ln(x)
137 inspect(x.shape)
138 inspect(ln.gain.shape)
142if __name__ == '__main__':
143 _test()