This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.
The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.
The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.
Usually, the weights are initialized with Xavier or Kaiming initializations. This paper scales (sets the gain) the weights by a constant depending on the size of the transformer.
DeepNorm suggests scaling the weights of the two linear transforms in the Feed-Forward Network, the value projection transform, and the output projection transform of the attention layer. Weights of these transforms are scaled by (has a gain equal to) .
The scaling is implemented in the
where is a constant that depends on the depth of the transformer, is Layer Normalization, and is the function of the -th transformer sub-layer (FFN or attention).
This function is used to replace Post-LayerNorm.
Where is the number of layers in the encoder and is the number of layers in the decoder.
Refer to the paper for derivation.
Here is an experiment implementation that uses DeepNorm.
73from typing import Union, List 74 75import torch 76from torch import nn, Size 77 78from labml_nn.normalization.layer_norm import LayerNorm 79from labml_nn.transformers import MultiHeadAttention 80from labml_nn.transformers.feed_forward import FeedForward 81from labml_nn.transformers.utils import subsequent_mask
normalized_shapeis the shape for LayerNorm
epsis for LayerNorm
elementwise_affineis a flag indicating whether to do an elementwise transformation in LayerNorm
91 def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *, 92 eps: float = 1e-5, 93 elementwise_affine: bool = True):
100 super().__init__() 101 102 self.alpha = alpha
104 self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
xis the output from the previous layer
gxis the output of the current sub-layer
106 def forward(self, x: torch.Tensor, gx: torch.Tensor):
112 return self.layer_norm(x + self.alpha * gx)
This implements a transformer decoder layer with DeepNorm. Encoder layers will have a similar form.
d_modelis the token embedding size
self_attnis the self attention module
feed_forwardis the feed forward module
deep_norm_alphais coefficient in DeepNorm
deep_norm_betais constant for scaling weights initialization
122 def __init__(self, *, 123 d_model: int, 124 self_attn: MultiHeadAttention, 125 feed_forward: FeedForward, 126 deep_norm_alpha: float, 127 deep_norm_beta: float, 128 ):
136 super().__init__() 137 138 self.self_attn = self_attn 139 self.feed_forward = feed_forward
DeepNorms after attention and feed forward network
141 self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model]) 142 self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
Scale weights after initialization
145 with torch.no_grad():
Feed forward network linear transformations
147 feed_forward.layer1.weight *= deep_norm_beta 148 feed_forward.layer2.weight *= deep_norm_beta
Attention value projection
151 self_attn.value.linear.weight *= deep_norm_beta
Attention output project
153 self_attn.output.weight *= deep_norm_beta
The mask will be initialized on the first call
156 self.mask = None
xare the embeddings of shape
[seq_len, batch_size, d_model]
158 def forward(self, x: torch.Tensor):
Create causal mask
163 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
165 self.mask = subsequent_mask(len(x)).to(x.device)
Run through self attention, i.e. keys and values are from self
168 x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
Pass through the feed-forward network
170 x = self.feed_forward_norm(x, self.feed_forward(x))
173 return x