11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM
28class TransformerMLM(nn.Module):
33 def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):
40 super().__init__()
41 self.generator = generator
42 self.src_embed = src_embed
43 self.encoder = encoder
45 def forward(self, x: torch.Tensor):
位置エンコーディングによるトークンの埋め込みを取得
47 x = self.src_embed(x)
トランスエンコーダー
49 x = self.encoder(x, None)
出力用のロジット
51 y = self.generator(x)
結果を返します(トレーナーはRNNでも使用されるため、2番目の値は状態用です)
55 return y, None
NLPAutoRegressionConfigs
これが継承されているのは、ここで再利用するデータパイプラインの実装があるからです。MLMからカスタムトレーニングステップを実装しました
58class Configs(NLPAutoRegressionConfigs):
MLM モデル
69 model: TransformerMLM
変圧器
71 transformer: TransformerConfigs
トークンの数
74 n_tokens: int = 'n_tokens_mlm'
マスクしてはいけないトークン
76 no_mask_tokens: List[int] = []
トークンをマスキングする確率
78 masking_prob: float = 0.15
マスクをランダムトークンに置き換える確率
80 randomize_prob: float = 0.1
マスクを元のトークンと交換する確率
82 no_change_prob: float = 0.1
84 mlm: MLM
[MASK]
トークン
87 mask_token: int
[PADDING]
トークン
89 padding_token: int
サンプリングを促す
92 prompt: str = [
93 "We are accounted poor citizens, the patricians good.",
94 "What authority surfeits on would relieve us: if they",
95 "would yield us but the superfluity, while it were",
96 "wholesome, we might guess they relieved us humanely;",
97 "but they think we are too dear: the leanness that",
98 "afflicts us, the object of our misery, is as an",
99 "inventory to particularise their abundance; our",
100 "sufferance is a gain to them Let us revenge this with",
101 "our pikes, ere we become rakes: for the gods know I",
102 "speak this in hunger for bread, not in thirst for revenge.",
103 ]
105 def init(self):
[MASK]
トークン
111 self.mask_token = self.n_tokens - 1
[PAD]
トークン
113 self.padding_token = self.n_tokens - 2
116 self.mlm = MLM(padding_token=self.padding_token,
117 mask_token=self.mask_token,
118 no_mask_tokens=self.no_mask_tokens,
119 n_tokens=self.n_tokens,
120 masking_prob=self.masking_prob,
121 randomize_prob=self.randomize_prob,
122 no_change_prob=self.no_change_prob)
精度指標 (と等しいラベルは無視してください[PAD]
)
125 self.accuracy = Accuracy(ignore_index=self.padding_token)
クロスエントロピー損失 (と等しいラベルは無視してください) [PAD]
127 self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)
129 super().init()
131 def step(self, batch: any, batch_idx: BatchIndex):
入力をデバイスに移動
137 data = batch[0].to(self.device)
トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新
140 if self.mode.is_train:
141 tracker.add_global_step(data.shape[0] * data.shape[1])
マスクされた入力とラベルを取得
144 with torch.no_grad():
145 data, labels = self.mlm(data)
モデル出力をキャプチャするかどうか
148 with self.mode.update(is_log_activations=batch_idx.is_last):
モデル出力を取得します。RNN を使用する場合はステートのタプルを返します。これはまだ実装されていません。
152 output, *_ = self.model(data)
損失の計算と記録
155 loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156 tracker.add("loss.", loss)
精度の計算と記録
159 self.accuracy(output, labels)
160 self.accuracy.track()
モデルのトレーニング
163 if self.mode.is_train:
勾配の計算
165 loss.backward()
クリップグラデーション
167 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
最適化の一歩を踏み出す
169 self.optimizer.step()
各エポックの最後のバッチでモデルパラメータと勾配を記録します
171 if batch_idx.is_last:
172 tracker.add('model', self.model)
グラデーションをクリア
174 self.optimizer.zero_grad()
追跡したメトリクスを保存する
177 tracker.save()
179 @torch.no_grad()
180 def sample(self):
が入力されたデータのテンソルを空にします。[PAD]
186 data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)
プロンプトを 1 つずつ追加します
188 for i, p in enumerate(self.prompt):
トークンのインデックスを取得
190 d = self.text.text_to_i(p)
テンソルに追加
192 s = min(self.seq_len, len(d))
193 data[:s, i] = d[:s]
テンソルを現在のデバイスに移動
195 data = data.to(self.device)
マスクされた入力とラベルを取得
198 data, labels = self.mlm(data)
モデル出力を取得
200 output, *_ = self.model(data)
生成されたサンプルを印刷
203 for j in range(data.shape[1]):
印刷からの出力を収集
205 log = []
各トークンについて
207 for i in range(len(data)):
ラベルがそうでない場合 [PAD]
209 if labels[i, j] != self.padding_token:
予測を取得
211 t = output[i, j].argmax().item()
印刷可能な文字の場合
213 if t < len(self.text.itos):
正しい予測
215 if t == labels[i, j]:
216 log.append((self.text.itos[t], Text.value))
予測が間違っている
218 else:
219 log.append((self.text.itos[t], Text.danger))
印刷可能な文字でない場合
221 else:
222 log.append(('*', Text.danger))
ラベルが [PAD]
(マスクされていない) 場合は、オリジナルを印刷してください。
224 elif data[i, j] < len(self.text.itos):
225 log.append((self.text.itos[data[i, j]], Text.subtle))
プリント
228 logger.log(log)
[PAD]
およびを含むトークンの数 [MASK]
231@option(Configs.n_tokens)
232def n_tokens_mlm(c: Configs):
236 return c.text.n_tokens + 2
239@option(Configs.transformer)
240def _transformer_configs(c: Configs):
247 conf = TransformerConfigs()
埋め込みやロジットの生成に使用するボキャブラリーサイズを設定
249 conf.n_src_vocab = c.n_tokens
250 conf.n_tgt_vocab = c.n_tokens
埋め込みサイズ
252 conf.d_model = c.d_model
255 return conf
分類モデルの作成
258@option(Configs.model)
259def _model(c: Configs):
263 m = TransformerMLM(encoder=c.transformer.encoder,
264 src_embed=c.transformer.src_embed,
265 generator=c.transformer.generator).to(c.device)
266
267 return m
270def main():
実験を作成
272 experiment.create(name="mlm")
コンフィグの作成
274 conf = Configs()
オーバーライド設定
276 experiment.configs(conf, {
バッチサイズ
278 'batch_size': 64,
シーケンスの長さは 短いシーケンス長を使用してトレーニングを高速化します。そうしないと、トレーニングに時間がかかります。
281 'seq_len': 32,
1024 エポックのトレーニングを行います。
284 'epochs': 1024,
エポックごとにトレーニングと検証を切り替える
287 'inner_iterations': 1,
変圧器構成 (デフォルトと同じ)
290 'd_model': 128,
291 'transformer.ffn.d_ff': 256,
292 'transformer.n_heads': 8,
293 'transformer.n_layers': 6,
保存および読み込み用のモデルを設定する
301 experiment.add_pytorch_models({'model': conf.model})
実験を始める
304 with experiment.start():
トレーニングを実行
306 conf.run()
310if __name__ == '__main__':
311 main()