这是 PyTorch 中用于远程序列建模的压缩转换器的实现。
这是 Transfor mer XL 的扩展,它压缩了过去的记忆以提供更长的注意力范围。也就是说,最远的内存被压缩到内存中,压缩率在哪里。
压缩操作定义为。本文引入了多种选择,我们只实现了一维卷积,这似乎可以给出最佳结果。每个层都有单独的压缩操作,其中是层号。
由于使用 BPTT 训练压缩需要维护非常大的计算图(许多时间步长),因此该论文提出了自动编码损失和注意力重建损失。自动编码丢失对压缩存储器中的原始存储器进行解码并计算损失。注意力重建损失计算压缩内存和未压缩内存上的多头注意力结果,并得出两者之间的平均平方误差。我们在这里实现了后者,因为它可以提供更好的结果。
该实现使用层前标准化,而论文使用层后归一化。前层范数在 FFN 和自我注意力之前对层进行范数,并且残差连接中的直通未标准化。在标准变压器设置中,这应该更稳定。
以下是用于在 Tiny Shakespeare 数据集上训练压缩变压器模型的训练代码和笔记本。
53from typing import Optional, List
54
55import torch
56import torch.nn.functional as F
57from torch import nn
58
59from labml_helpers.module import Module, TypedModuleList
60from labml_nn.transformers.feed_forward import FeedForward
61from labml_nn.transformers.mha import PrepareForMultiHeadAttention
62from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
63from labml_nn.utils import clone_module_list
compression_rate
d_model
是嵌入的大小74 def __init__(self, compression_rate: int, d_model: int):
79 super().__init__()
80 self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)
mem
有形状[seq_len, batch, d_model]
82 def forward(self, mem: torch.Tensor):
排列的维度,mem
这样我们就可以在卷积层中运行它。卷积层接受以下形式[batch, features, sequence]
89 mem = mem.permute(1, 2, 0)
通过卷积层运行压缩内存来获取压缩内存
91 c_mem = self.conv(mem)
排列回原状[seq_len, batch, d_model]
93 return c_mem.permute(2, 0, 1)
96class CompressiveTransformerLayer(Module):
102 def __init__(self, *,
103 d_model: int,
104 self_attn: RelativeMultiHeadAttention,
105 feed_forward: FeedForward,
106 dropout_prob: float,
107 compress: Conv1dCompression):
115 super().__init__()
116 self.compress = compress
117 self.size = d_model
118 self.self_attn = self_attn
119 self.feed_forward = feed_forward
120 self.dropout = nn.Dropout(dropout_prob)
121 self.norm_self_attn = nn.LayerNorm([d_model])
122 self.norm_ff = nn.LayerNorm([d_model])
124 def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
如果没有内存,则返回令牌嵌入
133 if mem is None:
134 return z
如果有压缩的内存,则将其与内存连接起来
137 if c_mem is not None:
138 mem = torch.cat((c_mem, mem), dim=0)
通过规范化层运行内存
141 mem = self.norm_self_attn(mem)
连接规范化内存和规范化令牌嵌入
143 return torch.cat((mem, z), dim=0)
x
是形状的令牌级特征向量的张量[seq_len, batch_size, d_model]
mem
是过去令牌级别形状特征向量(内存)的张量[mem_len, batch_size, d_model]
c_mem
是压缩内存的张量[c_mem_len, batch_size, d_model]
mask
是形状的矩阵[seq_len, c_mem_len + mem_len + seq_len, batch_size]
或[seq_len, c_mem_len + mem_len + seq_len, 1]
。mask[i, j]
如果 tokeni
可以在处看到令牌,则为 truej
。145 def forward(self, *,
146 x: torch.Tensor,
147 mem: Optional[torch.Tensor],
148 c_mem: Optional[torch.Tensor],
149 mask: torch.Tensor):
在进行自我注意之前对向量进行归一化
159 z = self.norm_self_attn(x)
规范化并连接内存和压缩内存
161 m_z = self.concat_memory(z, mem, c_mem)
注意
163 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
添加关注结果
165 x = x + self.dropout(self_attn)
标准化以进行前馈
168 z = self.norm_ff(x)
通过前馈网络
170 ff = self.feed_forward(z)
将前馈结果添加回来
172 x = x + self.dropout(ff)
175 return x
178class CompressiveTransformer(Module):
185 def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
186 super().__init__()
制作变压器层的副本
188 self.layers = clone_module_list(layer, n_layers)
最终归一化层
190 self.norm = nn.LayerNorm([layer.size])
x
是嵌入形状向量的令牌的张量[seq_len, batch_size, d_model]
mem
是过去令牌级别的张量列表,每个层的形状[mem_len, batch_size, d_model]
向量特征c_mem
是每层压缩内存[c_mem_len, batch_size, d_model]
的张量列表mask
是掩码矩阵192 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):
用于存储令牌级特征向量的列表,这些向量将成为下一个连续批次的记忆。
203 new_mem = []
穿过每个变压器层
205 for i, layer in enumerate(self.layers):
添加到特征向量列表中
207 new_mem.append(x.detach())
记忆
209 m = mem[i] if mem else None
压缩内存
211 cm = c_mem[i] if c_mem else None
穿过变压器 XL 层
213 x = layer(x=x, mem=m, c_mem=cm, mask=mask)
最后,对向量进行归一化
215 return self.norm(x), new_mem
注意力重建损失使用未压缩的内存和压缩的内存重现自我注意力输出,并计算两者之间的均方误差。它在没有位置编码的情况下做到这一点。
当计算和训练具有注意力重建损失的压缩函数时,所有参数都将被冻结。这包括标准化后的键/值投影和偏差/缩放。
由于此损失可以独立于模型的交叉熵损失进行计算,因此您可以使用单独的仅更新优化器。但是,我们使用相同的优化器进行更新,因此在计算注意力重建损失时,我们会分离除梯度计算之外的所有其他参数。
218class AttentionReconstructionLoss:
layers
是压缩变压器层的列表
236 def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
240 self.layers = layers
241 self.loss_func = nn.MSELoss()
这是 “prepareFormultiHeadAttention” 的重新实现,其中投影是使用与梯度计算分离的参数完成的。
pmha
是 “prepareFormultiHeadAttion”x
是带有令牌嵌入的张量243 def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):
除嵌入尺寸之外的输入形状;[seq_len, batch_size]
。
253 head_shape = x.shape[:-1]
分离投影权重和偏差
256 weight = pmha.linear.weight.detach()
257 bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None
线性变换
259 x = F.linear(x, weight, bias)
将最后一个维度拆分成头部
262 x = x.view(*head_shape, pmha.heads, pmha.d_k)
输出具有形状[seq_len, batch_size, heads, d_k]
或[batch_size, d_model]
265 return x
这是 “多头注意” 的重新实现,它调用prepare_for_attn
而不是 “prep areFormultiHeadAttention” 来分离投影参数。
267 def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
计算查询、键和值预测
274 query = self.prepare_for_attn(layer.query, query)
275 key = self.prepare_for_attn(layer.key, key)
276 value = self.prepare_for_attn(layer.value, value)
计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]
。
280 scores = torch.einsum('ibhd,jbhd->ijbh', query, key)
音阶分数
283 scores *= layer.scale
关注按键序列维度
287 attn = layer.softmax(scores)
乘以值
291 return torch.einsum("ijbh,jbhd->ibhd", attn, value)
分离移位和缩放参数的情况下执行图层归一化。
293 def norm(self, ln: nn.LayerNorm, x: torch.Tensor):
分离 shift (bias
) 和缩放 (weight
) 参数
299 weight = ln.weight.detach() if ln.weight is not None else None
300 bias = ln.bias.detach() if ln.bias is not None else None
层规范化
303 return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)
这将计算一层的损失
305 def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):
分离令牌嵌入和内存。
311 h = h.detach()
312 mem = mem.detach()
使用压缩内存。的参数是唯一未从梯度计算中分离出来的参数。
316 c_mem = layer.compress(mem)
规范化嵌入和记忆
319 h = self.norm(layer.norm_self_attn, h)
320 mem = self.norm(layer.norm_self_attn, mem)
321 c_mem = self.norm(layer.norm_self_attn, c_mem)
使用未压缩的内存计算注意力
324 attn_mem = self.attn(layer.self_attn, h, mem, mem)
使用压缩内存计算注意力
326 attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)
计算均方误差
329 return self.loss_func(attn_cmem, attn_mem)
331 def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):
计算每层的损失
333 losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]
损失总和
335 return sum(losses)