圧縮変圧器実験

これは、圧縮トランスモデルをトレーニングするための注釈付きの PyTorch 実験です。

11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15
16from labml import experiment, tracker, monit, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
24    CompressiveTransformerLayer, Conv1dCompression
27class CompressedMemory(NamedTuple):
28    mem: List[torch.Tensor]
29    c_mem: List[torch.Tensor]

自動回帰モデル

32class AutoregressiveModel(Module):
37    def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
38        super().__init__()

トークン埋め込みモジュール

40        self.src_embed = nn.Embedding(n_vocab, d_model)

変圧器

42        self.transformer = transformer

最終レイヤー

44        self.generator = nn.Linear(d_model, n_vocab)

マスク

46        self.mask_x = None
47        self.mask_mem = None
49    def forward(self, x: torch.Tensor, mem: CompressedMemory):

メモリと圧縮メモリを取得

51        if mem is not None:
52            mem, c_mem = mem.mem, mem.c_mem
53        else:
54            mem = []
55            c_mem = []

メモリと圧縮メモリの合計長 (マスク用)

58        m_len = len(mem[0]) if mem else 0
59        if c_mem:
60            m_len += len(c_mem[0])

トークンのマスクを後から作成

63        if self.mask_x is None or self.mask_x.shape[0] < len(x):
64            from labml_nn.transformers.utils import subsequent_mask
65            self.mask_x = subsequent_mask(len(x)).to(x.device)

メモリ用のオールワン (フルビジビリティ) マスクを作成

67        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
68            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

メモリがある場合はマスクを連結してください

71        if m_len:
72            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

それ以外の場合は、後続のマスクのみを使用してください

74        else:
75            mask = self.mask_x[:len(x), :len(x)]

トークンの埋め込み

78        x = self.src_embed(x)

変圧器に通してください

80        res, mem = self.transformer(x, mem, c_mem, mask)

次のトークンのロジットを生成

82        res = self.generator(res)

84        return res, mem

コンフィギュレーション

デフォルトの構成は、実験を開始するときに上書きできます。また、今後変更する予定です。

87class Configs(NLPAutoRegressionConfigs):
94    model: AutoregressiveModel

トークンの埋め込みサイズ

97    d_model: int = 128

アテンションヘッドの数

99    heads: int = 4

脱落確率

101    dropout: float = 0.0

FFN 隠れレイヤーのフィーチャ数

103    d_ff: int = 256

変圧器層の数

105    n_layers: int = 6

保存するメモリの数

107    mem_len: int = 8

トレーニングと検証を切り替えるときにメモリを維持するステートモジュール

109    memory = SimpleStateModule()

注意力再建ロス

111    attention_reconstruction_loss: AttentionReconstructionLoss

圧縮率

113    compression_rate: int = 4

圧縮メモリ長

115    c_mem_len: int = 128
117    def init(self):

トラッカー構成を設定

119        tracker.set_scalar("accuracy.*", True)
120        tracker.set_scalar("loss.*", True)

端末に注意再構成ロスを印刷しないでください

122        tracker.set_scalar("ar_loss.*", False)

モジュール出力をログに記録するフックを追加

124        hook_model_outputs(self.mode, self.model, 'model')

これにより、精度メトリックの統計情報とメモリがトレーニングと検証用に別々に保持されます。

126        self.state_modules = [self.accuracy, self.memory]

新しい記憶を連結し、最も古い記憶を圧縮します。

128    @torch.no_grad()
129    def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
130            -> Tuple[CompressedMemory, List[torch.Tensor]]:

構成でメモリを使用しないよう指定されている場合

136        if self.mem_len == 0 and self.c_mem_len == 0:
137            return CompressedMemory([], []), []

メモリと圧縮メモリを取得

140        if mem is not None:
141            mem, c_mem = mem.mem, mem.c_mem
142        else:
143            mem, c_mem = [], []

新しい記憶と古い記憶をつなげる

146        if mem:
147            mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
148        else:
149            mem = new_mem

より多くのメモリがある場合は、最も古いメモリを圧縮します mem_len

152        if len(mem[0]) > self.mem_len:

作成する圧縮メモリの数を計算します。ここでは保持するメモリの最大数、は保持するメモリの最大数 (mem_len )。

156            n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate

圧縮するメモリの数

158            n_old = n_c_mem * self.compression_rate

レイヤーごとに圧縮する必要があるメモリを保存するためのリスト。

160            mem_to_compress = []

レイヤーごとに圧縮されないメモリを保存するためのリスト。

162            uncompressed_mem = []

各レイヤーのメモリを繰り返し処理します。

164            for m in mem:

思い出を分けて

166                cm, m = torch.split(m, [n_old, len(m) - n_old])

思い出を集めて圧縮

168                mem_to_compress.append(cm)

残りの思い出を集めよう

170                uncompressed_mem.append(m)

思い出を更新

172            mem = uncompressed_mem

思い出を圧縮

175            new_c_mem = []
176            for i, layer in enumerate(self.model.transformer.layers):
177                new_c_mem.append(layer.compress(mem_to_compress[i]))

