11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs
28class CrossEntropyLoss(Module):
33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()
37 def forward(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
41class NLPAutoRegressionConfigs(TrainValidConfigs):
优化器
52 optimizer: torch.optim.Adam
训练设备
54 device: torch.device = DeviceConfigs()
自回归模型
57 model: Module
文本数据集
59 text: TextDataset
批量大小
61 batch_size: int = 16
序列的长度或上下文大小
63 seq_len: int = 512
词汇中的代币数量
65 n_tokens: int
分词器
67 tokenizer: Callable = 'character'
开始采样的文本提示(用于说明)
70 prompt: str
采样时的令牌分隔符(对于字符级别标记化为空白)
72 prompt_separator: str
是否定期保存模型
75 is_save_models = True
亏损函数
78 loss_func = CrossEntropyLoss()
精度函数
80 accuracy = Accuracy()
模型嵌入大小
82 d_model: int = 512
渐变剪切
84 grad_norm_clip: float = 1.0
训练数据加载器
87 train_loader: DataLoader = 'shuffled_train_loader'
验证数据加载器
89 valid_loader: DataLoader = 'shuffled_valid_loader'
数据加载器随着替换而随机播放
92 dataloader_shuffle_with_replacement: bool = False
是否记录模型参数和梯度(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。
97 is_log_model_params_grads: bool = False
是否记录模型激活(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。
102 is_log_model_activations: bool = False
104 def init(self):
设置跟踪器配置
109 tracker.set_scalar("accuracy.*", True)
110 tracker.set_scalar("loss.*", True)
111 tracker.set_text("sampled", False)
向日志模块输出添加钩子
113 hook_model_outputs(self.mode, self.model, 'model')
增加作为状态模块的精度。这个名字可能令人困惑,因为它旨在存储 RNN 的训练和验证之间的状态。这将使精度指标统计数据分开,以便进行训练和验证。
118 self.state_modules = [self.accuracy]
覆盖以计算和记录其他指标
120 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
122 pass
124 def step(self, batch: any, batch_idx: BatchIndex):
设置训练/评估模式
130 self.model.train(self.mode.is_train)
将数据移动到设备
133 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式下更新全局步长(处理的令牌数)
136 if self.mode.is_train:
137 tracker.add_global_step(data.shape[0] * data.shape[1])
是否捕获模型输出
140 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
获取模型输出。它在使用 RNN 时返回状态的元组。这还没有实现。😜
144 output, *_ = self.model(data)
计算并记录损失
147 loss = self.loss_func(output, target)
148 tracker.add("loss.", loss)
计算和记录精度
151 self.accuracy(output, target)
152 self.accuracy.track()
153
154 self.other_metrics(output, target)
训练模型
157 if self.mode.is_train:
计算梯度
159 loss.backward()
剪辑渐变
161 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
采取优化器步骤
163 self.optimizer.step()
记录每个纪元最后一批的模型参数和梯度
165 if batch_idx.is_last and self.is_log_model_params_grads:
166 tracker.add('model', self.model)
清除渐变
168 self.optimizer.zero_grad()
保存跟踪的指标
171 tracker.save()
173 def sample(self):
启动提示
179 prompt = self.prompt
收集输出以进行打印
181 log = [(prompt, Text.subtle)]
样本 25 个代币
183 for i in monit.iterate('Sample', 25):
将提示符号化
185 data = self.text.text_to_i(prompt).unsqueeze(-1)
186 data = data.to(self.device)
获取模型输出
188 output, *_ = self.model(data)
获取模型预测(贪婪)
190 output = output.argmax(dim=-1).squeeze()
将预测添加到提示符中
192 prompt += self.prompt_separator + self.text.itos[output[-1]]
添加日志记录的预测
194 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
195
196 tracker.add({'sampled': prompt})
打印采样输出
198 logger.log(log)
201@option(NLPAutoRegressionConfigs.optimizer)
202def _optimizer(c: NLPAutoRegressionConfigs):
207 optimizer = OptimizerConfigs()
208 optimizer.parameters = c.model.parameters()
209 optimizer.optimizer = 'Adam'
210 optimizer.d_model = c.d_model
211
212 return optimizer
获取代币数量
215@option(NLPAutoRegressionConfigs.n_tokens)
216def _n_tokens(c: NLPAutoRegressionConfigs):
220 return c.text.n_tokens
223@option(NLPAutoRegressionConfigs.tokenizer)
224def basic_english():
238 from torchtext.data import get_tokenizer
239 return get_tokenizer('basic_english')
242def character_tokenizer(x: str):
246 return list(x)
249@option(NLPAutoRegressionConfigs.tokenizer)
250def character():
254 return character_tokenizer
257@option(NLPAutoRegressionConfigs.text)
258def tiny_shakespeare(c: NLPAutoRegressionConfigs):
264 return TextFileDataset(
265 lab.get_data_path() / 'tiny_shakespeare.txt',
266 c.tokenizer,
267 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
270@option(NLPAutoRegressionConfigs.train_loader)
271def sequential_train_loader(c: NLPAutoRegressionConfigs):
275 return SequentialDataLoader(text=c.text.train,
276 dataset=c.text,
277 batch_size=c.batch_size,
278 seq_len=c.seq_len)
281@option(NLPAutoRegressionConfigs.valid_loader)
282def sequential_valid_loader(c: NLPAutoRegressionConfigs):
286 return SequentialDataLoader(text=c.text.valid,
287 dataset=c.text,
288 batch_size=c.batch_size,
289 seq_len=c.seq_len)
292def transpose_batch(batch):
300 transposed_data = list(zip(*batch))
沿第二维度堆叠批次dim=1
302 src = torch.stack(transposed_data[0], dim=1)
303 tgt = torch.stack(transposed_data[1], dim=1)
304
305 return src, tgt
308@option(NLPAutoRegressionConfigs.train_loader)
309def shuffled_train_loader(c: NLPAutoRegressionConfigs):
313 dataset = SequentialUnBatchedDataset(text=c.text.train,
314 dataset=c.text,
315 seq_len=c.seq_len)
316 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
317
318 return DataLoader(dataset,
319 batch_size=c.batch_size,
320 collate_fn=transpose_batch,
321 sampler=sampler)
324@option(NLPAutoRegressionConfigs.valid_loader)
325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
329 dataset = SequentialUnBatchedDataset(text=c.text.valid,
330 dataset=c.text,
331 seq_len=c.seq_len)
332 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
333
334 return DataLoader(dataset,
335 batch_size=c.batch_size,
336 collate_fn=transpose_batch,
337 sampler=sampler)