スイッチトランス

これは、論文の「スイッチトランスフォーマー:シンプルで効率的なスパース性を備えた1兆パラメータモデルへのスケーリングのミニチュアPyTorch実装です。私たちの実装には数百万のパラメーターしかなく、モデルの並列分散トレーニングは行いません。シングルGPUトレーニングを行いますが、論文に記載されているようにスイッチングという概念を実装しています

Switch Transformer は、トークンに基づいてパラメーターを切り替えることにより、トークンごとに異なるパラメーターを使用します。したがって、各トークンで選択されるパラメータはごくわずかです。そのため、より多くのパラメーターを使用できますが、計算コストは少なくなります

切り替えは、各トランスブロックの位置ごとのフィードフォワードネットワーク (FFN) で行われます。位置単位のフィードフォワードネットワークは、連続して完全に接続された2つの層で構成されています。スイッチトランスには複数のFFN(複数のエキスパート)がいて、ルーターに基づいてどれを使用するかを選択しました。出力はFFNを選択する確率のセットで、最も確率の高いものを選んで評価します。つまり、基本的に、計算コストは単一の FFN を使用する場合と同じです。私たちの実装では、FFNが多い場合や大きい場合は、すべて1つのGPUで実行されるため、うまく並列化できません。分散型セットアップでは、それぞれの FFN(それぞれが非常に大きい)を異なるデバイスに配置することになります

この論文では、エキスパート(FFN)間で負荷を分散するための別の損失用語を紹介し、ルーティングのバランスが取れていない場合のトークンのドロップについて論じています。

これは、Tiny Shakespeareデータセットのスイッチトランスをトレーニングするためのトレーニングコードとノートブックです

Open In Colab

39import torch
40from torch import nn
41
42from labml_helpers.module import Module
43from labml_nn.transformers.feed_forward import FeedForward
44from labml_nn.transformers.mha import MultiHeadAttention
45from labml_nn.utils import clone_module_list

複数の FFN 間のルーティング

48class SwitchFeedForward(Module):
  • capacity_factor 理想的なバランスの取れた負荷に対する各専門家の能力が係数となるか
  • drop_tokens エキスパートに送られるトークンの数がキャパシティを超える場合に、トークンをドロップするかどうかを指定します
  • is_scale_prob FFN への入力にルーティング確率を掛けるかどうかを指定します
  • n_experts 専門家の数です
  • expert はエキスパートレイヤー、FFN モジュールです
  • d_model トークン埋め込みに含まれる機能の数
  • d_ff は FFN の隠れレイヤーにあるフィーチャの数です
  • dropout 脱落確率はFFNにあるのか
53    def __init__(self, *,
54                 capacity_factor: float,
55                 drop_tokens: bool,
56                 is_scale_prob: bool,
57                 n_experts: int,
58                 expert: FeedForward,
59                 d_model: int):
70        super().__init__()
71
72        self.capacity_factor = capacity_factor
73        self.is_scale_prob = is_scale_prob
74        self.n_experts = n_experts
75        self.drop_tokens = drop_tokens

FFN のコピーを作成

78        self.experts = clone_module_list(expert, n_experts)

ルーティングレイヤーとソフトマックス

80        self.switch = nn.Linear(d_model, n_experts)
81        self.softmax = nn.Softmax(dim=-1)
  • x 形状のあるスイッチングモジュールへの入力です [seq_len, batch_size, d_model]
83    def forward(self, x: torch.Tensor):

形状をキャプチャして後で形状を変更

89        seq_len, batch_size, d_model = x.shape

シーケンスとバッチのディメンションをフラット化

91        x = x.view(-1, d_model)

各トークンのルーティング確率を取得します。ここで、はエキスパートの人数n_expertsはトークン埋め込みの線形変換です

97        route_prob = self.softmax(self.switch(x))

最大ルーティング確率とルートを取得します。最も高い確率で専門家にルーティングします

101        route_prob_max, routes = torch.max(route_prob, dim=-1)

各エキスパートに送られるトークンのインデックスを取得

104        indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]

空のテンソルを初期化して出力を保存する

107        final_output = x.new_zeros(x.shape)

各専門家の能力。

113        capacity = int(self.capacity_factor * len(x) / self.n_experts)

