变压器 XL 实验

这是一个带注释的 PyTorch 实验,用于训练变压器 xl 模型。

11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer

自动回归模型

26class AutoregressiveModel(Module):
31    def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32        super().__init__()

令牌嵌入模块

34        self.src_embed = nn.Embedding(n_vocab, d_model)

变压器

36        self.transformer = transformer

最后一层

38        self.generator = nn.Linear(d_model, n_vocab)

口罩

40        self.mask_x = None
41        self.mask_mem = None
43    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):

内存的长度

45        m_len = len(mem[0]) if mem else 0

为令牌创建后续掩码

47        if self.mask_x is None or self.mask_x.shape[0] < len(x):
48            from labml_nn.transformers.utils import subsequent_mask
49            self.mask_x = subsequent_mask(len(x)).to(x.device)

为内存创建一个全一(完全可见性)掩码

51        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

如果有内存,则连接掩码

55        if m_len:
56            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

否则,请使用后续的掩码

58        else:
59            mask = self.mask_x[:len(x), :len(x)]

令牌嵌入

62        x = self.src_embed(x)

用它穿过变压器

64        res, mem = self.transformer(x, mem, mask)

生成下一个令牌的日志

66        res = self.generator(res)

68        return res, mem

配置

当我们开始实验时,默认配置可以而且将会被覆盖

71class Configs(NLPAutoRegressionConfigs):
78    model: AutoregressiveModel

令牌嵌入大小

81    d_model: int = 128

注意头数量

83    heads: int = 4

辍学概率

85    dropout: float = 0.0

FFN 隐藏层中的要素数量

87    d_ff: int = 256

变压器层数

89    n_layers: int = 6

要保留的记忆数量

91    mem_len: int = 128

状态模块用于在训练和验证之间切换时保持记忆

93    memory = SimpleStateModule()
95    def init(self):

设置跟踪器配置

97        tracker.set_scalar("accuracy.*", True)
98        tracker.set_scalar("loss.*", True)

向日志模块输出添加钩子

100        hook_model_outputs(self.mode, self.model, 'model')

这将使精度指标统计数据和记忆分开,以便训练和验证。

102        self.state_modules = [self.accuracy, self.memory]

连接记忆并删除旧记忆以最大限度地保留内mem_len 存。

104    def merge_memory(self, old_mem, new_mem):

如果配置为不使用内存

111        if self.mem_len == 0:
112            return []

与旧内存串联

115        if old_mem:
116            mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117        else:
118            mem = new_mem

截断旧的记忆

121        if len(mem[0]) > self.mem_len:
122            mem = [m[-self.mem_len:] for m in mem]

125        return mem

培训/验证步骤

127    def step(self, batch: any, batch_idx: BatchIndex):

将数据移动到设备

133        data, target = batch[0].to(self.device), batch[1].to(self.device)

在训练模式下更新全局步长(处理的令牌数)

136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])

是否捕获模型输出

140        with self.mode.update(is_log_activations=batch_idx.is_last):

获得回忆

142            mem = self.memory.get()

运行模型

144            output, new_mem = self.model(data, mem)

合并内存

146            mem = self.merge_memory(mem, new_mem)

更新记忆

148            self.memory.set(mem)

计算和记录交叉熵损失

151        loss = self.loss_func(output, target)
152        tracker.add("loss.", loss)

计算和记录精度

155        self.accuracy(output, target)
156        self.accuracy.track()

训练模型

159        if self.mode.is_train:

计算梯度

161            loss.backward()

剪辑渐变

163            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

采取优化器步骤

165            self.optimizer.step()

记录每个纪元最后一批的模型参数和梯度

167            if batch_idx.is_last:
168                tracker.add('model', self.model)

清除渐变

170            self.optimizer.zero_grad()

保存跟踪的指标

173        tracker.save()

采样功能可在训练时定期生成样本

175    def sample(self):

启动提示

181        prompt = self.prompt

收集输出以进行打印

183        log = [(prompt, Text.subtle)]

记忆

185        mem = []

样本 25 个代币

187        for i in monit.iterate('Sample', 25):

将提示符号化

189            data = self.text.text_to_i(prompt).unsqueeze(-1)

移至设备

191            data = data.to(self.device)

获取模型输出

193            output, new_mem = self.model(data, mem)

获取模型预测(贪婪)

195            output = output.argmax(dim=-1).squeeze(1)

将预测添加到提示符中

197            prompt += self.prompt_separator + self.text.itos[output[-1]]

在下一次迭代中只喂最后一个角色进行建模,其余部分将作为记忆进去

199            prompt = prompt[-1:]

添加日志记录的预测

201            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

更新内存

203            mem = self.merge_memory(mem, new_mem)

打印采样输出

206        logger.log(log)

初始化自回归模型

209@option(Configs.model)
210def autoregressive_model(c: Configs):
214    from labml_nn.transformers.xl import RelativeMultiHeadAttention
215    from labml_nn.transformers.feed_forward import FeedForward
216    m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217        TransformerXLLayer(d_model=c.d_model,
218                           self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219                           feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220                           dropout_prob=c.dropout), c.n_layers))
221    return m.to(c.device)

运行实验

224def main():

创建实验

229    experiment.create(name="transformer_xl", comment='')

创建配置

231    conf = Configs()

装载配置

233    experiment.configs(conf,

要覆盖的配置字典

235                       {'tokenizer': 'character',
236                        'text': 'tiny_shakespeare',
237                        'optimizer.learning_rate': 1.,
238                        'optimizer.optimizer': 'Noam',
239                        'prompt': 'It is',
240                        'prompt_separator': '',
241
242                        'train_loader': 'sequential_train_loader',
243                        'valid_loader': 'sequential_valid_loader',
244
245                        'seq_len': 2,
246                        'mem_len': 32,
247                        'epochs': 128,
248                        'batch_size': 32,
249                        'inner_iterations': 25,
250                        })

设置用于保存和加载的模型

253    experiment.add_pytorch_models({'model': conf.model})

开始实验

256    with experiment.start():

TrainValidConfigs.run

258        conf.run()

262if __name__ == '__main__':
263    main()