11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs
30class NLPClassificationConfigs(TrainValidConfigs):
オプティマイザー
41 optimizer: torch.optim.Adam
トレーニングデバイス
43 device: torch.device = DeviceConfigs()
自己回帰モデル
46 model: Module
バッチサイズ
48 batch_size: int = 16
シーケンスの長さ、またはコンテキストサイズ
50 seq_len: int = 512
語彙
52 vocab: Vocab = 'ag_news'
ボキャブラリー内のトークンの数
54 n_tokens: int
クラス数
56 n_classes: int = 'ag_news'
トークナイザー
58 tokenizer: Callable = 'character'
モデルを定期的に保存するかどうか
61 is_save_models = True
損失関数
64 loss_func = nn.CrossEntropyLoss()
精度機能
66 accuracy = Accuracy()
モデル埋め込みサイズ
68 d_model: int = 512
グラデーションクリッピング
70 grad_norm_clip: float = 1.0
トレーニングデータローダー
73 train_loader: DataLoader = 'ag_news'
検証データローダー
75 valid_loader: DataLoader = 'ag_news'
モデルパラメーターと勾配を記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります
。80 is_log_model_params_grads: bool = False
モデルのアクティベーションをログに記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります
。85 is_log_model_activations: bool = False
87 def init(self):
トラッカー構成を設定
92 tracker.set_scalar("accuracy.*", True)
93 tracker.set_scalar("loss.*", True)
モジュール出力をログに記録するフックを追加
95 hook_model_outputs(self.mode, self.model, 'model')
ステートモジュールとして精度を追加してください。この名前は、RNN のトレーニングと検証の間の状態を保存するためのものなので、おそらくわかりにくいでしょう。これにより、精度指標の統計情報がトレーニング用と検証用に別々に保持されます。
100 self.state_modules = [self.accuracy]
102 def step(self, batch: any, batch_idx: BatchIndex):
データをデバイスに移動
108 data, target = batch[0].to(self.device), batch[1].to(self.device)
トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新
111 if self.mode.is_train:
112 tracker.add_global_step(data.shape[1])
モデル出力をキャプチャするかどうか
115 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
モデル出力を取得します。RNN を使用する場合はステートのタプルを返します。これはまだ実装されていません。😜
119 output, *_ = self.model(data)
損失の計算と記録
122 loss = self.loss_func(output, target)
123 tracker.add("loss.", loss)
精度の計算と記録
126 self.accuracy(output, target)
127 self.accuracy.track()
モデルのトレーニング
130 if self.mode.is_train:
勾配の計算
132 loss.backward()
クリップグラデーション
134 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
最適化の一歩を踏み出す
136 self.optimizer.step()
各エポックの最後のバッチでモデルパラメータと勾配を記録します
138 if batch_idx.is_last and self.is_log_model_params_grads:
139 tracker.add('model', self.model)
グラデーションをクリア
141 self.optimizer.zero_grad()
追跡したメトリクスを保存する
144 tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153 optimizer = OptimizerConfigs()
154 optimizer.parameters = c.model.parameters()
155 optimizer.optimizer = 'Adam'
156 optimizer.d_model = c.d_model
157
158 return optimizer
この実験では、キャラクターレベルのトークナイザーを使用します。設定で切り替えることができますが、
'tokenizer': 'basic_english',
実験を開始するときに構成辞書にあります。
161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176 from torchtext.data import get_tokenizer
177 return get_tokenizer('basic_english')
180def character_tokenizer(x: str):
184 return list(x)
キャラクターレベルのトークナイザー設定
187@option(NLPClassificationConfigs.tokenizer)
188def character():
192 return character_tokenizer
トークンの数を取得
195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200 return len(c.vocab) + 2
203class CollateFunc:
tokenizer
トークナイザー関数ですvocab
はボキャブラリーseq_len
シーケンスの長さですpadding_token
seq_len
がテキストの長さより大きい場合にパディングに使用されるトークンですclassifier_token
[CLS]
入力の最後に設定したトークンです208 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216 self.classifier_token = classifier_token
217 self.padding_token = padding_token
218 self.seq_len = seq_len
219 self.vocab = vocab
220 self.tokenizer = tokenizer
batch
が収集したデータのバッチです DataLoader
222 def __call__(self, batch):
で初期化された入力データテンソル padding_token
228 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
空ラベルテンソル
230 labels = torch.zeros(len(batch), dtype=torch.long)
サンプルをループ処理
233 for (i, (_label, _text)) in enumerate(batch):
ラベルを設定
235 labels[i] = int(_label) - 1
入力テキストをトークン化
237 _text = [self.vocab[token] for token in self.tokenizer(_text)]
最大まで切り捨て seq_len
239 _text = _text[:self.seq_len]
転置してデータに追加
241 data[:len(_text), i] = data.new_tensor(_text)
シーケンスの最後のトークンを次のように設定します。[CLS]
244 data[-1, :] = self.classifier_token
247 return data, labels
250@option([NLPClassificationConfigs.n_classes,
251 NLPClassificationConfigs.vocab,
252 NLPClassificationConfigs.train_loader,
253 NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):
トレーニングと検証のデータセットを入手
263 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
データをメモリに読み込む
266 with monit.section('Load data'):
267 from labml_nn.utils import MapStyleDataset
270 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
トークナイザーを入手
273 tokenizer = c.tokenizer
カウンターの作成
276 counter = Counter()
トレーニングデータセットからトークンを収集
278 for (label, line) in train:
279 counter.update(tokenizer(line))
検証データセットからトークンを収集
281 for (label, line) in valid:
282 counter.update(tokenizer(line))
ボキャブラリーの作成
284 vocab = torchtext.vocab.vocab(counter, min_freq=1)
トレーニングデータローダーの作成
287 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
検証データローダーの作成
290 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
リターンn_classes
vocab
、train_loader
、、valid_loader
294 return 4, vocab, train_loader, valid_loader