各エキスパートにルーティングされたトークンの数。

115        counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])

ドロップされたトークンの空のリストを初期化

118        dropped = []

drop_tokens True そうである場合にのみトークンをドロップしてください。

120        if self.drop_tokens:

各エキスパートにトークンをドロップ

122            for i in range(self.n_experts):

エキスパートがキャパシティを超えていない場合は無視してください

124                if len(indexes_list[i]) <= capacity:
125                    continue

ドロップする前にインデックスをシャッフルする

127                indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]

容量を超えたトークンをドロップトークンとして集める

129                dropped.append(indexes_list[i][capacity:])

専門家が対応できる範囲でトークンのみを保管してください

131                indexes_list[i] = indexes_list[i][:capacity]

エキスパート FFN のアウトプットを取得

134        expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]

最終出力に割り当て

137        for i in range(self.n_experts):
138            final_output[indexes_list[i], :] = expert_output[i]

ドロップしたトークンをパススルーする

141        if dropped:
142            dropped = torch.cat(dropped)
143            final_output[dropped, :] = x[dropped, :]
144
145        if self.is_scale_prob:

エキスパートのアウトプットに確率を掛けます

147            final_output = final_output * route_prob_max.view(-1, 1)
148        else:

値をスケーリングするのではなく、勾配が流れるように乗算してください(これは実験したものです)。

151            final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)

最終出力の形状をに戻します [seq_len, batch_size, d_model]

154        final_output = final_output.view(seq_len, batch_size, d_model)

リターン

  • 最終出力
  • 各エキスパートにルーティングされるトークンの数
  • 各専門家の確率の合計
  • ドロップされたトークンの数。
  • 選択した専門家のルーティング確率

これらはロードバランシング、ロス、ロギングに使用されます。

165        return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max

スイッチトランスブロック

これは、スイッチフィードフォワードモジュールの追加出力を処理する点で、通常のトランスブロックと同じです

168class SwitchTransformerLayer(Module):
  • d_model トークンの埋め込みサイズです
  • attn アテンションモジュールです
  • feed_forward はフィードフォワードモジュール (この場合はスイッチングモジュール)
  • dropout_prob セルフアテンションとFFNの後に脱落する確率です
176    def __init__(self, *,
177                 d_model: int,
178                 attn: MultiHeadAttention,
179                 feed_forward: SwitchFeedForward,
180                 dropout_prob: float):
187        super().__init__()
188        self.size = d_model
189        self.attn = attn
190        self.feed_forward = feed_forward
191        self.dropout = nn.Dropout(dropout_prob)
192        self.norm_self_attn = nn.LayerNorm([d_model])
193        self.norm_ff = nn.LayerNorm([d_model])
195    def forward(self, *,
196                x: torch.Tensor,
197                mask: torch.Tensor):

セルフアテンションを行う前にベクトルを正規化してください

199        z = self.norm_self_attn(x)

自己注意を向ける。つまり、キーと値は自己からのものだ

201        self_attn = self.attn(query=z, key=z, value=z, mask=mask)

セルフアテンションの結果を追加

203        x = x + self.dropout(self_attn)

フィードフォワード用に正規化

206        z = self.norm_ff(x)

スイッチングフィードフォワードネットワークを経由します

208        ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)

フィードフォワードの結果を追加し直す

210        x = x + self.dropout(ff)
211
212        return x, counts, route_prob, n_dropped, route_prob_max

スイッチトランス

215class SwitchTransformer(Module):
220    def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
221        super().__init__()

トランスレイヤーのコピーを作成

223        self.layers = clone_module_list(layer, n_layers)

最終正規化レイヤー

225        self.norm = nn.LayerNorm([layer.size])
227    def forward(self, x: torch.Tensor, mask: torch.Tensor):

各変圧器層に通す

229        counts, route_prob, n_dropped, route_prob_max = [], [], [], []
230        for layer in self.layers:
231            x, f, p, n_d, p_max = layer(x=x, mask=mask)
232            counts.append(f)
233            route_prob.append(p)
234            n_dropped.append(n_d)
235            route_prob_max.append(p_max)

最後に、ベクトルを正規化します。

237        x = self.norm(x)

239        return x, torch.stack(counts), torch.stack(route_prob), n_dropped, torch.stack(route_prob_max)