14import torch
15import torch.nn as nn
16
17from labml import experiment, tracker
18from labml.configs import option
19from labml_helpers.module import Module
20from labml_helpers.train_valid import BatchIndex
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
24class AutoregressiveModel(Module):
29 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
30 super().__init__()
トークン埋め込みモジュール
32 self.src_embed = nn.Embedding(n_vocab, d_model)
変圧器
34 self.transformer = transformer
最終レイヤー
36 self.generator = nn.Linear(d_model, n_vocab)
37 self.mask = None
39 def forward(self, x: torch.Tensor):
後続のマスクを初期化
41 if self.mask is None or self.mask.size(0) != len(x):
42 from labml_nn.transformers.utils import subsequent_mask
43 self.mask = subsequent_mask(len(x)).to(x.device)
トークンの埋め込み
45 x = self.src_embed(x)
変圧器に通してください
47 res, counts, route_prob, n_dropped, route_prob_max = self.transformer(x, self.mask)
次のトークンのロジットを生成
49 res = self.generator(res)
51 return res, counts, route_prob, n_dropped, route_prob_max
54class Configs(NLPAutoRegressionConfigs):
63 model: AutoregressiveModel
64 transformer: Module
トークンの埋め込みサイズ
67 d_model: int = 128
アテンションヘッドの数
69 heads: int = 4
脱落確率
71 dropout: float = 0.0
FFN 隠れレイヤーのフィーチャ数
73 d_ff: int = 256
変圧器層の数
75 n_layers: int = 6
エキスパートの数
77 n_experts: int = 4
負荷分散係数
79 load_balancing_loss_ceof = 0.01
選択したエキスパートアウトプットをルーティング確率でスケーリングするかどうか
81 is_scale_prob: bool = True
トークンをドロップするかどうか
83 drop_tokens: bool = False
各モデルの容量を決定する容量係数
85 capacity_factor: float = 1.0
87 def init(self):
88 super().init()
トラッキングインジケータを初期化
90 tracker.set_scalar("lb_loss.*", False)
91 tracker.set_scalar("route.*", False)
92 tracker.set_scalar("dropped.*", False)
94 def step(self, batch: any, batch_idx: BatchIndex):
データをデバイスに移動
100 data, target = batch[0].to(self.device), batch[1].to(self.device)
トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新
103 if self.mode.is_train:
104 tracker.add_global_step(data.shape[0] * data.shape[1])
モデル出力をキャプチャするかどうか
107 with self.mode.update(is_log_activations=batch_idx.is_last):
モデル出力を取得します。
109 output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
クロスエントロピー損失の計算とクロスエントロピー損失
112 cross_entropy_loss = self.loss_func(output, target)
現在のバッチで処理されたトークンの総数
114 total = counts.sum(dim=-1, keepdims=True)
各エキスパートにルーティングされるトークンの割合は、argmax がと等しいトークンの数です。
118 route_frac = counts / total
平均ルーティング確率
121 route_prob = route_prob / total
負荷分散損失は単一レイヤーの損失であり、ここではすべてのレイヤーの損失の合計を求めています。
126 load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
トラック統計
129 tracker.add('dropped.', total.new_tensor(n_dropped) / total)
130 tracker.add('route.min.', route_frac.min())
131 tracker.add('route.max.', route_frac.max())
132 tracker.add('route.std.', route_frac.std())
133 tracker.add('route.max_prob.', route_prob_max)
134 tracker.add("loss.", cross_entropy_loss)
135 tracker.add("lb_loss.", load_balancing_loss)
140 loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss
精度の計算と記録
143 self.accuracy(output, target)
144 self.accuracy.track()
モデルのトレーニング
147 if self.mode.is_train:
勾配の計算
149 loss.backward()
クリップグラデーション
151 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
最適化の一歩を踏み出す
153 self.optimizer.step()
各エポックの最後のバッチでモデルパラメータと勾配を記録します
155 if batch_idx.is_last:
156 tracker.add('model', self.model)
グラデーションをクリア
158 self.optimizer.zero_grad()
追跡したメトリクスを保存する
161 tracker.save()
164@option(Configs.model)
165def autoregressive_model(c: Configs):
169 m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
170 return m.to(c.device)
173@option(Configs.transformer)
174def switch_transformer(c: Configs):
178 from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
179 from labml_nn.transformers import MultiHeadAttention
180 from labml_nn.transformers.feed_forward import FeedForward
181
182 return SwitchTransformer(
183 SwitchTransformerLayer(d_model=c.d_model,
184 attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
185 feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
186 drop_tokens=c.drop_tokens,
187 is_scale_prob=c.is_scale_prob,
188 n_experts=c.n_experts,
189 expert=FeedForward(c.d_model, c.d_ff, c.dropout),
190 d_model=c.d_model),
191 dropout_prob=c.dropout),
192 c.n_layers)
195def main():
実験を作成
200 experiment.create(name="switch_transformer", comment='')
コンフィグの作成
202 conf = Configs()
構成をロード
204 experiment.configs(conf,
オーバーライドする設定の辞書
206 {'tokenizer': 'character',
207 'text': 'tiny_shakespeare',
208 'optimizer.learning_rate': 1.,
209 'optimizer.optimizer': 'Noam',
210 'prompt': 'It is',
211 'prompt_separator': '',
212
213 'transformer': 'switch_transformer',
214 'n_experts': 4,
215
216 'drop_tokens': True,
217 'capacity_factor': 1.2,
218
219 'train_loader': 'shuffled_train_loader',
220 'valid_loader': 'shuffled_valid_loader',
221
222 'seq_len': 64,
223 'epochs': 128,
224 'batch_size': 32,
225 'inner_iterations': 25,
226 })
保存および読み込み用のモデルを設定する
229 experiment.add_pytorch_models({'model': conf.model})
実験を始める
232 with experiment.start():
TrainValidConfigs.run
234 conf.run()
238if __name__ == '__main__':
239 main()