これは、論文の「スイッチトランスフォーマー:シンプルで効率的なスパース性を備えた1兆パラメータモデルへのスケーリング」のミニチュアPyTorch実装です。私たちの実装には数百万のパラメーターしかなく、モデルの並列分散トレーニングは行いません。シングルGPUトレーニングを行いますが、論文に記載されているようにスイッチングという概念を実装しています
。Switch Transformer は、トークンに基づいてパラメーターを切り替えることにより、トークンごとに異なるパラメーターを使用します。したがって、各トークンで選択されるパラメータはごくわずかです。そのため、より多くのパラメーターを使用できますが、計算コストは少なくなります
。切り替えは、各トランスブロックの位置ごとのフィードフォワードネットワーク (FFN) で行われます。位置単位のフィードフォワードネットワークは、連続して完全に接続された2つの層で構成されています。スイッチトランスには複数のFFN(複数のエキスパート)がいて、ルーターに基づいてどれを使用するかを選択しました。出力はFFNを選択する確率のセットで、最も確率の高いものを選んで評価します。つまり、基本的に、計算コストは単一の FFN を使用する場合と同じです。私たちの実装では、FFNが多い場合や大きい場合は、すべて1つのGPUで実行されるため、うまく並列化できません。分散型セットアップでは、それぞれの FFN(それぞれが非常に大きい)を異なるデバイスに配置することになります
。この論文では、エキスパート(FFN)間で負荷を分散するための別の損失用語を紹介し、ルーティングのバランスが取れていない場合のトークンのドロップについて論じています。
これは、Tiny Shakespeareデータセットのスイッチトランスをトレーニングするためのトレーニングコードとノートブックです。
39import torch
40from torch import nn
41
42from labml_helpers.module import Module
43from labml_nn.transformers.feed_forward import FeedForward
44from labml_nn.transformers.mha import MultiHeadAttention
45from labml_nn.utils import clone_module_list
48class SwitchFeedForward(Module):
capacity_factor
理想的なバランスの取れた負荷に対する各専門家の能力が係数となるかdrop_tokens
エキスパートに送られるトークンの数がキャパシティを超える場合に、トークンをドロップするかどうかを指定しますis_scale_prob
FFN への入力にルーティング確率を掛けるかどうかを指定しますn_experts
専門家の数ですexpert
はエキスパートレイヤー、FFN モジュールですd_model
トークン埋め込みに含まれる機能の数d_ff
は FFN の隠れレイヤーにあるフィーチャの数ですdropout
脱落確率はFFNにあるのか53 def __init__(self, *,
54 capacity_factor: float,
55 drop_tokens: bool,
56 is_scale_prob: bool,
57 n_experts: int,
58 expert: FeedForward,
59 d_model: int):
70 super().__init__()
71
72 self.capacity_factor = capacity_factor
73 self.is_scale_prob = is_scale_prob
74 self.n_experts = n_experts
75 self.drop_tokens = drop_tokens
FFN のコピーを作成
78 self.experts = clone_module_list(expert, n_experts)
ルーティングレイヤーとソフトマックス
80 self.switch = nn.Linear(d_model, n_experts)
81 self.softmax = nn.Softmax(dim=-1)
x
形状のあるスイッチングモジュールへの入力です [seq_len, batch_size, d_model]
83 def forward(self, x: torch.Tensor):
形状をキャプチャして後で形状を変更
89 seq_len, batch_size, d_model = x.shape
シーケンスとバッチのディメンションをフラット化
91 x = x.view(-1, d_model)
各トークンのルーティング確率を取得します。ここで、はエキスパートの人数n_experts
、はトークン埋め込みの線形変換です
97 route_prob = self.softmax(self.switch(x))
最大ルーティング確率とルートを取得します。最も高い確率で専門家にルーティングします
101 route_prob_max, routes = torch.max(route_prob, dim=-1)
各エキスパートに送られるトークンのインデックスを取得
104 indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]
空のテンソルを初期化して出力を保存する
107 final_output = x.new_zeros(x.shape)
各専門家の能力。
113 capacity = int(self.capacity_factor * len(x) / self.n_experts)
各エキスパートにルーティングされたトークンの数。
115 counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
ドロップされたトークンの空のリストを初期化
118 dropped = []
drop_tokens
True
そうである場合にのみトークンをドロップしてください。
120 if self.drop_tokens:
各エキスパートにトークンをドロップ
122 for i in range(self.n_experts):
エキスパートがキャパシティを超えていない場合は無視してください
124 if len(indexes_list[i]) <= capacity:
125 continue
ドロップする前にインデックスをシャッフルする
127 indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
容量を超えたトークンをドロップトークンとして集める
129 dropped.append(indexes_list[i][capacity:])
専門家が対応できる範囲でトークンのみを保管してください
131 indexes_list[i] = indexes_list[i][:capacity]
エキスパート FFN のアウトプットを取得
134 expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]
最終出力に割り当て
137 for i in range(self.n_experts):
138 final_output[indexes_list[i], :] = expert_output[i]
ドロップしたトークンをパススルーする
141 if dropped:
142 dropped = torch.cat(dropped)
143 final_output[dropped, :] = x[dropped, :]
144
145 if self.is_scale_prob:
エキスパートのアウトプットに確率を掛けます
147 final_output = final_output * route_prob_max.view(-1, 1)
148 else:
値をスケーリングするのではなく、勾配が流れるように乗算してください(これは実験したものです)。
151 final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)
最終出力の形状をに戻します [seq_len, batch_size, d_model]
154 final_output = final_output.view(seq_len, batch_size, d_model)
リターン
これらはロードバランシング、ロス、ロギングに使用されます。
165 return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
168class SwitchTransformerLayer(Module):
d_model
トークンの埋め込みサイズですattn
アテンションモジュールですfeed_forward
はフィードフォワードモジュール (この場合はスイッチングモジュール)dropout_prob
セルフアテンションとFFNの後に脱落する確率です176 def __init__(self, *,
177 d_model: int,
178 attn: MultiHeadAttention,
179 feed_forward: SwitchFeedForward,
180 dropout_prob: float):
187 super().__init__()
188 self.size = d_model
189 self.attn = attn
190 self.feed_forward = feed_forward
191 self.dropout = nn.Dropout(dropout_prob)
192 self.norm_self_attn = nn.LayerNorm([d_model])
193 self.norm_ff = nn.LayerNorm([d_model])
195 def forward(self, *,
196 x: torch.Tensor,
197 mask: torch.Tensor):
セルフアテンションを行う前にベクトルを正規化してください
199 z = self.norm_self_attn(x)
自己注意を向ける。つまり、キーと値は自己からのものだ
201 self_attn = self.attn(query=z, key=z, value=z, mask=mask)
セルフアテンションの結果を追加
203 x = x + self.dropout(self_attn)
フィードフォワード用に正規化
206 z = self.norm_ff(x)
スイッチングフィードフォワードネットワークを経由します
208 ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)
フィードフォワードの結果を追加し直す
210 x = x + self.dropout(ff)
211
212 return x, counts, route_prob, n_dropped, route_prob_max
215class SwitchTransformer(Module):
220 def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
221 super().__init__()
トランスレイヤーのコピーを作成
223 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
225 self.norm = nn.LayerNorm([layer.size])
227 def forward(self, x: torch.Tensor, mask: torch.Tensor):
各変圧器層に通す
229 counts, route_prob, n_dropped, route_prob_max = [], [], [], []
230 for layer in self.layers:
231 x, f, p, n_d, p_max = layer(x=x, mask=mask)
232 counts.append(f)
233 route_prob.append(p)
234 n_dropped.append(n_d)
235 route_prob_max.append(p_max)
最後に、ベクトルを正規化します。
237 x = self.norm(x)
239 return x, torch.stack(counts), torch.stack(route_prob), n_dropped, torch.stack(route_prob_max)