11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs
30class NLPClassificationConfigs(TrainValidConfigs):
优化器
41 optimizer: torch.optim.Adam
训练设备
43 device: torch.device = DeviceConfigs()
自回归模型
46 model: Module
批量大小
48 batch_size: int = 16
序列的长度或上下文大小
50 seq_len: int = 512
词汇
52 vocab: Vocab = 'ag_news'
词汇中的代币数量
54 n_tokens: int
班级数
56 n_classes: int = 'ag_news'
分词器
58 tokenizer: Callable = 'character'
是否定期保存模型
61 is_save_models = True
亏损函数
64 loss_func = nn.CrossEntropyLoss()
精度函数
66 accuracy = Accuracy()
模型嵌入大小
68 d_model: int = 512
渐变剪切
70 grad_norm_clip: float = 1.0
训练数据加载器
73 train_loader: DataLoader = 'ag_news'
验证数据加载器
75 valid_loader: DataLoader = 'ag_news'
是否记录模型参数和梯度(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。
80 is_log_model_params_grads: bool = False
是否记录模型激活(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。
85 is_log_model_activations: bool = False
87 def init(self):
设置跟踪器配置
92 tracker.set_scalar("accuracy.*", True)
93 tracker.set_scalar("loss.*", True)
向日志模块输出添加钩子
95 hook_model_outputs(self.mode, self.model, 'model')
增加作为状态模块的精度。这个名字可能令人困惑,因为它旨在存储 RNN 的训练和验证之间的状态。这将使精度指标统计数据分开,以便进行训练和验证。
100 self.state_modules = [self.accuracy]
102 def step(self, batch: any, batch_idx: BatchIndex):
将数据移动到设备
108 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式下更新全局步长(处理的令牌数)
111 if self.mode.is_train:
112 tracker.add_global_step(data.shape[1])
是否捕获模型输出
115 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
获取模型输出。它在使用 RNN 时返回状态的元组。这还没有实现。😜
119 output, *_ = self.model(data)
计算并记录损失
122 loss = self.loss_func(output, target)
123 tracker.add("loss.", loss)
计算和记录精度
126 self.accuracy(output, target)
127 self.accuracy.track()
训练模型
130 if self.mode.is_train:
计算梯度
132 loss.backward()
剪辑渐变
134 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
采取优化器步骤
136 self.optimizer.step()
记录每个纪元最后一批的模型参数和梯度
138 if batch_idx.is_last and self.is_log_model_params_grads:
139 tracker.add('model', self.model)
清除渐变
141 self.optimizer.zero_grad()
保存跟踪的指标
144 tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153 optimizer = OptimizerConfigs()
154 optimizer.parameters = c.model.parameters()
155 optimizer.optimizer = 'Adam'
156 optimizer.d_model = c.d_model
157
158 return optimizer
161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176 from torchtext.data import get_tokenizer
177 return get_tokenizer('basic_english')
180def character_tokenizer(x: str):
184 return list(x)
角色级别分词器配置
187@option(NLPClassificationConfigs.tokenizer)
188def character():
192 return character_tokenizer
获取代币数量
195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200 return len(c.vocab) + 2
203class CollateFunc:
tokenizer
是分词器函数vocab
是词汇seq_len
是序列的长度padding_token
是大于文本长度时seq_len
用于填充的标记classifier_token
是我们在输入末尾设置的[CLS]
令牌208 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216 self.classifier_token = classifier_token
217 self.padding_token = padding_token
218 self.seq_len = seq_len
219 self.vocab = vocab
220 self.tokenizer = tokenizer
batch
是由DataLoader
222 def __call__(self, batch):
输入数据张量,初始化为padding_token
228 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
空标签张量
230 labels = torch.zeros(len(batch), dtype=torch.long)
循环浏览样本
233 for (i, (_label, _text)) in enumerate(batch):
设置标签
235 labels[i] = int(_label) - 1
标记输入文本
237 _text = [self.vocab[token] for token in self.tokenizer(_text)]
截断最多seq_len
239 _text = _text[:self.seq_len]
转置并添加到数据
241 data[:len(_text), i] = data.new_tensor(_text)
将序列中的最后一个令牌设置为[CLS]
244 data[-1, :] = self.classifier_token
247 return data, labels
250@option([NLPClassificationConfigs.n_classes,
251 NLPClassificationConfigs.vocab,
252 NLPClassificationConfigs.train_loader,
253 NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):
获取训练和验证数据集
263 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
将数据加载到内存
266 with monit.section('Load data'):
267 from labml_nn.utils import MapStyleDataset
获取分词器
273 tokenizer = c.tokenizer
创建计数器
276 counter = Counter()
从训练数据集中收集令牌
278 for (label, line) in train:
279 counter.update(tokenizer(line))
从验证数据集中收集令牌
281 for (label, line) in valid:
282 counter.update(tokenizer(line))
创建词汇
284 vocab = torchtext.vocab.vocab(counter, min_freq=1)
创建训练数据加载器
287 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
创建验证数据加载器
290 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
返回n_classes
vocab
、train_loader
、和valid_loader
294 return 4, vocab, train_loader, valid_loader