分類用 NLP モデルトレーナー

11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs

トレーナー構成

これには、NLP分類タスクトレーニングの基本的な構成があります。すべてのプロパティは設定可能です。

30class NLPClassificationConfigs(TrainValidConfigs):

オプティマイザー

41    optimizer: torch.optim.Adam

トレーニングデバイス

43    device: torch.device = DeviceConfigs()

自己回帰モデル

46    model: Module

バッチサイズ

48    batch_size: int = 16

シーケンスの長さ、またはコンテキストサイズ

50    seq_len: int = 512

語彙

52    vocab: Vocab = 'ag_news'

ボキャブラリー内のトークンの数

54    n_tokens: int

クラス数

56    n_classes: int = 'ag_news'

トークナイザー

58    tokenizer: Callable = 'character'

モデルを定期的に保存するかどうか

61    is_save_models = True

損失関数

64    loss_func = nn.CrossEntropyLoss()

精度機能

66    accuracy = Accuracy()

モデル埋め込みサイズ

68    d_model: int = 512

グラデーションクリッピング

70    grad_norm_clip: float = 1.0

トレーニングデータローダー

73    train_loader: DataLoader = 'ag_news'

検証データローダー

75    valid_loader: DataLoader = 'ag_news'

モデルパラメーターと勾配を記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります

80    is_log_model_params_grads: bool = False

モデルのアクティベーションをログに記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります

85    is_log_model_activations: bool = False

初期化

87    def init(self):

トラッカー構成を設定

92        tracker.set_scalar("accuracy.*", True)
93        tracker.set_scalar("loss.*", True)

モジュール出力をログに記録するフックを追加

95        hook_model_outputs(self.mode, self.model, 'model')

ステートモジュールとして精度を追加してください。この名前は、RNN のトレーニングと検証の間の状態を保存するためのものなので、おそらくわかりにくいでしょう。これにより、精度指標の統計情報がトレーニング用と検証用に別々に保持されます。

100        self.state_modules = [self.accuracy]

トレーニングまたは検証ステップ

102    def step(self, batch: any, batch_idx: BatchIndex):

データをデバイスに移動

108        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新

111        if self.mode.is_train:
112            tracker.add_global_step(data.shape[1])

モデル出力をキャプチャするかどうか

115        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

モデル出力を取得します。RNN を使用する場合はステートのタプルを返します。これはまだ実装されていません。😜

119            output, *_ = self.model(data)

損失の計算と記録

122        loss = self.loss_func(output, target)
123        tracker.add("loss.", loss)

精度の計算と記録

126        self.accuracy(output, target)
127        self.accuracy.track()

モデルのトレーニング

130        if self.mode.is_train:

勾配の計算

132            loss.backward()

クリップグラデーション

134            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

最適化の一歩を踏み出す

136            self.optimizer.step()

各エポックの最後のバッチでモデルパラメータと勾配を記録します

138            if batch_idx.is_last and self.is_log_model_params_grads:
139                tracker.add('model', self.model)

グラデーションをクリア

141            self.optimizer.zero_grad()

追跡したメトリクスを保存する

144        tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153    optimizer = OptimizerConfigs()
154    optimizer.parameters = c.model.parameters()
155    optimizer.optimizer = 'Adam'
156    optimizer.d_model = c.d_model
157
158    return optimizer

ベーシック・イングリッシュ・トークナイザー

この実験では、キャラクターレベルのトークナイザーを使用します。設定で切り替えることができますが、

'tokenizer': 'basic_english',

実験を開始するときに構成辞書にあります。

161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176    from torchtext.data import get_tokenizer
177    return get_tokenizer('basic_english')

キャラクターレベルトークナイザー

180def character_tokenizer(x: str):
184    return list(x)

キャラクターレベルのトークナイザー設定

187@option(NLPClassificationConfigs.tokenizer)
188def character():
192    return character_tokenizer

トークンの数を取得

195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200    return len(c.vocab) + 2

データをバッチにロードする機能

203class CollateFunc:
  • tokenizer トークナイザー関数です
  • vocab はボキャブラリー
  • seq_len シーケンスの長さです
  • padding_token seq_len がテキストの長さより大きい場合にパディングに使用されるトークンです
  • classifier_token [CLS] 入力の最後に設定したトークンです
208    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216        self.classifier_token = classifier_token
217        self.padding_token = padding_token
218        self.seq_len = seq_len
219        self.vocab = vocab
220        self.tokenizer = tokenizer
  • batch が収集したデータのバッチです DataLoader
222    def __call__(self, batch):

で初期化された入力データテンソル padding_token

228        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)

空ラベルテンソル

230        labels = torch.zeros(len(batch), dtype=torch.long)

サンプルをループ処理

233        for (i, (_label, _text)) in enumerate(batch):

ラベルを設定

235            labels[i] = int(_label) - 1

入力テキストをトークン化

237            _text = [self.vocab[token] for token in self.tokenizer(_text)]

最大まで切り捨て seq_len

239            _text = _text[:self.seq_len]

転置してデータに追加

241            data[:len(_text), i] = data.new_tensor(_text)

シーケンスの最後のトークンを次のように設定します。[CLS]

244        data[-1, :] = self.classifier_token

247        return data, labels

AG ニュースデータセット

これにより AG News データセットが読み込まれn_classes 、、vocab train_loadervalid_loader の値が設定されます。

250@option([NLPClassificationConfigs.n_classes,
251         NLPClassificationConfigs.vocab,
252         NLPClassificationConfigs.train_loader,
253         NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):

トレーニングと検証のデータセットを入手

263    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))

データをメモリに読み込む

266    with monit.section('Load data'):
267        from labml_nn.utils import MapStyleDataset
270        train, valid = MapStyleDataset(train), MapStyleDataset(valid)

トークナイザーを入手

273    tokenizer = c.tokenizer

カウンターの作成

276    counter = Counter()

トレーニングデータセットからトークンを収集

278    for (label, line) in train:
279        counter.update(tokenizer(line))

検証データセットからトークンを収集

281    for (label, line) in valid:
282        counter.update(tokenizer(line))

ボキャブラリーの作成

284    vocab = torchtext.vocab.vocab(counter, min_freq=1)

トレーニングデータローダーの作成

287    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

検証データローダーの作成

290    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

リターンn_classes vocabtrain_loader 、、valid_loader

294    return 4, vocab, train_loader, valid_loader