压缩变压器实验

这是一个带注释的 PyTorch 实验,用于训练压缩变压器模型。

11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15
16from labml import experiment, tracker, monit, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
24    CompressiveTransformerLayer, Conv1dCompression
27class CompressedMemory(NamedTuple):
28    mem: List[torch.Tensor]
29    c_mem: List[torch.Tensor]

自动回归模型

32class AutoregressiveModel(Module):
37    def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
38        super().__init__()

令牌嵌入模块

40        self.src_embed = nn.Embedding(n_vocab, d_model)

变压器

42        self.transformer = transformer

最后一层

44        self.generator = nn.Linear(d_model, n_vocab)

口罩

46        self.mask_x = None
47        self.mask_mem = None
49    def forward(self, x: torch.Tensor, mem: CompressedMemory):

获取内存和压缩内存

51        if mem is not None:
52            mem, c_mem = mem.mem, mem.c_mem
53        else:
54            mem = []
55            c_mem = []

内存和压缩内存的总长度(用于掩码)

58        m_len = len(mem[0]) if mem else 0
59        if c_mem:
60            m_len += len(c_mem[0])

为令牌创建后续掩码

63        if self.mask_x is None or self.mask_x.shape[0] < len(x):
64            from labml_nn.transformers.utils import subsequent_mask
65            self.mask_x = subsequent_mask(len(x)).to(x.device)

为内存创建一个全一(完全可见性)掩码

67        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
68            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

如果有内存,则连接掩码

71        if m_len:
72            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

否则,仅使用后续的掩码

74        else:
75            mask = self.mask_x[:len(x), :len(x)]

令牌嵌入

78        x = self.src_embed(x)

用它穿过变压器

80        res, mem = self.transformer(x, mem, c_mem, mask)

生成下一个令牌的日志

82        res = self.generator(res)

84        return res, mem

配置

当我们开始实验时,默认配置可以而且将会被覆盖。

87class Configs(NLPAutoRegressionConfigs):
94    model: AutoregressiveModel

令牌嵌入大小

97    d_model: int = 128

注意头数量

99    heads: int = 4

辍学概率

101    dropout: float = 0.0

FFN 隐藏层中的要素数量

103    d_ff: int = 256

变压器层数

105    n_layers: int = 6

要保留的记忆数量

107    mem_len: int = 8

状态模块用于在训练和验证之间切换时保持记忆

109    memory = SimpleStateModule()

注意力重建损失

111    attention_reconstruction_loss: AttentionReconstructionLoss

压缩率

113    compression_rate: int = 4

压缩的内存长度

115    c_mem_len: int = 128
117    def init(self):

设置跟踪器配置

119        tracker.set_scalar("accuracy.*", True)
120        tracker.set_scalar("loss.*", True)

不要在终端中打印注意力重建损失

122        tracker.set_scalar("ar_loss.*", False)

向日志模块输出添加钩子

124        hook_model_outputs(self.mode, self.model, 'model')

这将使精度指标统计数据和记忆分开,以便训练和验证。

126        self.state_modules = [self.accuracy, self.memory]

连接新记忆并压缩最古老的记忆。

128    @torch.no_grad()
129    def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
130            -> Tuple[CompressedMemory, List[torch.Tensor]]:

如果配置指定不使用内存

136        if self.mem_len == 0 and self.c_mem_len == 0:
137            return CompressedMemory([], []), []

获取内存和压缩内存

140        if mem is not None:
141            mem, c_mem = mem.mem, mem.c_mem
142        else:
143            mem, c_mem = [], []

将新记忆与旧记忆连接起来

146        if mem:
147            mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
148        else:
149            mem = new_mem

如果记忆多于最早的记忆,则压缩最早的记忆mem_len

152        if len(mem[0]) > self.mem_len:

计算要制作的压缩记忆的数量,其中是我们拥有的记忆数量,是我们维护的最大记忆数(mem_len )。

156            n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate

要压缩的内存数量

158            n_old = n_c_mem * self.compression_rate

用于保存每层需要压缩的内存的列表。

160            mem_to_compress = []

一个列表,用于保存每层未被压缩的记忆。

162            uncompressed_mem = []

遍历每层的记忆。

164            for m in mem:

在以下位置拆分记忆

166                cm, m = torch.split(m, [n_old, len(m) - n_old])

收集记忆进行压缩

168                mem_to_compress.append(cm)

收集剩余的记忆

170                uncompressed_mem.append(m)

更新记忆

172            mem = uncompressed_mem

压缩记忆

175            new_c_mem = []
176            for i, layer in enumerate(self.model.transformer.layers):
177                new_c_mem.append(layer.compress(mem_to_compress[i]))

将新压缩的存储器与旧的压缩存储器连接起来

