This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.

The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.

The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.

Usually, the weights are initialized with Xavier or Kaiming initializations. This paper scales (sets the gain) the weights by a constant $β$ depending on the size of the transformer.

DeepNorm suggests scaling the weights of the two linear transforms in the Feed-Forward Network, the value projection transform, and the output projection transform of the attention layer. Weights of these transforms are scaled by (has a gain equal to) $β$.

The scaling is implemented in the

$x_{l+1}=LN(αx_{l}+G_{l}(x_{l},θ_{l}))$

where $α$ is a constant that depends on the depth of the transformer, $LN$ is Layer Normalization, and $G_{l}(x_{l},θ_{l})$ is the function of the $l$-th transformer sub-layer (FFN or attention).

This function is used to replace Post-LayerNorm.

Where $N$ is the number of layers in the encoder and $M$ is the number of layers in the decoder.

Refer to the paper for derivation.

Here is an experiment implementation that uses DeepNorm.

```
75from typing import Union, List
76
77import torch
78from torch import nn, Size
79
80from labml_nn.normalization.layer_norm import LayerNorm
81from labml_nn.transformers import MultiHeadAttention
82from labml_nn.transformers.feed_forward import FeedForward
83from labml_nn.transformers.utils import subsequent_mask
```

`86class DeepNorm(nn.Module):`

`alpha`

is $α$`normalized_shape`

is the shape for LayerNorm $LN$`eps`

is $ϵ$ for LayerNorm`elementwise_affine`

is a flag indicating whether to do an elementwise transformation in LayerNorm

```
93 def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
94 eps: float = 1e-5,
95 elementwise_affine: bool = True):
```

```
102 super().__init__()
103
104 self.alpha = alpha
```

Initialize $LN$

`106 self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)`

`x`

is the output from the previous layer $x_{l}$`gx`

is the output of the current sub-layer $G_{l}(x_{l},θ_{l})$

`108 def forward(self, x: torch.Tensor, gx: torch.Tensor):`

$x_{l+1}=LN(αx_{l}+G_{l}(x_{l},θ_{l}))$

`114 return x + self.alpha * gx`

This implements a transformer decoder layer with DeepNorm. Encoder layers will have a similar form.

`117class DeepNormTransformerLayer(nn.Module):`

`d_model`

is the token embedding size`self_attn`

is the self attention module`feed_forward`

is the feed forward module`deep_norm_alpha`

is $α$ coefficient in DeepNorm`deep_norm_beta`

is $β$ constant for scaling weights initialization

```
124 def __init__(self, *,
125 d_model: int,
126 self_attn: MultiHeadAttention,
127 feed_forward: FeedForward,
128 deep_norm_alpha: float,
129 deep_norm_beta: float,
130 ):
```

```
138 super().__init__()
139
140 self.self_attn = self_attn
141 self.feed_forward = feed_forward
```

DeepNorms after attention and feed forward network

```
143 self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
144 self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
```

Scale weights after initialization

`147 with torch.no_grad():`

Feed forward network linear transformations

```
149 feed_forward.layer1.weight *= deep_norm_beta
150 feed_forward.layer2.weight *= deep_norm_beta
```

Attention value projection

`153 self_attn.value.linear.weight *= deep_norm_beta`

Attention output project

`155 self_attn.output.weight *= deep_norm_beta`

The mask will be initialized on the first call

`158 self.mask = None`

`x`

are the embeddings of shape`[seq_len, batch_size, d_model]`

`160 def forward(self, x: torch.Tensor):`

Create causal mask

`165 if self.mask is None or self.mask.size(0) != len(x):`

Subsequent mask, will mask out tokens from seeing future tokens

`167 self.mask = subsequent_mask(len(x)).to(x.device)`

Run through self attention, i.e. keys and values are from self

`170 x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))`

Pass through the feed-forward network

`172 x = self.feed_forward_norm(x, self.feed_forward(x))`

`175 return x`