トランスフォーマー XL 実験

これは、トランスフォーマー xl モデルをトレーニングするための注釈付きの PyTorch 実験です。

11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer

自動回帰モデル

26class AutoregressiveModel(Module):
31    def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32        super().__init__()

トークン埋め込みモジュール

34        self.src_embed = nn.Embedding(n_vocab, d_model)

変圧器

36        self.transformer = transformer

最終レイヤー

38        self.generator = nn.Linear(d_model, n_vocab)

マスク

40        self.mask_x = None
41        self.mask_mem = None
43    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):

メモリの長さ

45        m_len = len(mem[0]) if mem else 0

トークンのマスクを後から作成

47        if self.mask_x is None or self.mask_x.shape[0] < len(x):
48            from labml_nn.transformers.utils import subsequent_mask
49            self.mask_x = subsequent_mask(len(x)).to(x.device)

メモリ用のオールワン (フルビジビリティ) マスクを作成

51        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

メモリがある場合はマスクを連結してください

55        if m_len:
56            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

それ以外の場合は、後続のマスクを使用してください。

58        else:
59            mask = self.mask_x[:len(x), :len(x)]

トークンの埋め込み

62        x = self.src_embed(x)

変圧器に通してください

64        res, mem = self.transformer(x, mem, mask)

次のトークンのロジットを生成

66        res = self.generator(res)

68        return res, mem

コンフィギュレーション

デフォルトの設定は、実験を開始したときに上書きでき、また上書きされます。

71class Configs(NLPAutoRegressionConfigs):
78    model: AutoregressiveModel

トークンの埋め込みサイズ

81    d_model: int = 128

アテンションヘッドの数

83    heads: int = 4

脱落確率

85    dropout: float = 0.0

FFN 隠れレイヤーのフィーチャ数

87    d_ff: int = 256

変圧器層の数

89    n_layers: int = 6

保存するメモリの数

91    mem_len: int = 128

トレーニングと検証を切り替えるときにメモリを維持するステートモジュール

93    memory = SimpleStateModule()
95    def init(self):

トラッカー構成を設定

97        tracker.set_scalar("accuracy.*", True)
98        tracker.set_scalar("loss.*", True)

モジュール出力をログに記録するフックを追加

100        hook_model_outputs(self.mode, self.model, 'model')

これにより、精度メトリックの統計情報とメモリがトレーニングと検証用に別々に保持されます。

102        self.state_modules = [self.accuracy, self.memory]

記憶を連結し、古い記憶を削除して、記憶を最大限に活用してください。mem_len

104    def merge_memory(self, old_mem, new_mem):

メモリを使用しないように設定されている場合

111        if self.mem_len == 0:
112            return []

古いメモリと連結

115        if old_mem:
116            mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117        else:
118            mem = new_mem

古い思い出を切り捨てる

121        if len(mem[0]) > self.mem_len:
122            mem = [m[-self.mem_len:] for m in mem]

125        return mem

トレーニング/検証ステップ

127    def step(self, batch: any, batch_idx: BatchIndex):

データをデバイスに移動

133        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新

136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])

モデル出力をキャプチャするかどうか

140        with self.mode.update(is_log_activations=batch_idx.is_last):

思い出をゲット

142            mem = self.memory.get()

モデルを実行

144            output, new_mem = self.model(data, mem)

マージメモリ

146            mem = self.merge_memory(mem, new_mem)

メモリーを更新

148            self.memory.set(mem)

クロスエントロピー損失の計算と記録

151        loss = self.loss_func(output, target)
152        tracker.add("loss.", loss)

精度の計算と記録

155        self.accuracy(output, target)
156        self.accuracy.track()

モデルのトレーニング

159        if self.mode.is_train:

勾配の計算

161            loss.backward()

クリップグラデーション

163            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

最適化の一歩を踏み出す

165            self.optimizer.step()

各エポックの最後のバッチでモデルパラメータと勾配を記録します

167            if batch_idx.is_last:
168                tracker.add('model', self.model)

グラデーションをクリア

170            self.optimizer.zero_grad()

追跡したメトリクスを保存する

173        tracker.save()

トレーニング中に定期的にサンプルを生成するサンプリング機能

175    def sample(self):

起動プロンプト

181        prompt = self.prompt

印刷用の出力を収集

183        log = [(prompt, Text.subtle)]

記憶

185        mem = []

25トークンのサンプル

187        for i in monit.iterate('Sample', 25):

プロンプトをトークン化

189            data = self.text.text_to_i(prompt).unsqueeze(-1)

デバイスに移動

191            data = data.to(self.device)

モデル出力を取得

193            output, new_mem = self.model(data, mem)

モデル予測を取得 (欲張り)

195            output = output.argmax(dim=-1).squeeze(1)

予測をプロンプトに追加

197            prompt += self.prompt_separator + self.text.itos[output[-1]]

次のイテレーションでは最後の文字だけをモデルにフィードし、残りはメモリとして残ります

199            prompt = prompt[-1:]

ロギング用の予測を追加

201            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

メモリを更新

203            mem = self.merge_memory(mem, new_mem)

サンプル出力を印刷する

206        logger.log(log)

自己回帰モデルを初期化

209@option(Configs.model)
210def autoregressive_model(c: Configs):
214    from labml_nn.transformers.xl import RelativeMultiHeadAttention
215    from labml_nn.transformers.feed_forward import FeedForward
216    m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217        TransformerXLLayer(d_model=c.d_model,
218                           self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219                           feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220                           dropout_prob=c.dropout), c.n_layers))
221    return m.to(c.device)

実験を実行する

224def main():

実験を作成

229    experiment.create(name="transformer_xl", comment='')

コンフィグの作成

231    conf = Configs()

構成をロード

233    experiment.configs(conf,

オーバーライドする設定の辞書

235                       {'tokenizer': 'character',
236                        'text': 'tiny_shakespeare',
237                        'optimizer.learning_rate': 1.,
238                        'optimizer.optimizer': 'Noam',
239                        'prompt': 'It is',
240                        'prompt_separator': '',
241
242                        'train_loader': 'sequential_train_loader',
243                        'valid_loader': 'sequential_valid_loader',
244
245                        'seq_len': 2,
246                        'mem_len': 32,
247                        'epochs': 128,
248                        'batch_size': 32,
249                        'inner_iterations': 25,
250                        })

保存および読み込み用のモデルを設定する

253    experiment.add_pytorch_models({'model': conf.model})

実験を始める

256    with experiment.start():

TrainValidConfigs.run

258        conf.run()

262if __name__ == '__main__':
263    main()