180            if c_mem:
181                c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]

如果没有旧的压缩记忆

183            else:
184                c_mem = new_c_mem

截断旧的记忆

187            if len(c_mem[0]) > self.c_mem_len:
188                c_mem = [m[-self.c_mem_len:] for m in c_mem]

如果内存数量少于mem_len

190        else:
191            mem_to_compress = []

返回被压缩的记忆和记忆。重建损失计算需要被压缩的记忆。

195        return CompressedMemory(mem, c_mem), mem_to_compress

培训/验证步骤

197    def step(self, batch: any, batch_idx: BatchIndex):

将数据移动到设备

203        data, target = batch[0].to(self.device), batch[1].to(self.device)

在训练模式下更新全局步长(处理的令牌数)

206        if self.mode.is_train:
207            tracker.add_global_step(data.shape[0] * data.shape[1])

是否捕获模型输出

210        with self.mode.update(is_log_activations=batch_idx.is_last):

获得回忆

212            mem = self.memory.get()

运行模型

214            output, new_mem = self.model(data, mem)

合并和压缩内存

216            mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)

更新记忆

218            self.memory.set(mem)

计算和记录交叉熵损失

221        loss = self.loss_func(output, target)
222        tracker.add("loss.", loss)

如果在此步骤中记忆被压缩,则计算注意力重建损失

225        if mem_to_compress:

引起注意重建损失

227            ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)

追踪注意力重建损失

229            tracker.add("ar_loss.", ar_loss)

将注意力重建损失增加到损失

231            loss = loss + ar_loss

计算和记录精度

234        self.accuracy(output, target)
235        self.accuracy.track()

训练模型

238        if self.mode.is_train:

计算梯度

240            loss.backward()

剪辑渐变

242            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

采取优化器步骤

244            self.optimizer.step()

记录每个纪元最后一批的模型参数和梯度

246            if batch_idx.is_last:
247                tracker.add('model', self.model)

清除渐变

249            self.optimizer.zero_grad()

保存跟踪的指标

252        tracker.save()

采样功能可在训练时定期生成样本

254    def sample(self):

启动提示

260        prompt = self.prompt

收集输出以进行打印

262        log = [(prompt, Text.subtle)]

记忆

264        mem = CompressedMemory([], [])

样本 25 个代币

266        for i in monit.iterate('Sample', 25):

将提示符号化

268            data = self.text.text_to_i(prompt).unsqueeze(-1)

移至设备

270            data = data.to(self.device)

获取模型输出

272            output, new_mem = self.model(data, mem)

获取模型预测(贪婪)

274            output = output.argmax(dim=-1).squeeze(1)

将预测添加到提示符中

276            prompt += self.prompt_separator + self.text.itos[output[-1]]

在下一次迭代中只喂最后一个角色进行建模,其余部分将作为记忆进去

278            prompt = prompt[-1:]

添加日志记录的预测

280            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

更新和压缩内存

282            mem, _ = self.merge_compress_memory(mem, new_mem)

打印采样输出

285        logger.log(log)

初始化自回归模型

288@option(Configs.model)
289def autoregressive_model(c: Configs):
293    from labml_nn.transformers.xl import RelativeMultiHeadAttention
294    from labml_nn.transformers.feed_forward import FeedForward
295    m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
296        CompressiveTransformerLayer(d_model=c.d_model,
297                                    self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
298                                    feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
299                                    dropout_prob=c.dropout,
300                                    compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
301    return m.to(c.device)

初始化注意力重建损失

304@option(Configs.attention_reconstruction_loss)
305def attention_reconstruction_loss(c: Configs):
309    return AttentionReconstructionLoss(c.model.transformer.layers)

运行实验

312def main():

创建实验

317    experiment.create(name="compressive_transformer", comment='')

创建配置

319    conf = Configs()

装载配置

321    experiment.configs(conf,

要覆盖的配置字典

323                       {'tokenizer': 'character',
324                        'text': 'tiny_shakespeare',
325                        'optimizer.learning_rate': 2.5e-4,
326                        'optimizer.optimizer': 'AdamW',
327                        'prompt': 'It is',
328                        'prompt_separator': '',
329
330                        'train_loader': 'sequential_train_loader',
331                        'valid_loader': 'sequential_valid_loader',
332
333                        'seq_len': 8,
334                        'mem_len': 8,
335                        'epochs': 128,
336                        'batch_size': 32,
337                        'inner_iterations': 25,
338                        'compression_rate': 2,
339                        })

设置用于保存和加载的模型

342    experiment.add_pytorch_models({'model': conf.model})

开始实验

345    with experiment.start():

TrainValidConfigs.run

347        conf.run()

351if __name__ == '__main__':
352    main()