压缩变压器

这是 PyTorch用于远程序列建模的压缩转换器的实现。

这是 Transfor mer XL 的扩展,它压缩了过去的记忆以提供更长的注意力范围。也就是说,最远的内存被压缩到内存中,压缩率在哪里。

压缩操作

压缩操作定义为。本文引入了多种选择,我们只实现了一维卷积,这似乎可以给出最佳结果。每个层都有单独的压缩操作其中是层号。

训练压缩操作

由于使用 BPTT 训练压缩需要维护非常大的计算图(许多时间步长),因此该论文提出了自动编码损失注意力重建损失。自动编码丢失对压缩存储器中的原始存储器进行解码并计算损失。注意力重建损失计算压缩内存和未压缩内存上的多头注意力结果,并得出两者之间的平均平方误差。我们在这里实现了后者,因为它可以提供更好的结果。

该实现使用层前标准化,而论文使用层后归一化。前层范数在 FFN 和自我注意力之前对层进行范数,并且残差连接中的直通未标准化。在标准变压器设置中,这应该更稳定。

以下是用于在 Tiny Shakespeare 数据集上训练压缩变压器模型的训练代码和笔记本。

Open In Colab

53from typing import Optional, List
54
55import torch
56import torch.nn.functional as F
57from torch import nn
58
59from labml_helpers.module import Module, TypedModuleList
60from labml_nn.transformers.feed_forward import FeedForward
61from labml_nn.transformers.mha import PrepareForMultiHeadAttention
62from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
63from labml_nn.utils import clone_module_list

一维卷积压缩

这是一个nn.Conv1d 包含一些张量维度排列的简单包装。

66class Conv1dCompression(Module):
  • compression_rate
  • d_model 是嵌入的大小
74    def __init__(self, compression_rate: int, d_model: int):
79        super().__init__()
80        self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)

mem 有形状[seq_len, batch, d_model]

82    def forward(self, mem: torch.Tensor):

排列的维度,mem 这样我们就可以在卷积层中运行它。卷积层接受以下形式[batch, features, sequence]

89        mem = mem.permute(1, 2, 0)

通过卷积层运行压缩内存来获取压缩内存

91        c_mem = self.conv(mem)

排列回原状[seq_len, batch, d_model]

93        return c_mem.permute(2, 0, 1)

压缩变压器层

这是单个压缩变压器层的实现

96class CompressiveTransformerLayer(Module):
  • d_model 是令牌嵌入的大小
  • self_attn自我关注模块
  • feed_forward前馈模块
  • dropout_prob 是自我关注和 FFN 后退学的概率
  • compress 是压缩函数
102    def __init__(self, *,
103                 d_model: int,
104                 self_attn: RelativeMultiHeadAttention,
105                 feed_forward: FeedForward,
106                 dropout_prob: float,
107                 compress: Conv1dCompression):
115        super().__init__()
116        self.compress = compress
117        self.size = d_model
118        self.self_attn = self_attn
119        self.feed_forward = feed_forward
120        self.dropout = nn.Dropout(dropout_prob)
121        self.norm_self_attn = nn.LayerNorm([d_model])
122        self.norm_ff = nn.LayerNorm([d_model])

将标准化令牌嵌入与内存和压缩内存连接起来。

  • z 是层规范化令牌嵌入。
  • memc_mem 是内存和压缩内存(未规范化)。
124    def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):

如果没有内存,则返回令牌嵌入

133        if mem is None:
134            return z

如果有压缩的内存,则将其与内存连接起来

137        if c_mem is not None:
138            mem = torch.cat((c_mem, mem), dim=0)

通过规范化层运行内存

141        mem = self.norm_self_attn(mem)

连接规范化内存和规范化令牌嵌入

143        return torch.cat((mem, z), dim=0)
  • x 是形状的令牌级特征向量的张量[seq_len, batch_size, d_model]
  • mem 是过去令牌级别形状特征向量(内存)的张量[mem_len, batch_size, d_model]
  • c_mem 是压缩内存的张量[c_mem_len, batch_size, d_model]
  • mask 是形状的矩阵[seq_len, c_mem_len + mem_len + seq_len, batch_size][seq_len, c_mem_len + mem_len + seq_len, 1]mask[i, j] 如果 tokeni 可以在处看到令牌,则为 truej
145    def forward(self, *,
146                x: torch.Tensor,
147                mem: Optional[torch.Tensor],
148                c_mem: Optional[torch.Tensor],
149                mask: torch.Tensor):

在进行自我注意之前对向量进行归一化

159        z = self.norm_self_attn(x)

规范化并连接内存和压缩内存

161        m_z = self.concat_memory(z, mem, c_mem)

注意

