This is a PyTorch implementation of Layer Normalization.

- You need to maintain running means.
- Tricky for RNNs. Do you need different normalizations for each step?
- Doesn't work with small batch sizes; large NLP models are usually trained with small batch sizes.
- Need to compute means and variances across devices in distributed training.

Layer normalization is a simpler normalization method that works on a wider range of settings. Layer normalization transforms the inputs to have zero mean and unit variance across the features. *Note that batch normalization fixes the zero mean and unit variance for each element.* Layer normalization does it for each batch across all elements.

Layer normalization is generally used for NLP tasks.

We have used layer normalization in most of the transformer implementations.

```
35from typing import Union, List
36
37import torch
38from torch import nn, Size
39
40from labml_helpers.module import Module
```

Layer normalization $LN$ normalizes the input $X$ as follows:

When input $X∈R_{B×C}$ is a batch of embeddings, where $B$ is the batch size and $C$ is the number of features. $γ∈R_{C}$ and $β∈R_{C}$. $LN(X)=γCVar [X]+ϵ X−CE [X] +β$

When input $X∈R_{L×B×C}$ is a batch of a sequence of embeddings, where $B$ is the batch size, $C$ is the number of channels, $L$ is the length of the sequence. $γ∈R_{C}$ and $β∈R_{C}$. $LN(X)=γCVar [X]+ϵ X−CE [X] +β$

When input $X∈R_{B×C×H×W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. This is not a widely used scenario. $γ∈R_{C×H×W}$ and $β∈R_{C×H×W}$. $LN(X)=γC,H,WVar [X]+ϵ X−C,H,WE [X] +β$

`43class LayerNorm(Module):`

`normalized_shape`

$S$ is the shape of the elements (except the batch). The input should then be $X∈R_{∗×S[0]×S[1]×...×S[n]}$`eps`

is $ϵ$, used in $Var[X]+ϵ $ for numerical stability`elementwise_affine`

is whether to scale and shift the normalized value

We've tried to use the same names for arguments as PyTorch `LayerNorm`

implementation.

```
72 def __init__(self, normalized_shape: Union[int, List[int], Size], *,
73 eps: float = 1e-5,
74 elementwise_affine: bool = True):
```

`84 super().__init__()`

Convert `normalized_shape`

to `torch.Size`

```
87 if isinstance(normalized_shape, int):
88 normalized_shape = torch.Size([normalized_shape])
89 elif isinstance(normalized_shape, list):
90 normalized_shape = torch.Size(normalized_shape)
91 assert isinstance(normalized_shape, torch.Size)
```

```
94 self.normalized_shape = normalized_shape
95 self.eps = eps
96 self.elementwise_affine = elementwise_affine
```

Create parameters for $γ$ and $β$ for gain and bias

```
98 if self.elementwise_affine:
99 self.gain = nn.Parameter(torch.ones(normalized_shape))
100 self.bias = nn.Parameter(torch.zeros(normalized_shape))
```

`x`

is a tensor of shape `[*, S[0], S[1], ..., S[n]]`

. `*`

could be any number of dimensions. For example, in an NLP task this will be `[seq_len, batch_size, features]`

`102 def forward(self, x: torch.Tensor):`

Sanity check to make sure the shapes match

`110 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]`

The dimensions to calculate the mean and variance on

`113 dims = [-(i + 1) for i in range(len(self.normalized_shape))]`

Calculate the mean of all elements; i.e. the means for each element $E[X]$

`117 mean = x.mean(dim=dims, keepdim=True)`

Calculate the squared mean of all elements; i.e. the means for each element $E[X_{2}]$

`120 mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)`

Variance of all element $Var[X]=E[X_{2}]−E[X]_{2}$

`122 var = mean_x2 - mean ** 2`

Normalize $X^=Var[X]+ϵ X−E[X] $

`125 x_norm = (x - mean) / torch.sqrt(var + self.eps)`

Scale and shift $LN(x)=γX^+β$

```
127 if self.elementwise_affine:
128 x_norm = self.gain * x_norm + self.bias
```

`131 return x_norm`

Simple test

`134def _test():`

```
138 from labml.logger import inspect
139
140 x = torch.zeros([2, 3, 2, 4])
141 inspect(x.shape)
142 ln = LayerNorm(x.shape[2:])
143
144 x = ln(x)
145 inspect(x.shape)
146 inspect(ln.gain.shape)
```

```
150if __name__ == '__main__':
151 _test()
```