用于分类的 NLP 模型训练器

11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs

训练器配置

它具有 NLP 分类任务培训的基本配置。所有属性都是可配置的。

30class NLPClassificationConfigs(TrainValidConfigs):

优化器

41    optimizer: torch.optim.Adam

训练设备

43    device: torch.device = DeviceConfigs()

自回归模型

46    model: Module

批量大小

48    batch_size: int = 16

序列的长度或上下文大小

50    seq_len: int = 512

词汇

52    vocab: Vocab = 'ag_news'

词汇中的代币数量

54    n_tokens: int

班级数

56    n_classes: int = 'ag_news'

分词器

58    tokenizer: Callable = 'character'

是否定期保存模型

61    is_save_models = True

亏损函数

64    loss_func = nn.CrossEntropyLoss()

精度函数

66    accuracy = Accuracy()

模型嵌入大小

68    d_model: int = 512

渐变剪切

70    grad_norm_clip: float = 1.0

训练数据加载器

73    train_loader: DataLoader = 'ag_news'

验证数据加载器

75    valid_loader: DataLoader = 'ag_news'

是否记录模型参数和梯度(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。

80    is_log_model_params_grads: bool = False

是否记录模型激活(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。

85    is_log_model_activations: bool = False

初始化

87    def init(self):

设置跟踪器配置

92        tracker.set_scalar("accuracy.*", True)
93        tracker.set_scalar("loss.*", True)

向日志模块输出添加钩子

95        hook_model_outputs(self.mode, self.model, 'model')

增加作为状态模块的精度。这个名字可能令人困惑,因为它旨在存储 RNN 的训练和验证之间的状态。这将使精度指标统计数据分开,以便进行训练和验证。

100        self.state_modules = [self.accuracy]

培训或验证步骤

102    def step(self, batch: any, batch_idx: BatchIndex):

将数据移动到设备

108        data, target = batch[0].to(self.device), batch[1].to(self.device)

在训练模式下更新全局步长(处理的令牌数)

111        if self.mode.is_train:
112            tracker.add_global_step(data.shape[1])

是否捕获模型输出

115        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

获取模型输出。它在使用 RNN 时返回状态的元组。这还没有实现。😜

119            output, *_ = self.model(data)

计算并记录损失

122        loss = self.loss_func(output, target)
123        tracker.add("loss.", loss)

计算和记录精度

126        self.accuracy(output, target)
127        self.accuracy.track()

训练模型

130        if self.mode.is_train:

计算梯度

132            loss.backward()

剪辑渐变

134            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

采取优化器步骤

136            self.optimizer.step()

记录每个纪元最后一批的模型参数和梯度

138            if batch_idx.is_last and self.is_log_model_params_grads:
139                tracker.add('model', self.model)

清除渐变

141            self.optimizer.zero_grad()

保存跟踪的指标

144        tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153    optimizer = OptimizerConfigs()
154    optimizer.parameters = c.model.parameters()
155    optimizer.optimizer = 'Adam'
156    optimizer.d_model = c.d_model
157
158    return optimizer

基础英语分词器

我们在这个实验中使用角色等级分词器。你可以通过设置进行切换,

'tokenizer': 'basic_english',

开始实验时在配置字典中。

161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176    from torchtext.data import get_tokenizer
177    return get_tokenizer('basic_english')

角色等级分词器

180def character_tokenizer(x: str):
184    return list(x)

角色级别分词器配置

187@option(NLPClassificationConfigs.tokenizer)
188def character():
192    return character_tokenizer

获取代币数量

195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200    return len(c.vocab) + 2

将数据加载到批处理中的函数

203class CollateFunc:
  • tokenizer 是分词器函数
  • vocab 是词汇
  • seq_len 是序列的长度
  • padding_token 是大于文本长度时seq_len 用于填充的标记
  • classifier_token 是我们在输入末尾设置的[CLS] 令牌
208    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216        self.classifier_token = classifier_token
217        self.padding_token = padding_token
218        self.seq_len = seq_len
219        self.vocab = vocab
220        self.tokenizer = tokenizer
  • batch 是由DataLoader
222    def __call__(self, batch):

输入数据张量,初始化为padding_token

228        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)

空标签张量

230        labels = torch.zeros(len(batch), dtype=torch.long)

循环浏览样本

233        for (i, (_label, _text)) in enumerate(batch):

设置标签

235            labels[i] = int(_label) - 1

标记输入文本

237            _text = [self.vocab[token] for token in self.tokenizer(_text)]

截断最多seq_len

239            _text = _text[:self.seq_len]

转置并添加到数据

241            data[:len(_text), i] = data.new_tensor(_text)

将序列中的最后一个令牌设置为[CLS]

244        data[-1, :] = self.classifier_token

247        return data, labels

AG 新闻数据集

这将加载 AG News 数据集并设置n_classesvocab train_loader 、和的值valid_loader

250@option([NLPClassificationConfigs.n_classes,
251         NLPClassificationConfigs.vocab,
252         NLPClassificationConfigs.train_loader,
253         NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):

获取训练和验证数据集

263    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))

将数据加载到内存

266    with monit.section('Load data'):
267        from labml_nn.utils import MapStyleDataset
270        train, valid = MapStyleDataset(train), MapStyleDataset(valid)

获取分词器

273    tokenizer = c.tokenizer

创建计数器

276    counter = Counter()

从训练数据集中收集令牌

278    for (label, line) in train:
279        counter.update(tokenizer(line))

从验证数据集中收集令牌

281    for (label, line) in valid:
282        counter.update(tokenizer(line))

创建词汇

284    vocab = torchtext.vocab.vocab(counter, min_freq=1)

创建训练数据加载器

287    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

创建验证数据加载器

290    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

返回n_classes vocabtrain_loader 、和valid_loader

294    return 4, vocab, train_loader, valid_loader