163        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)

添加关注结果

165        x = x + self.dropout(self_attn)

标准化以进行前馈

168        z = self.norm_ff(x)

通过前馈网络

170        ff = self.feed_forward(z)

将前馈结果添加回来

172        x = x + self.dropout(ff)

175        return x

压缩变压器型号

它由多个压缩变压器层组成

178class CompressiveTransformer(Module):
185    def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
186        super().__init__()

制作变压器层的副本

188        self.layers = clone_module_list(layer, n_layers)

最终归一化层

190        self.norm = nn.LayerNorm([layer.size])
  • x 是嵌入形状向量的令牌的张量[seq_len, batch_size, d_model]
  • mem 是过去令牌级别的张量列表,每个层的形状[mem_len, batch_size, d_model] 向量特征
  • c_mem 是每层压缩内存[c_mem_len, batch_size, d_model] 的张量列表
  • mask 是掩码矩阵
192    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):

用于存储令牌级特征向量的列表,这些向量将成为下一个连续批次的记忆。

203        new_mem = []

穿过每个变压器层

205        for i, layer in enumerate(self.layers):

添加到特征向量列表中

207            new_mem.append(x.detach())

记忆

209            m = mem[i] if mem else None

压缩内存

211            cm = c_mem[i] if c_mem else None

穿过变压器 XL 层

213            x = layer(x=x, mem=m, c_mem=cm, mask=mask)

最后,对向量进行归一化

215        return self.norm(x), new_mem

注意力重建损失

注意力重建损失使用未压缩的内存和压缩的内存重现自我注意力输出,并计算两者之间的均方误差。它在没有位置编码的情况下做到这一点。

当计算和训练具有注意力重建损失的压缩函数时,所有参数都将被冻结。这包括标准化后的键/值投影和偏差/缩放。

由于此损失可以独立于模型的交叉熵损失进行计算,因此您可以使用单独的仅更新优化器。但是,我们使用相同的优化器进行更新,因此在计算注意力重建损失时,我们会分离除梯度计算之外的所有其他参数。

218class AttentionReconstructionLoss:

layers 是压缩变压器层的列表

236    def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
240        self.layers = layers
241        self.loss_func = nn.MSELoss()

这是 “prepareFormultiHeadAttention” 的重新实现,其中投影是使用与梯度计算分离的参数完成的。

  • x 是带有令牌嵌入的张量
  • 243    def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):

    除嵌入尺寸之外的输入形状;[seq_len, batch_size]

    253        head_shape = x.shape[:-1]

    分离投影权重和偏差

    256        weight = pmha.linear.weight.detach()
    257        bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None

    线性变换

    259        x = F.linear(x, weight, bias)

    将最后一个维度拆分成头部

    262        x = x.view(*head_shape, pmha.heads, pmha.d_k)

    输出具有形状[seq_len, batch_size, heads, d_k][batch_size, d_model]

    265        return x
    267    def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):

    计算查询、键和值预测

    274        query = self.prepare_for_attn(layer.query, query)
    275        key = self.prepare_for_attn(layer.key, key)
    276        value = self.prepare_for_attn(layer.value, value)

    计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]

    280        scores = torch.einsum('ibhd,jbhd->ijbh', query, key)

    音阶分数

    283        scores *= layer.scale

    关注按键序列维度

    287        attn = layer.softmax(scores)

    乘以值

    291        return torch.einsum("ijbh,jbhd->ibhd", attn, value)
    在@@

    分离移位和缩放参数的情况下执行图层归一化。

    293    def norm(self, ln: nn.LayerNorm, x: torch.Tensor):

    分离 shift (bias ) 和缩放 (weight ) 参数

    299        weight = ln.weight.detach() if ln.weight is not None else None
    300        bias = ln.bias.detach() if ln.bias is not None else None

    层规范化

    303        return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)

    这将计算一层的损失

    305    def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):

    分离令牌嵌入和内存。

    311        h = h.detach()
    312        mem = mem.detach()

    使用压缩内存。的参数是唯一未从梯度计算中分离出来的参数。

    316        c_mem = layer.compress(mem)

    规范化嵌入和记忆

    319        h = self.norm(layer.norm_self_attn, h)
    320        mem = self.norm(layer.norm_self_attn, mem)
    321        c_mem = self.norm(layer.norm_self_attn, c_mem)

    使用未压缩的内存计算注意力

    324        attn_mem = self.attn(layer.self_attn, h, mem, mem)

    使用压缩内存计算注意力

    326        attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)

    计算均方误差

    329        return self.loss_func(attn_cmem, attn_mem)
    331    def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):

    计算每层的损失

    333        losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]

    损失总和

    335        return sum(losses)