自己回帰 NLP モデルトレーナー

11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs

クロスエントロピー損失

28class CrossEntropyLoss(Module):
33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()
37    def forward(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

トレーナー構成

これには、NLP自己回帰タスクトレーニングの基本構成があります。すべてのプロパティは設定可能です。

41class NLPAutoRegressionConfigs(TrainValidConfigs):

オプティマイザー

52    optimizer: torch.optim.Adam

トレーニングデバイス

54    device: torch.device = DeviceConfigs()

自己回帰モデル

57    model: Module

テキストデータセット

59    text: TextDataset

バッチサイズ

61    batch_size: int = 16

シーケンスの長さ、またはコンテキストサイズ

63    seq_len: int = 512

ボキャブラリー内のトークンの数

65    n_tokens: int

トークナイザー

67    tokenizer: Callable = 'character'

サンプリングを開始するテキストプロンプト (説明用)

70    prompt: str

サンプリング時のトークンセパレーター (文字レベルのトークン化の場合は空白)

72    prompt_separator: str

モデルを定期的に保存するかどうか

75    is_save_models = True

損失関数

78    loss_func = CrossEntropyLoss()

精度機能

80    accuracy = Accuracy()

モデル埋め込みサイズ

82    d_model: int = 512

グラデーションクリッピング

84    grad_norm_clip: float = 1.0

トレーニングデータローダー

87    train_loader: DataLoader = 'shuffled_train_loader'

検証データローダー

89    valid_loader: DataLoader = 'shuffled_valid_loader'

データローダーは交換時にシャッフルされます

92    dataloader_shuffle_with_replacement: bool = False

モデルパラメーターと勾配を記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります

97    is_log_model_params_grads: bool = False

モデルのアクティベーションをログに記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります

102    is_log_model_activations: bool = False

初期化

104    def init(self):

トラッカー構成を設定

109        tracker.set_scalar("accuracy.*", True)
110        tracker.set_scalar("loss.*", True)
111        tracker.set_text("sampled", False)

モジュール出力をログに記録するフックを追加

113        hook_model_outputs(self.mode, self.model, 'model')

ステートモジュールとして精度を追加してください。この名前は、RNN のトレーニングと検証の間の状態を保存するためのものなので、おそらくわかりにくいでしょう。これにより、精度指標の統計情報がトレーニング用と検証用に別々に保持されます。

118        self.state_modules = [self.accuracy]

オーバーライドして他の指標を計算して記録する

120    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
122        pass

トレーニングまたは検証ステップ

124    def step(self, batch: any, batch_idx: BatchIndex):

トレーニング/評価モードの設定

130        self.model.train(self.mode.is_train)

データをデバイスに移動

133        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新

136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])

モデル出力をキャプチャするかどうか

140        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

モデル出力を取得します。RNN を使用する場合はステートのタプルを返します。これはまだ実装されていません。😜

144            output, *_ = self.model(data)

損失の計算と記録

147        loss = self.loss_func(output, target)
148        tracker.add("loss.", loss)

精度の計算と記録

151        self.accuracy(output, target)
152        self.accuracy.track()
153
154        self.other_metrics(output, target)

モデルのトレーニング

157        if self.mode.is_train:

勾配の計算

159            loss.backward()

クリップグラデーション

161            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

最適化の一歩を踏み出す

163            self.optimizer.step()

各エポックの最後のバッチでモデルパラメータと勾配を記録します

165            if batch_idx.is_last and self.is_log_model_params_grads:
166                tracker.add('model', self.model)

グラデーションをクリア

168            self.optimizer.zero_grad()

追跡したメトリクスを保存する

171        tracker.save()

トレーニング中に定期的にサンプルを生成するサンプリング機能

173    def sample(self):

起動プロンプト

179        prompt = self.prompt

印刷用の出力を収集

181        log = [(prompt, Text.subtle)]

25トークンのサンプル

183        for i in monit.iterate('Sample', 25):

プロンプトをトークン化

185            data = self.text.text_to_i(prompt).unsqueeze(-1)
186            data = data.to(self.device)

