スイッチトランス実験

これは、スイッチトランスをトレーニングするための注釈付きPyTorch実験です。

Open In Colab

14import torch
15import torch.nn as nn
16
17from labml import experiment, tracker
18from labml.configs import option
19from labml_helpers.module import Module
20from labml_helpers.train_valid import BatchIndex
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs

自動回帰モデル

24class AutoregressiveModel(Module):
29    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
30        super().__init__()

トークン埋め込みモジュール

32        self.src_embed = nn.Embedding(n_vocab, d_model)

変圧器

34        self.transformer = transformer

最終レイヤー

36        self.generator = nn.Linear(d_model, n_vocab)
37        self.mask = None
39    def forward(self, x: torch.Tensor):

後続のマスクを初期化

41        if self.mask is None or self.mask.size(0) != len(x):
42            from labml_nn.transformers.utils import subsequent_mask
43            self.mask = subsequent_mask(len(x)).to(x.device)

トークンの埋め込み

45        x = self.src_embed(x)

変圧器に通してください

47        res, counts, route_prob, n_dropped, route_prob_max = self.transformer(x, self.mask)

次のトークンのロジットを生成

49        res = self.generator(res)

51        return res, counts, route_prob, n_dropped, route_prob_max

コンフィギュレーション

これは広がります。NLPAutoRegressionConfigs

デフォルトの設定は、実験を開始したときに上書きでき、また上書きされます。

54class Configs(NLPAutoRegressionConfigs):
63    model: AutoregressiveModel
64    transformer: Module

トークンの埋め込みサイズ

67    d_model: int = 128

アテンションヘッドの数

69    heads: int = 4

脱落確率

71    dropout: float = 0.0

FFN 隠れレイヤーのフィーチャ数

73    d_ff: int = 256

変圧器層の数

75    n_layers: int = 6

エキスパートの数

77    n_experts: int = 4

負荷分散係数

79    load_balancing_loss_ceof = 0.01

選択したエキスパートアウトプットをルーティング確率でスケーリングするかどうか

81    is_scale_prob: bool = True

トークンをドロップするかどうか

83    drop_tokens: bool = False

各モデルの容量を決定する容量係数

85    capacity_factor: float = 1.0
87    def init(self):
88        super().init()

トラッキングインジケータを初期化

90        tracker.set_scalar("lb_loss.*", False)
91        tracker.set_scalar("route.*", False)
92        tracker.set_scalar("dropped.*", False)

トレーニングまたは検証ステップ

94    def step(self, batch: any, batch_idx: BatchIndex):

データをデバイスに移動

100        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新

103        if self.mode.is_train:
104            tracker.add_global_step(data.shape[0] * data.shape[1])

モデル出力をキャプチャするかどうか

107        with self.mode.update(is_log_activations=batch_idx.is_last):

モデル出力を取得します。

109            output, counts, route_prob, n_dropped, route_prob_max = self.model(data)

クロスエントロピー損失の計算とクロスエントロピー損失

112        cross_entropy_loss = self.loss_func(output, target)

現在のバッチで処理されたトークンの総数

114        total = counts.sum(dim=-1, keepdims=True)

各エキスパートにルーティングされるトークンの割合は、argmax がと等しいトークンの数です。

118        route_frac = counts / total

平均ルーティング確率

121        route_prob = route_prob / total

負荷分散損失は単一レイヤーの損失であり、ここではすべてのレイヤーの損失の合計を求めています。

126        load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()

トラック統計

129        tracker.add('dropped.', total.new_tensor(n_dropped) / total)
130        tracker.add('route.min.', route_frac.min())
131        tracker.add('route.max.', route_frac.max())
132        tracker.add('route.std.', route_frac.std())
133        tracker.add('route.max_prob.', route_prob_max)
134        tracker.add("loss.", cross_entropy_loss)
135        tracker.add("lb_loss.", load_balancing_loss)

複合損失。負荷分散損失には、次のような小さな値に設定された係数が乗算されます

140        loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss

精度の計算と記録

143        self.accuracy(output, target)
144        self.accuracy.track()

モデルのトレーニング

147        if self.mode.is_train:

勾配の計算

149            loss.backward()

クリップグラデーション

151            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

最適化の一歩を踏み出す

153            self.optimizer.step()

各エポックの最後のバッチでモデルパラメータと勾配を記録します

155            if batch_idx.is_last:
156                tracker.add('model', self.model)

グラデーションをクリア

158            self.optimizer.zero_grad()

追跡したメトリクスを保存する

161        tracker.save()

自己回帰モデルを初期化

164@option(Configs.model)
165def autoregressive_model(c: Configs):
169    m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
170    return m.to(c.device)

スイッチトランスを初期化します

173@option(Configs.transformer)
174def switch_transformer(c: Configs):
178    from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
179    from labml_nn.transformers import MultiHeadAttention
180    from labml_nn.transformers.feed_forward import FeedForward
181
182    return SwitchTransformer(
183        SwitchTransformerLayer(d_model=c.d_model,
184                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
185                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
186                                                              drop_tokens=c.drop_tokens,
187                                                              is_scale_prob=c.is_scale_prob,
188                                                              n_experts=c.n_experts,
189                                                              expert=FeedForward(c.d_model, c.d_ff, c.dropout),
190                                                              d_model=c.d_model),
191                               dropout_prob=c.dropout),
192        c.n_layers)

実験を実行する

195def main():

実験を作成

200    experiment.create(name="switch_transformer", comment='')

コンフィグの作成

202    conf = Configs()

構成をロード

204    experiment.configs(conf,

オーバーライドする設定の辞書

206                       {'tokenizer': 'character',
207                        'text': 'tiny_shakespeare',
208                        'optimizer.learning_rate': 1.,
209                        'optimizer.optimizer': 'Noam',
210                        'prompt': 'It is',
211                        'prompt_separator': '',
212
213                        'transformer': 'switch_transformer',
214                        'n_experts': 4,
215
216                        'drop_tokens': True,
217                        'capacity_factor': 1.2,
218
219                        'train_loader': 'shuffled_train_loader',
220                        'valid_loader': 'shuffled_valid_loader',
221
222                        'seq_len': 64,
223                        'epochs': 128,
224                        'batch_size': 32,
225                        'inner_iterations': 25,
226                        })

保存および読み込み用のモデルを設定する

229    experiment.add_pytorch_models({'model': conf.model})

実験を始める

232    with experiment.start():

TrainValidConfigs.run

234        conf.run()

238if __name__ == '__main__':
239    main()