11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM

MLM 用トランスフォーマーベースモデル

28class TransformerMLM(nn.Module):
33    def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):
40        super().__init__()
41        self.generator = generator
42        self.src_embed = src_embed
43        self.encoder = encoder
45    def forward(self, x: torch.Tensor):

位置エンコーディングによるトークンの埋め込みを取得

47        x = self.src_embed(x)

トランスエンコーダー

49        x = self.encoder(x, None)

出力用のロジット

51        y = self.generator(x)

結果を返します(トレーナーはRNNでも使用されるため、2番目の値は状態用です)

55        return y, None

コンフィギュレーション

NLPAutoRegressionConfigs これが継承されているのは、ここで再利用するデータパイプラインの実装があるからです。MLMからカスタムトレーニングステップを実装しました

58class Configs(NLPAutoRegressionConfigs):

MLM モデル

69    model: TransformerMLM

変圧器

71    transformer: TransformerConfigs

トークンの数

74    n_tokens: int = 'n_tokens_mlm'

マスクしてはいけないトークン

76    no_mask_tokens: List[int] = []

トークンをマスキングする確率

78    masking_prob: float = 0.15

マスクをランダムトークンに置き換える確率

80    randomize_prob: float = 0.1

マスクを元のトークンと交換する確率

82    no_change_prob: float = 0.1
84    mlm: MLM

[MASK] トークン

87    mask_token: int

[PADDING] トークン

89    padding_token: int

サンプリングを促す

92    prompt: str = [
93        "We are accounted poor citizens, the patricians good.",
94        "What authority surfeits on would relieve us: if they",
95        "would yield us but the superfluity, while it were",
96        "wholesome, we might guess they relieved us humanely;",
97        "but they think we are too dear: the leanness that",
98        "afflicts us, the object of our misery, is as an",
99        "inventory to particularise their abundance; our",
100        "sufferance is a gain to them Let us revenge this with",
101        "our pikes, ere we become rakes: for the gods know I",
102        "speak this in hunger for bread, not in thirst for revenge.",
103    ]

初期化

105    def init(self):

[MASK] トークン

111        self.mask_token = self.n_tokens - 1

[PAD] トークン

113        self.padding_token = self.n_tokens - 2
116        self.mlm = MLM(padding_token=self.padding_token,
117                       mask_token=self.mask_token,
118                       no_mask_tokens=self.no_mask_tokens,
119                       n_tokens=self.n_tokens,
120                       masking_prob=self.masking_prob,
121                       randomize_prob=self.randomize_prob,
122                       no_change_prob=self.no_change_prob)

精度指標 (と等しいラベルは無視してください[PAD] )

125        self.accuracy = Accuracy(ignore_index=self.padding_token)

クロスエントロピー損失 (と等しいラベルは無視してください) [PAD]

127        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)

129        super().init()

トレーニングまたは検証ステップ

131    def step(self, batch: any, batch_idx: BatchIndex):

入力をデバイスに移動

137        data = batch[0].to(self.device)

トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新

140        if self.mode.is_train:
141            tracker.add_global_step(data.shape[0] * data.shape[1])

マスクされた入力とラベルを取得

144        with torch.no_grad():
145            data, labels = self.mlm(data)

モデル出力をキャプチャするかどうか

148        with self.mode.update(is_log_activations=batch_idx.is_last):

モデル出力を取得します。RNN を使用する場合はステートのタプルを返します。これはまだ実装されていません。

152            output, *_ = self.model(data)

損失の計算と記録

155        loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156        tracker.add("loss.", loss)

精度の計算と記録

159        self.accuracy(output, labels)
160        self.accuracy.track()

モデルのトレーニング

163        if self.mode.is_train:

勾配の計算

165            loss.backward()

クリップグラデーション

167            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

最適化の一歩を踏み出す

169            self.optimizer.step()

各エポックの最後のバッチでモデルパラメータと勾配を記録します

171            if batch_idx.is_last:
172                tracker.add('model', self.model)

グラデーションをクリア

174            self.optimizer.zero_grad()

追跡したメトリクスを保存する

177        tracker.save()

トレーニング中に定期的にサンプルを生成するサンプリング機能

179    @torch.no_grad()
180    def sample(self):

が入力されたデータのテンソルを空にします。[PAD]

186        data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)

プロンプトを 1 つずつ追加します

188        for i, p in enumerate(self.prompt):

トークンのインデックスを取得

190            d = self.text.text_to_i(p)

テンソルに追加

192            s = min(self.seq_len, len(d))
193            data[:s, i] = d[:s]

テンソルを現在のデバイスに移動

195        data = data.to(self.device)

マスクされた入力とラベルを取得

198        data, labels = self.mlm(data)

モデル出力を取得

200        output, *_ = self.model(data)

生成されたサンプルを印刷

203        for j in range(data.shape[1]):

印刷からの出力を収集

205            log = []

各トークンについて

207            for i in range(len(data)):

ラベルがそうでない場合 [PAD]

209                if labels[i, j] != self.padding_token:

予測を取得

211                    t = output[i, j].argmax().item()

印刷可能な文字の場合

213                    if t < len(self.text.itos):

正しい予測

215                        if t == labels[i, j]:
216                            log.append((self.text.itos[t], Text.value))

予測が間違っている

218                        else:
219                            log.append((self.text.itos[t], Text.danger))

印刷可能な文字でない場合

221                    else:
222                        log.append(('*', Text.danger))

ラベルが [PAD] (マスクされていない) 場合は、オリジナルを印刷してください。

224                elif data[i, j] < len(self.text.itos):
225                    log.append((self.text.itos[data[i, j]], Text.subtle))

プリント

228            logger.log(log)

[PAD] およびを含むトークンの数 [MASK]

231@option(Configs.n_tokens)
232def n_tokens_mlm(c: Configs):
236    return c.text.n_tokens + 2

変圧器構成

239@option(Configs.transformer)
240def _transformer_configs(c: Configs):
247    conf = TransformerConfigs()

埋め込みやロジットの生成に使用するボキャブラリーサイズを設定

249    conf.n_src_vocab = c.n_tokens
250    conf.n_tgt_vocab = c.n_tokens

埋め込みサイズ

252    conf.d_model = c.d_model

255    return conf

分類モデルの作成

258@option(Configs.model)
259def _model(c: Configs):
263    m = TransformerMLM(encoder=c.transformer.encoder,
264                       src_embed=c.transformer.src_embed,
265                       generator=c.transformer.generator).to(c.device)
266
267    return m
270def main():

実験を作成

272    experiment.create(name="mlm")

コンフィグの作成

274    conf = Configs()

オーバーライド設定

276    experiment.configs(conf, {

バッチサイズ

278        'batch_size': 64,

シーケンスの長さは 短いシーケンス長を使用してトレーニングを高速化します。そうしないと、トレーニングに時間がかかります。

281        'seq_len': 32,

1024 エポックのトレーニングを行います。

284        'epochs': 1024,

エポックごとにトレーニングと検証を切り替える

287        'inner_iterations': 1,

変圧器構成 (デフォルトと同じ)

290        'd_model': 128,
291        'transformer.ffn.d_ff': 256,
292        'transformer.n_heads': 8,
293        'transformer.n_layers': 6,

Noam オプティマイザを使う

296        'optimizer.optimizer': 'Noam',
297        'optimizer.learning_rate': 1.,
298    })

保存および読み込み用のモデルを設定する

301    experiment.add_pytorch_models({'model': conf.model})

実験を始める

304    with experiment.start():

トレーニングを実行

306        conf.run()

310if __name__ == '__main__':
311    main()