掩码语言模型 (MLM) 实验

这是一个带注释的 PyTorch 实验,用于训练一个蒙版语言模型

11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM

基于变压器的传销模型

28class TransformerMLM(nn.Module):
33    def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):
40        super().__init__()
41        self.generator = generator
42        self.src_embed = src_embed
43        self.encoder = encoder
45    def forward(self, x: torch.Tensor):

使用位置编码获取令牌嵌入

47        x = self.src_embed(x)

变压器编码

49        x = self.encoder(x, None)

输出的对数

51        y = self.generator(x)

返回结果(第二个值用于状态,因为我们的训练器也与 RNN 一起使用)

55        return y, None

配置

这继承自,NLPAutoRegressionConfigs 因为它有我们在这里重用的数据管道实现。我们已经实施了 MLM 的自定义训练步骤。

58class Configs(NLPAutoRegressionConfigs):

传销模型

69    model: TransformerMLM

变压器

71    transformer: TransformerConfigs

代币数量

74    n_tokens: int = 'n_tokens_mlm'

不应该被掩盖的代币

76    no_mask_tokens: List[int] = []

掩盖代币的概率

78    masking_prob: float = 0.15

用随机令牌替换掩码的概率

80    randomize_prob: float = 0.1

用原始令牌替换掩码的概率

82    no_change_prob: float = 0.1
84    mlm: MLM

[MASK] 令牌

87    mask_token: int

[PADDING] 令牌

89    padding_token: int

提示采样

92    prompt: str = [
93        "We are accounted poor citizens, the patricians good.",
94        "What authority surfeits on would relieve us: if they",
95        "would yield us but the superfluity, while it were",
96        "wholesome, we might guess they relieved us humanely;",
97        "but they think we are too dear: the leanness that",
98        "afflicts us, the object of our misery, is as an",
99        "inventory to particularise their abundance; our",
100        "sufferance is a gain to them Let us revenge this with",
101        "our pikes, ere we become rakes: for the gods know I",
102        "speak this in hunger for bread, not in thirst for revenge.",
103    ]

初始化

105    def init(self):

[MASK] 令牌

111        self.mask_token = self.n_tokens - 1

[PAD] 令牌

113        self.padding_token = self.n_tokens - 2
116        self.mlm = MLM(padding_token=self.padding_token,
117                       mask_token=self.mask_token,
118                       no_mask_tokens=self.no_mask_tokens,
119                       n_tokens=self.n_tokens,
120                       masking_prob=self.masking_prob,
121                       randomize_prob=self.randomize_prob,
122                       no_change_prob=self.no_change_prob)

精度度量度(忽略等于的标签[PAD]

125        self.accuracy = Accuracy(ignore_index=self.padding_token)

交叉熵损失(忽略等于的标签[PAD]

127        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)

129        super().init()

培训或验证步骤

131    def step(self, batch: any, batch_idx: BatchIndex):

将输入移至设备

137        data = batch[0].to(self.device)

在训练模式下更新全局步长(处理的令牌数)

140        if self.mode.is_train:
141            tracker.add_global_step(data.shape[0] * data.shape[1])

获取屏蔽的输入和标签

144        with torch.no_grad():
145            data, labels = self.mlm(data)

是否捕获模型输出

148        with self.mode.update(is_log_activations=batch_idx.is_last):

获取模型输出。它在使用 RNN 时返回状态的元组。这尚未实现。

152            output, *_ = self.model(data)

计算并记录损失

155        loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156        tracker.add("loss.", loss)

计算和记录精度

159        self.accuracy(output, labels)
160        self.accuracy.track()

训练模型

163        if self.mode.is_train:

计算梯度

165            loss.backward()

剪辑渐变

167            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

采取优化器步骤

169            self.optimizer.step()

记录每个纪元最后一批的模型参数和梯度

171            if batch_idx.is_last:
172                tracker.add('model', self.model)

清除渐变

174            self.optimizer.zero_grad()

保存跟踪的指标

177        tracker.save()

采样功能可在训练时定期生成样本

179    @torch.no_grad()
180    def sample(self):

填充的数据为空张量[PAD]

186        data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)

逐个添加提示

188        for i, p in enumerate(self.prompt):

获取代币索引

190            d = self.text.text_to_i(p)

添加到张量中

192            s = min(self.seq_len, len(d))
193            data[:s, i] = d[:s]

将张量移到当前设备

195        data = data.to(self.device)

获取屏蔽的输入和标签

198        data, labels = self.mlm(data)

获取模型输出

200        output, *_ = self.model(data)

打印生成的样本

203        for j in range(data.shape[1]):

从打印中收集输出

205            log = []

对于每个代币

207            for i in range(len(data)):

如果标签不是[PAD]

209                if labels[i, j] != self.padding_token:

获取预测

211                    t = output[i, j].argmax().item()

如果是可打印的字符

213                    if t < len(self.text.itos):

正确的预测

215                        if t == labels[i, j]:
216                            log.append((self.text.itos[t], Text.value))

预测不正确

218                        else:
219                            log.append((self.text.itos[t], Text.danger))

如果它不是可打印的字符

221                    else:
222                        log.append(('*', Text.danger))

如果标签是[PAD] (未遮罩),请打印原件。

224                elif data[i, j] < len(self.text.itos):
225                    log.append((self.text.itos[data[i, j]], Text.subtle))

打印

228            logger.log(log)

包括[PAD] 和在内的代币数量[MASK]

231@option(Configs.n_tokens)
232def n_tokens_mlm(c: Configs):
236    return c.text.n_tokens + 2

变压器配置

239@option(Configs.transformer)
240def _transformer_configs(c: Configs):

我们使用我们的可配置变压器实现

247    conf = TransformerConfigs()

设置嵌入和生成 logit 的词汇量大小

249    conf.n_src_vocab = c.n_tokens
250    conf.n_tgt_vocab = c.n_tokens

嵌入大小

252    conf.d_model = c.d_model

255    return conf

创建分类模型

258@option(Configs.model)
259def _model(c: Configs):
263    m = TransformerMLM(encoder=c.transformer.encoder,
264                       src_embed=c.transformer.src_embed,
265                       generator=c.transformer.generator).to(c.device)
266
267    return m
270def main():

创建实验

272    experiment.create(name="mlm")

创建配置

274    conf = Configs()

覆盖配置

276    experiment.configs(conf, {

批量大小

278        'batch_size': 64,
的@@

序列长度。我们使用较短的序列长度来更快地训练。否则训练需要很长时间。

281        'seq_len': 32,

训练 1024 个时代。

284        'epochs': 1024,

在训练和验证之间切换每个纪元的次数

287        'inner_iterations': 1,

变压器配置(与默认值相同)

290        'd_model': 128,
291        'transformer.ffn.d_ff': 256,
292        'transformer.n_heads': 8,
293        'transformer.n_layers': 6,

使用 Noam 优化器

296        'optimizer.optimizer': 'Noam',
297        'optimizer.learning_rate': 1.,
298    })

设置用于保存和加载的模型

301    experiment.add_pytorch_models({'model': conf.model})

开始实验

304    with experiment.start():

跑步训练

306        conf.run()

310if __name__ == '__main__':
311    main()