11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15
16from labml import experiment, tracker, monit, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
24 CompressiveTransformerLayer, Conv1dCompression
27class CompressedMemory(NamedTuple):
28 mem: List[torch.Tensor]
29 c_mem: List[torch.Tensor]
32class AutoregressiveModel(Module):
37 def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
38 super().__init__()
トークン埋め込みモジュール
40 self.src_embed = nn.Embedding(n_vocab, d_model)
変圧器
42 self.transformer = transformer
最終レイヤー
44 self.generator = nn.Linear(d_model, n_vocab)
マスク
46 self.mask_x = None
47 self.mask_mem = None
49 def forward(self, x: torch.Tensor, mem: CompressedMemory):
メモリと圧縮メモリを取得
51 if mem is not None:
52 mem, c_mem = mem.mem, mem.c_mem
53 else:
54 mem = []
55 c_mem = []
メモリと圧縮メモリの合計長 (マスク用)
58 m_len = len(mem[0]) if mem else 0
59 if c_mem:
60 m_len += len(c_mem[0])
トークンのマスクを後から作成
63 if self.mask_x is None or self.mask_x.shape[0] < len(x):
64 from labml_nn.transformers.utils import subsequent_mask
65 self.mask_x = subsequent_mask(len(x)).to(x.device)
メモリ用のオールワン (フルビジビリティ) マスクを作成
67 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
68 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
メモリがある場合はマスクを連結してください
71 if m_len:
72 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
それ以外の場合は、後続のマスクのみを使用してください
74 else:
75 mask = self.mask_x[:len(x), :len(x)]
トークンの埋め込み
78 x = self.src_embed(x)
変圧器に通してください
80 res, mem = self.transformer(x, mem, c_mem, mask)
次のトークンのロジットを生成
82 res = self.generator(res)
84 return res, mem
87class Configs(NLPAutoRegressionConfigs):
94 model: AutoregressiveModel
トークンの埋め込みサイズ
97 d_model: int = 128
アテンションヘッドの数
99 heads: int = 4
脱落確率
101 dropout: float = 0.0
FFN 隠れレイヤーのフィーチャ数
103 d_ff: int = 256
変圧器層の数
105 n_layers: int = 6
保存するメモリの数
107 mem_len: int = 8
トレーニングと検証を切り替えるときにメモリを維持するステートモジュール
109 memory = SimpleStateModule()
注意力再建ロス
111 attention_reconstruction_loss: AttentionReconstructionLoss
圧縮率
113 compression_rate: int = 4
圧縮メモリ長
115 c_mem_len: int = 128
117 def init(self):
トラッカー構成を設定
119 tracker.set_scalar("accuracy.*", True)
120 tracker.set_scalar("loss.*", True)
端末に注意再構成ロスを印刷しないでください
122 tracker.set_scalar("ar_loss.*", False)
モジュール出力をログに記録するフックを追加
124 hook_model_outputs(self.mode, self.model, 'model')
これにより、精度メトリックの統計情報とメモリがトレーニングと検証用に別々に保持されます。
126 self.state_modules = [self.accuracy, self.memory]
新しい記憶を連結し、最も古い記憶を圧縮します。
128 @torch.no_grad()
129 def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
130 -> Tuple[CompressedMemory, List[torch.Tensor]]:
構成でメモリを使用しないよう指定されている場合
136 if self.mem_len == 0 and self.c_mem_len == 0:
137 return CompressedMemory([], []), []
メモリと圧縮メモリを取得
140 if mem is not None:
141 mem, c_mem = mem.mem, mem.c_mem
142 else:
143 mem, c_mem = [], []
新しい記憶と古い記憶をつなげる
146 if mem:
147 mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
148 else:
149 mem = new_mem
より多くのメモリがある場合は、最も古いメモリを圧縮します mem_len
152 if len(mem[0]) > self.mem_len:
作成する圧縮メモリの数を計算します。ここで、は保持するメモリの最大数、は保持するメモリの最大数 (mem_len
)。
156 n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate
圧縮するメモリの数
158 n_old = n_c_mem * self.compression_rate
レイヤーごとに圧縮する必要があるメモリを保存するためのリスト。
160 mem_to_compress = []
レイヤーごとに圧縮されないメモリを保存するためのリスト。
162 uncompressed_mem = []
各レイヤーのメモリを繰り返し処理します。
164 for m in mem:
思い出を分けて
166 cm, m = torch.split(m, [n_old, len(m) - n_old])
思い出を集めて圧縮
168 mem_to_compress.append(cm)
残りの思い出を集めよう
170 uncompressed_mem.append(m)
思い出を更新
172 mem = uncompressed_mem
思い出を圧縮
175 new_c_mem = []
176 for i, layer in enumerate(self.model.transformer.layers):
177 new_c_mem.append(layer.compress(mem_to_compress[i]))
新しく圧縮されたメモリを古い圧縮メモリと連結する
180 if c_mem:
181 c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
古い圧縮メモリがない場合
183 else:
184 c_mem = new_c_mem
古い思い出を切り捨てる
187 if len(c_mem[0]) > self.c_mem_len:
188 c_mem = [m[-self.c_mem_len:] for m in c_mem]
メモリの数が以下の場合、メモリは圧縮されません mem_len
190 else:
191 mem_to_compress = []
メモリと圧縮されたメモリを返します。再構成損失の計算には、圧縮されたメモリが必要です
。195 return CompressedMemory(mem, c_mem), mem_to_compress
197 def step(self, batch: any, batch_idx: BatchIndex):
データをデバイスに移動
203 data, target = batch[0].to(self.device), batch[1].to(self.device)
トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新
206 if self.mode.is_train:
207 tracker.add_global_step(data.shape[0] * data.shape[1])
モデル出力をキャプチャするかどうか
210 with self.mode.update(is_log_activations=batch_idx.is_last):
思い出をゲット
212 mem = self.memory.get()
モデルを実行
214 output, new_mem = self.model(data, mem)
メモリの統合と圧縮
216 mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
メモリーを更新
218 self.memory.set(mem)
クロスエントロピー損失の計算と記録
221 loss = self.loss_func(output, target)
222 tracker.add("loss.", loss)
このステップで記憶が圧縮された場合の注意再構成損失を計算します。
225 if mem_to_compress:
注意を向けて再建ロス
227 ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)
トラック・アテンション・リコンストラクション・ロス
229 tracker.add("ar_loss.", ar_loss)
損失に注意再構築損失を追加
231 loss = loss + ar_loss
精度の計算と記録
234 self.accuracy(output, target)
235 self.accuracy.track()
モデルのトレーニング
238 if self.mode.is_train:
勾配の計算
240 loss.backward()
クリップグラデーション
242 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
最適化の一歩を踏み出す
244 self.optimizer.step()
各エポックの最後のバッチでモデルパラメータと勾配を記録します
246 if batch_idx.is_last:
247 tracker.add('model', self.model)
グラデーションをクリア
249 self.optimizer.zero_grad()
追跡したメトリクスを保存する
252 tracker.save()
254 def sample(self):
起動プロンプト
260 prompt = self.prompt
印刷用の出力を収集
262 log = [(prompt, Text.subtle)]
記憶
264 mem = CompressedMemory([], [])
25トークンのサンプル
266 for i in monit.iterate('Sample', 25):
プロンプトをトークン化
268 data = self.text.text_to_i(prompt).unsqueeze(-1)
デバイスに移動
270 data = data.to(self.device)
モデル出力を取得
272 output, new_mem = self.model(data, mem)
モデル予測を取得 (欲張り)
274 output = output.argmax(dim=-1).squeeze(1)
予測をプロンプトに追加
276 prompt += self.prompt_separator + self.text.itos[output[-1]]
次のイテレーションでは最後の文字だけをモデルにフィードし、残りはメモリとして残ります
278 prompt = prompt[-1:]
ロギング用の予測を追加
280 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
メモリの更新と圧縮
282 mem, _ = self.merge_compress_memory(mem, new_mem)
サンプル出力を印刷する
285 logger.log(log)
288@option(Configs.model)
289def autoregressive_model(c: Configs):
293 from labml_nn.transformers.xl import RelativeMultiHeadAttention
294 from labml_nn.transformers.feed_forward import FeedForward
295 m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
296 CompressiveTransformerLayer(d_model=c.d_model,
297 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
298 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
299 dropout_prob=c.dropout,
300 compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
301 return m.to(c.device)
304@option(Configs.attention_reconstruction_loss)
305def attention_reconstruction_loss(c: Configs):
309 return AttentionReconstructionLoss(c.model.transformer.layers)
312def main():
実験を作成
317 experiment.create(name="compressive_transformer", comment='')
コンフィグの作成
319 conf = Configs()
構成をロード
321 experiment.configs(conf,
オーバーライドする設定の辞書
323 {'tokenizer': 'character',
324 'text': 'tiny_shakespeare',
325 'optimizer.learning_rate': 2.5e-4,
326 'optimizer.optimizer': 'AdamW',
327 'prompt': 'It is',
328 'prompt_separator': '',
329
330 'train_loader': 'sequential_train_loader',
331 'valid_loader': 'sequential_valid_loader',
332
333 'seq_len': 8,
334 'mem_len': 8,
335 'epochs': 128,
336 'batch_size': 32,
337 'inner_iterations': 25,
338 'compression_rate': 2,
339 })
保存および読み込み用のモデルを設定する
342 experiment.add_pytorch_models({'model': conf.model})
実験を始める
345 with experiment.start():
TrainValidConfigs.run
347 conf.run()
351if __name__ == '__main__':
352 main()