11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs
28class CrossEntropyLoss(Module):
33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()
37 def forward(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
41class NLPAutoRegressionConfigs(TrainValidConfigs):
オプティマイザー
52 optimizer: torch.optim.Adam
トレーニングデバイス
54 device: torch.device = DeviceConfigs()
自己回帰モデル
57 model: Module
テキストデータセット
59 text: TextDataset
バッチサイズ
61 batch_size: int = 16
シーケンスの長さ、またはコンテキストサイズ
63 seq_len: int = 512
ボキャブラリー内のトークンの数
65 n_tokens: int
トークナイザー
67 tokenizer: Callable = 'character'
サンプリングを開始するテキストプロンプト (説明用)
70 prompt: str
サンプリング時のトークンセパレーター (文字レベルのトークン化の場合は空白)
72 prompt_separator: str
モデルを定期的に保存するかどうか
75 is_save_models = True
損失関数
78 loss_func = CrossEntropyLoss()
精度機能
80 accuracy = Accuracy()
モデル埋め込みサイズ
82 d_model: int = 512
グラデーションクリッピング
84 grad_norm_clip: float = 1.0
トレーニングデータローダー
87 train_loader: DataLoader = 'shuffled_train_loader'
検証データローダー
89 valid_loader: DataLoader = 'shuffled_valid_loader'
データローダーは交換時にシャッフルされます
92 dataloader_shuffle_with_replacement: bool = False
モデルパラメーターと勾配を記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります
。97 is_log_model_params_grads: bool = False
モデルのアクティベーションをログに記録するかどうか (エポックごとに 1 回)。これらはレイヤーごとの統計情報をまとめたものですが、それでも非常に深いネットワークの多くの指標につながる可能性があります
。102 is_log_model_activations: bool = False
104 def init(self):
トラッカー構成を設定
109 tracker.set_scalar("accuracy.*", True)
110 tracker.set_scalar("loss.*", True)
111 tracker.set_text("sampled", False)
モジュール出力をログに記録するフックを追加
113 hook_model_outputs(self.mode, self.model, 'model')
ステートモジュールとして精度を追加してください。この名前は、RNN のトレーニングと検証の間の状態を保存するためのものなので、おそらくわかりにくいでしょう。これにより、精度指標の統計情報がトレーニング用と検証用に別々に保持されます。
118 self.state_modules = [self.accuracy]
オーバーライドして他の指標を計算して記録する
120 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
122 pass
124 def step(self, batch: any, batch_idx: BatchIndex):
トレーニング/評価モードの設定
130 self.model.train(self.mode.is_train)
データをデバイスに移動
133 data, target = batch[0].to(self.device), batch[1].to(self.device)
トレーニングモード時にグローバルステップ (処理されたトークンの数) を更新
136 if self.mode.is_train:
137 tracker.add_global_step(data.shape[0] * data.shape[1])
モデル出力をキャプチャするかどうか
140 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
モデル出力を取得します。RNN を使用する場合はステートのタプルを返します。これはまだ実装されていません。😜
144 output, *_ = self.model(data)
損失の計算と記録
147 loss = self.loss_func(output, target)
148 tracker.add("loss.", loss)
精度の計算と記録
151 self.accuracy(output, target)
152 self.accuracy.track()
153
154 self.other_metrics(output, target)
モデルのトレーニング
157 if self.mode.is_train:
勾配の計算
159 loss.backward()
クリップグラデーション
161 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
最適化の一歩を踏み出す
163 self.optimizer.step()
各エポックの最後のバッチでモデルパラメータと勾配を記録します
165 if batch_idx.is_last and self.is_log_model_params_grads:
166 tracker.add('model', self.model)
グラデーションをクリア
168 self.optimizer.zero_grad()
追跡したメトリクスを保存する
171 tracker.save()
173 def sample(self):
起動プロンプト
179 prompt = self.prompt
印刷用の出力を収集
181 log = [(prompt, Text.subtle)]
25トークンのサンプル
183 for i in monit.iterate('Sample', 25):
プロンプトをトークン化
185 data = self.text.text_to_i(prompt).unsqueeze(-1)
186 data = data.to(self.device)
モデル出力を取得
188 output, *_ = self.model(data)
モデル予測を取得 (欲張り)
190 output = output.argmax(dim=-1).squeeze()
予測をプロンプトに追加
192 prompt += self.prompt_separator + self.text.itos[output[-1]]
ロギング用の予測を追加
194 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
195
196 tracker.add({'sampled': prompt})
サンプル出力を印刷する
198 logger.log(log)
201@option(NLPAutoRegressionConfigs.optimizer)
202def _optimizer(c: NLPAutoRegressionConfigs):
207 optimizer = OptimizerConfigs()
208 optimizer.parameters = c.model.parameters()
209 optimizer.optimizer = 'Adam'
210 optimizer.d_model = c.d_model
211
212 return optimizer
トークンの数を取得
215@option(NLPAutoRegressionConfigs.n_tokens)
216def _n_tokens(c: NLPAutoRegressionConfigs):
220 return c.text.n_tokens
この実験では、キャラクターレベルのトークナイザーを使用します。設定で切り替えることができますが、
'tokenizer': 'basic_english',
実験を開始するときに構成辞書にあります。
223@option(NLPAutoRegressionConfigs.tokenizer)
224def basic_english():
238 from torchtext.data import get_tokenizer
239 return get_tokenizer('basic_english')
242def character_tokenizer(x: str):
246 return list(x)
249@option(NLPAutoRegressionConfigs.tokenizer)
250def character():
254 return character_tokenizer
257@option(NLPAutoRegressionConfigs.text)
258def tiny_shakespeare(c: NLPAutoRegressionConfigs):
264 return TextFileDataset(
265 lab.get_data_path() / 'tiny_shakespeare.txt',
266 c.tokenizer,
267 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
270@option(NLPAutoRegressionConfigs.train_loader)
271def sequential_train_loader(c: NLPAutoRegressionConfigs):
275 return SequentialDataLoader(text=c.text.train,
276 dataset=c.text,
277 batch_size=c.batch_size,
278 seq_len=c.seq_len)
281@option(NLPAutoRegressionConfigs.valid_loader)
282def sequential_valid_loader(c: NLPAutoRegressionConfigs):
286 return SequentialDataLoader(text=c.text.valid,
287 dataset=c.text,
288 batch_size=c.batch_size,
289 seq_len=c.seq_len)
292def transpose_batch(batch):
300 transposed_data = list(zip(*batch))
2 番目の次元に沿ってバッチを積み重ねる dim=1
302 src = torch.stack(transposed_data[0], dim=1)
303 tgt = torch.stack(transposed_data[1], dim=1)
304
305 return src, tgt
308@option(NLPAutoRegressionConfigs.train_loader)
309def shuffled_train_loader(c: NLPAutoRegressionConfigs):
313 dataset = SequentialUnBatchedDataset(text=c.text.train,
314 dataset=c.text,
315 seq_len=c.seq_len)
316 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
317
318 return DataLoader(dataset,
319 batch_size=c.batch_size,
320 collate_fn=transpose_batch,
321 sampler=sampler)
324@option(NLPAutoRegressionConfigs.valid_loader)
325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
329 dataset = SequentialUnBatchedDataset(text=c.text.valid,
330 dataset=c.text,
331 seq_len=c.seq_len)
332 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
333
334 return DataLoader(dataset,
335 batch_size=c.batch_size,
336 collate_fn=transpose_batch,
337 sampler=sampler)