DeepNorm

Open In Colab Open In Comet

This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.

The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.

The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.

Weight Initializations

Usually, the weights are initialized with Xavier or Kaiming initializations. This paper scales (sets the gain) the weights by a constant depending on the size of the transformer.

DeepNorm suggests scaling the weights of the two linear transforms in the Feed-Forward Network, the value projection transform, and the output projection transform of the attention layer. Weights of these transforms are scaled by (has a gain equal to) .

The scaling is implemented in the

Normalization Function

where is a constant that depends on the depth of the transformer, is Layer Normalization, and is the function of the -th transformer sub-layer (FFN or attention).

This function is used to replace Post-LayerNorm.

and constants

Where is the number of layers in the encoder and is the number of layers in the decoder.

Refer to the paper for derivation.

Here is an experiment implementation that uses DeepNorm.

74from typing import Union, List
75
76import torch
77from torch import nn, Size
78
79from labml_nn.normalization.layer_norm import LayerNorm
80from labml_nn.transformers import MultiHeadAttention
81from labml_nn.transformers.feed_forward import FeedForward
82from labml_nn.transformers.utils import subsequent_mask

DeepNorm Normalization

85class DeepNorm(nn.Module):
  • alpha is
  • normalized_shape is the shape for LayerNorm
  • eps is for LayerNorm
  • elementwise_affine is a flag indicating whether to do an elementwise transformation in LayerNorm
92    def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
93                 eps: float = 1e-5,
94                 elementwise_affine: bool = True):
101        super().__init__()
102
103        self.alpha = alpha

Initialize

105        self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
  • x is the output from the previous layer
  • gx is the output of the current sub-layer
107    def forward(self, x: torch.Tensor, gx: torch.Tensor):

113        return x + self.alpha * gx

Transformer Decoder Layer with DeepNorm

This implements a transformer decoder layer with DeepNorm. Encoder layers will have a similar form.

116class DeepNormTransformerLayer(nn.Module):
  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
  • deep_norm_alpha is coefficient in DeepNorm
  • deep_norm_beta is constant for scaling weights initialization
123    def __init__(self, *,
124                 d_model: int,
125                 self_attn: MultiHeadAttention,
126                 feed_forward: FeedForward,
127                 deep_norm_alpha: float,
128                 deep_norm_beta: float,
129                 ):
137        super().__init__()
138
139        self.self_attn = self_attn
140        self.feed_forward = feed_forward

DeepNorms after attention and feed forward network

142        self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
143        self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])

Scale weights after initialization

146        with torch.no_grad():

Feed forward network linear transformations

148            feed_forward.layer1.weight *= deep_norm_beta
149            feed_forward.layer2.weight *= deep_norm_beta

Attention value projection

152            self_attn.value.linear.weight *= deep_norm_beta

Attention output project

154            self_attn.output.weight *= deep_norm_beta

The mask will be initialized on the first call

157        self.mask = None
  • x are the embeddings of shape [seq_len, batch_size, d_model]
159    def forward(self, x: torch.Tensor):

Create causal mask

164        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

166            self.mask = subsequent_mask(len(x)).to(x.device)

Run through self attention, i.e. keys and values are from self

169        x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))

Pass through the feed-forward network

171        x = self.feed_forward_norm(x, self.feed_forward(x))

174        return x