新しく圧縮されたメモリを古い圧縮メモリと連結する

180            if c_mem:
181                c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]

古い圧縮メモリがない場合

183            else:
184                c_mem = new_c_mem

古い思い出を切り捨てる

187            if len(c_mem[0]) > self.c_mem_len:
188                c_mem = [m[-self.c_mem_len:] for m in c_mem]

メモリの数が以下の場合、メモリは圧縮されません mem_len

190        else:
191            mem_to_compress = []

メモリと圧縮されたメモリを返します。再構成損失の計算には、圧縮されたメモリが必要です

195        return CompressedMemory(mem, c_mem), mem_to_compress

トレーニング/検証ステップ

197    def step(self, batch: any, batch_idx: BatchIndex):

データをデバイスに移動

203        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新

206        if self.mode.is_train:
207            tracker.add_global_step(data.shape[0] * data.shape[1])

モデル出力をキャプチャするかどうか

210        with self.mode.update(is_log_activations=batch_idx.is_last):

思い出をゲット

212            mem = self.memory.get()

モデルを実行

214            output, new_mem = self.model(data, mem)

メモリの統合と圧縮

216            mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)

メモリーを更新

218            self.memory.set(mem)

クロスエントロピー損失の計算と記録

221        loss = self.loss_func(output, target)
222        tracker.add("loss.", loss)

このステップで記憶が圧縮された場合の注意再構成損失を計算します。

225        if mem_to_compress:

注意を向けて再建ロス

227            ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)

トラック・アテンション・リコンストラクション・ロス

229            tracker.add("ar_loss.", ar_loss)

損失に注意再構築損失を追加

231            loss = loss + ar_loss

精度の計算と記録

234        self.accuracy(output, target)
235        self.accuracy.track()

モデルのトレーニング

238        if self.mode.is_train:

勾配の計算

240            loss.backward()

クリップグラデーション

242            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

最適化の一歩を踏み出す

244            self.optimizer.step()

各エポックの最後のバッチでモデルパラメータと勾配を記録します

246            if batch_idx.is_last:
247                tracker.add('model', self.model)

グラデーションをクリア

249            self.optimizer.zero_grad()

追跡したメトリクスを保存する

252        tracker.save()

トレーニング中に定期的にサンプルを生成するサンプリング機能

254    def sample(self):

起動プロンプト

260        prompt = self.prompt

印刷用の出力を収集

262        log = [(prompt, Text.subtle)]

記憶

264        mem = CompressedMemory([], [])

25トークンのサンプル

266        for i in monit.iterate('Sample', 25):

プロンプトをトークン化

268            data = self.text.text_to_i(prompt).unsqueeze(-1)

デバイスに移動

270            data = data.to(self.device)

モデル出力を取得

272            output, new_mem = self.model(data, mem)

モデル予測を取得 (欲張り)

274            output = output.argmax(dim=-1).squeeze(1)

予測をプロンプトに追加

276            prompt += self.prompt_separator + self.text.itos[output[-1]]

次のイテレーションでは最後の文字だけをモデルにフィードし、残りはメモリとして残ります

278            prompt = prompt[-1:]

ロギング用の予測を追加

280            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

メモリの更新と圧縮

282            mem, _ = self.merge_compress_memory(mem, new_mem)

サンプル出力を印刷する

285        logger.log(log)

自己回帰モデルを初期化

288@option(Configs.model)
289def autoregressive_model(c: Configs):
293    from labml_nn.transformers.xl import RelativeMultiHeadAttention
294    from labml_nn.transformers.feed_forward import FeedForward
295    m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
296        CompressiveTransformerLayer(d_model=c.d_model,
297                                    self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
298                                    feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
299                                    dropout_prob=c.dropout,
300                                    compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
301    return m.to(c.device)

注意力再構築ロスを初期化

304@option(Configs.attention_reconstruction_loss)
305def attention_reconstruction_loss(c: Configs):
309    return AttentionReconstructionLoss(c.model.transformer.layers)

実験を実行する

312def main():

実験を作成

317    experiment.create(name="compressive_transformer", comment='')

コンフィグの作成

319    conf = Configs()

構成をロード

321    experiment.configs(conf,

オーバーライドする設定の辞書

323                       {'tokenizer': 'character',
324                        'text': 'tiny_shakespeare',
325                        'optimizer.learning_rate': 2.5e-4,
326                        'optimizer.optimizer': 'AdamW',
327                        'prompt': 'It is',
328                        'prompt_separator': '',
329
330                        'train_loader': 'sequential_train_loader',
331                        'valid_loader': 'sequential_valid_loader',
332
333                        'seq_len': 8,
334                        'mem_len': 8,
335                        'epochs': 128,
336                        'batch_size': 32,
337                        'inner_iterations': 25,
338                        'compression_rate': 2,
339                        })

保存および読み込み用のモデルを設定する

342    experiment.add_pytorch_models({'model': conf.model})

実験を始める

345    with experiment.start():

TrainValidConfigs.run

347        conf.run()

351if __name__ == '__main__':
352    main()