ディープ・ノーム

Open In Colab

これは、論文「DeepNet: トランスフォーマーを1,000レイヤーにスケーリング」に掲載されたDeepNormをPyTorchで実装したものです

この論文では、LayerNormに代わる新しい正規化関数と重み初期化スキームにより、非常に深い変圧器を安定させる方法を提案しています。これは、ポストレイヤーノームのパフォーマンスとプレレイヤーノームの安定性を兼ね備えています。DeepNormsを搭載したトランスフォーマーは、学習率のウォームアップがなくても安定しているはずです

この論文はまず、(同じ入力の)レイヤー出力の変化が安定したトレーニング中に徐々に変化し、不安定な場合は最初のトレーニングステップで急速に変化することを示しています。これは、重みを小さい値に初期化し、トレーニングが安定しているところで学習率のウォームアップを行うと起こります。彼らは、層出力への変更を小さく抑えるという考えを利用して、新しい正規化と重み初期化のメカニズムを導き出しています

ウェイト初期化

通常、ウェイトは Xavier または Kaiming の初期化で初期化されます。このペーパーでは、トランスのサイズに応じて一定の割合でウェイトをスケーリング(ゲイン設定)します

DeepNormは、フィードフォワードネットワーク内の2つの線形変換、つまりアテンションレイヤーの値投影変換と出力投影変換の重みをスケーリングすることを提案しています。これらのトランスフォームの重みは (ゲインがと等しい) でスケーリングされます

スケーリングはで実装されています

正規化機能

ここで、は変圧器の深さに依存する定数、は層の正規化、は -番目の変圧器サブレイヤー (FFN または注意) の関数です。

この関数はポストレイヤーノルムの代わりに使われます。

と定数

ここで、はエンコーダーのレイヤー数、はデコーダーのレイヤー数です。

導出については論文を参照してください

DeepNormを使った実験的な実装です

73from typing import Union, List
74
75import torch
76from torch import nn, Size
77
78from labml_nn.normalization.layer_norm import LayerNorm
79from labml_nn.transformers import MultiHeadAttention
80from labml_nn.transformers.feed_forward import FeedForward
81from labml_nn.transformers.utils import subsequent_mask

ディープノルム正規化

84class DeepNorm(nn.Module):
  • alpha
  • normalized_shape はレイヤーノルムの形状です
  • eps レイヤーノルム用
  • elementwise_affine LayerNorm で要素単位の変換を行うかどうかを示すフラグです
91    def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
92                 eps: float = 1e-5,
93                 elementwise_affine: bool = True):
100        super().__init__()
101
102        self.alpha = alpha

[初期化]

104        self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
  • x 前のレイヤーからの出力です
  • gx 現在のサブレイヤーの出力です
106    def forward(self, x: torch.Tensor, gx: torch.Tensor):

112        return self.layer_norm(x + self.alpha * gx)

DeepNorm搭載トランスデコーダレイヤー

これはDeepNormでトランスデコーダーレイヤーを実装しています。エンコーダレイヤーも同様の形式になります

115class DeepNormTransformerLayer(nn.Module):
  • d_model トークンの埋め込みサイズです
  • self_attn セルフアテンションモジュールです
  • feed_forward フィードフォワードモジュールです
  • deep_norm_alpha はディープノームの係数
  • deep_norm_beta スケーリングウェイトの初期化では定数です
122    def __init__(self, *,
123                 d_model: int,
124                 self_attn: MultiHeadAttention,
125                 feed_forward: FeedForward,
126                 deep_norm_alpha: float,
127                 deep_norm_beta: float,
128                 ):
136        super().__init__()
137
138        self.self_attn = self_attn
139        self.feed_forward = feed_forward

DeepNorms アフターアテンションアンドフィードフォワードネットワーク

141        self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
142        self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])

初期化後にウェイトをスケーリング

145        with torch.no_grad():

フィードフォワードネットワークの線形変換

147            feed_forward.layer1.weight *= deep_norm_beta
148            feed_forward.layer2.weight *= deep_norm_beta

アテンション・バリュー・プロジェクション

151            self_attn.value.linear.weight *= deep_norm_beta

アテンションアウトプットプロジェクト

153            self_attn.output.weight *= deep_norm_beta

マスクは最初の呼び出しで初期化されます

156        self.mask = None
  • x 形が埋め込まれているものです [seq_len, batch_size, d_model]
158    def forward(self, x: torch.Tensor):

カジュアルマスクの作成

163        if self.mask is None or self.mask.size(0) != len(x):

次にマスクすると、トークンがマスクされ、将来のトークンが見えなくなります

165            self.mask = subsequent_mask(len(x)).to(x.device)

自己注意を向ける。つまり、キーと値は自己からのものだ

168        x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))

フィードフォワードネットワークを通過

170        x = self.feed_forward_norm(x, self.feed_forward(x))

173        return x