11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
26class AutoregressiveModel(Module):
31 def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32 super().__init__()
トークン埋め込みモジュール
34 self.src_embed = nn.Embedding(n_vocab, d_model)
変圧器
36 self.transformer = transformer
最終レイヤー
38 self.generator = nn.Linear(d_model, n_vocab)
マスク
40 self.mask_x = None
41 self.mask_mem = None
43 def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):
メモリの長さ
45 m_len = len(mem[0]) if mem else 0
トークンのマスクを後から作成
47 if self.mask_x is None or self.mask_x.shape[0] < len(x):
48 from labml_nn.transformers.utils import subsequent_mask
49 self.mask_x = subsequent_mask(len(x)).to(x.device)
メモリ用のオールワン (フルビジビリティ) マスクを作成
51 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
メモリがある場合はマスクを連結してください
55 if m_len:
56 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
それ以外の場合は、後続のマスクを使用してください。
58 else:
59 mask = self.mask_x[:len(x), :len(x)]
トークンの埋め込み
62 x = self.src_embed(x)
変圧器に通してください
64 res, mem = self.transformer(x, mem, mask)
次のトークンのロジットを生成
66 res = self.generator(res)
68 return res, mem
71class Configs(NLPAutoRegressionConfigs):
78 model: AutoregressiveModel
トークンの埋め込みサイズ
81 d_model: int = 128
アテンションヘッドの数
83 heads: int = 4
脱落確率
85 dropout: float = 0.0
FFN 隠れレイヤーのフィーチャ数
87 d_ff: int = 256
変圧器層の数
89 n_layers: int = 6
保存するメモリの数
91 mem_len: int = 128
トレーニングと検証を切り替えるときにメモリを維持するステートモジュール
93 memory = SimpleStateModule()
95 def init(self):
トラッカー構成を設定
97 tracker.set_scalar("accuracy.*", True)
98 tracker.set_scalar("loss.*", True)
モジュール出力をログに記録するフックを追加
100 hook_model_outputs(self.mode, self.model, 'model')
これにより、精度メトリックの統計情報とメモリがトレーニングと検証用に別々に保持されます。
102 self.state_modules = [self.accuracy, self.memory]
記憶を連結し、古い記憶を削除して、記憶を最大限に活用してください。mem_len
104 def merge_memory(self, old_mem, new_mem):
メモリを使用しないように設定されている場合
111 if self.mem_len == 0:
112 return []
古いメモリと連結
115 if old_mem:
116 mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117 else:
118 mem = new_mem
古い思い出を切り捨てる
121 if len(mem[0]) > self.mem_len:
122 mem = [m[-self.mem_len:] for m in mem]
125 return mem
127 def step(self, batch: any, batch_idx: BatchIndex):
データをデバイスに移動
133 data, target = batch[0].to(self.device), batch[1].to(self.device)
トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新
136 if self.mode.is_train:
137 tracker.add_global_step(data.shape[0] * data.shape[1])
モデル出力をキャプチャするかどうか
140 with self.mode.update(is_log_activations=batch_idx.is_last):
思い出をゲット
142 mem = self.memory.get()
モデルを実行
144 output, new_mem = self.model(data, mem)
マージメモリ
146 mem = self.merge_memory(mem, new_mem)
メモリーを更新
148 self.memory.set(mem)
クロスエントロピー損失の計算と記録
151 loss = self.loss_func(output, target)
152 tracker.add("loss.", loss)
精度の計算と記録
155 self.accuracy(output, target)
156 self.accuracy.track()
モデルのトレーニング
159 if self.mode.is_train:
勾配の計算
161 loss.backward()
クリップグラデーション
163 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
最適化の一歩を踏み出す
165 self.optimizer.step()
各エポックの最後のバッチでモデルパラメータと勾配を記録します
167 if batch_idx.is_last:
168 tracker.add('model', self.model)
グラデーションをクリア
170 self.optimizer.zero_grad()
追跡したメトリクスを保存する
173 tracker.save()
175 def sample(self):
起動プロンプト
181 prompt = self.prompt
印刷用の出力を収集
183 log = [(prompt, Text.subtle)]
記憶
185 mem = []
25トークンのサンプル
187 for i in monit.iterate('Sample', 25):
プロンプトをトークン化
189 data = self.text.text_to_i(prompt).unsqueeze(-1)
デバイスに移動
191 data = data.to(self.device)
モデル出力を取得
193 output, new_mem = self.model(data, mem)
モデル予測を取得 (欲張り)
195 output = output.argmax(dim=-1).squeeze(1)
予測をプロンプトに追加
197 prompt += self.prompt_separator + self.text.itos[output[-1]]
次のイテレーションでは最後の文字だけをモデルにフィードし、残りはメモリとして残ります
199 prompt = prompt[-1:]
ロギング用の予測を追加
201 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
メモリを更新
203 mem = self.merge_memory(mem, new_mem)
サンプル出力を印刷する
206 logger.log(log)
209@option(Configs.model)
210def autoregressive_model(c: Configs):
214 from labml_nn.transformers.xl import RelativeMultiHeadAttention
215 from labml_nn.transformers.feed_forward import FeedForward
216 m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217 TransformerXLLayer(d_model=c.d_model,
218 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220 dropout_prob=c.dropout), c.n_layers))
221 return m.to(c.device)
224def main():
実験を作成
229 experiment.create(name="transformer_xl", comment='')
コンフィグの作成
231 conf = Configs()
構成をロード
233 experiment.configs(conf,
オーバーライドする設定の辞書
235 {'tokenizer': 'character',
236 'text': 'tiny_shakespeare',
237 'optimizer.learning_rate': 1.,
238 'optimizer.optimizer': 'Noam',
239 'prompt': 'It is',
240 'prompt_separator': '',
241
242 'train_loader': 'sequential_train_loader',
243 'valid_loader': 'sequential_valid_loader',
244
245 'seq_len': 2,
246 'mem_len': 32,
247 'epochs': 128,
248 'batch_size': 32,
249 'inner_iterations': 25,
250 })
保存および読み込み用のモデルを設定する
253 experiment.add_pytorch_models({'model': conf.model})
実験を始める
256 with experiment.start():
TrainValidConfigs.run
258 conf.run()
262if __name__ == '__main__':
263 main()