モデル出力を取得

188            output, *_ = self.model(data)

モデル予測を取得 (欲張り)

190            output = output.argmax(dim=-1).squeeze()

予測をプロンプトに追加

192            prompt += self.prompt_separator + self.text.itos[output[-1]]

ロギング用の予測を追加

194            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
195
196        tracker.add({'sampled': prompt})

サンプル出力を印刷する

198        logger.log(log)
201@option(NLPAutoRegressionConfigs.optimizer)
202def _optimizer(c: NLPAutoRegressionConfigs):
207    optimizer = OptimizerConfigs()
208    optimizer.parameters = c.model.parameters()
209    optimizer.optimizer = 'Adam'
210    optimizer.d_model = c.d_model
211
212    return optimizer

トークンの数を取得

215@option(NLPAutoRegressionConfigs.n_tokens)
216def _n_tokens(c: NLPAutoRegressionConfigs):
220    return c.text.n_tokens

ベーシック・イングリッシュ・トークナイザー

この実験では、キャラクターレベルのトークナイザーを使用します。設定で切り替えることができますが、

'tokenizer': 'basic_english',

実験を開始するときに構成辞書にあります。

223@option(NLPAutoRegressionConfigs.tokenizer)
224def basic_english():
238    from torchtext.data import get_tokenizer
239    return get_tokenizer('basic_english')

キャラクターレベルトークナイザー

242def character_tokenizer(x: str):
246    return list(x)

キャラクターレベルのトークナイザー設定

249@option(NLPAutoRegressionConfigs.tokenizer)
250def character():
254    return character_tokenizer

小さなシェイクスピアデータセット

存在しない場合は URL からダウンロードします

257@option(NLPAutoRegressionConfigs.text)
258def tiny_shakespeare(c: NLPAutoRegressionConfigs):
264    return TextFileDataset(
265        lab.get_data_path() / 'tiny_shakespeare.txt',
266        c.tokenizer,
267        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

シーケンシャルトレーニングデータローダー

270@option(NLPAutoRegressionConfigs.train_loader)
271def sequential_train_loader(c: NLPAutoRegressionConfigs):
275    return SequentialDataLoader(text=c.text.train,
276                                dataset=c.text,
277                                batch_size=c.batch_size,
278                                seq_len=c.seq_len)

シーケンシャル検証データローダー

281@option(NLPAutoRegressionConfigs.valid_loader)
282def sequential_valid_loader(c: NLPAutoRegressionConfigs):
286    return SequentialDataLoader(text=c.text.valid,
287                                dataset=c.text,
288                                batch_size=c.batch_size,
289                                seq_len=c.seq_len)

トランスポーズバッチ

DataLoader 第 1 次元のバッチを収集します。最初にシーケンスになるようにトランスポーズする必要があります

292def transpose_batch(batch):
300    transposed_data = list(zip(*batch))

2 番目の次元に沿ってバッチを積み重ねる dim=1

302    src = torch.stack(transposed_data[0], dim=1)
303    tgt = torch.stack(transposed_data[1], dim=1)
304
305    return src, tgt

シャッフルされたトレーニングデータローダー

308@option(NLPAutoRegressionConfigs.train_loader)
309def shuffled_train_loader(c: NLPAutoRegressionConfigs):
313    dataset = SequentialUnBatchedDataset(text=c.text.train,
314                                         dataset=c.text,
315                                         seq_len=c.seq_len)
316    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
317
318    return DataLoader(dataset,
319                      batch_size=c.batch_size,
320                      collate_fn=transpose_batch,
321                      sampler=sampler)

シャッフルされた検証データローダー

324@option(NLPAutoRegressionConfigs.valid_loader)
325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
329    dataset = SequentialUnBatchedDataset(text=c.text.valid,
330                                         dataset=c.text,
331                                         seq_len=c.seq_len)
332    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
333
334    return DataLoader(dataset,
335                      batch_size=c.batch_size,
336                      collate_fn=transpose_batch,
337                      sampler=sampler)