论文《线性变形金刚是秘密的快速加权记忆系统》发现了线性自注意力和快速加权系统之间的相似之处,并据此修改了自我注意力更新规则。它还引入了一个更简单但有效的内核函数。
作者提供了该论文的正式实现,包括他们在论文中比较的其他变体。
考虑一个输入序列或长度,每个步骤都是大小向量;即快速权重模型在每个步骤生成权重矩阵以生成输出,
是外积 (),其中两个向量的元素相互乘以得出矩阵。是一个激活函数。并且是可训练的权重(参数)。是在每一步生成的快速权重。
原始变压器的自我注意力是,(为了清楚起见,省略了)
在哪里
线性化自我注意力背后的想法是用不同的内核替换 softmax 内核,这样我们就可以更快地计算出自我注意力函数的分母:
这给了
使用和,我们可以有效地计算它们:
这与快速称重非常相似。
本文介绍了一种新的线性注意力投影函数、一种新的更新规则和标准化变更
以下是用于在 Tiny Shakespeare 数据集上训练快速权重转换器的训练代码和一本笔记本。
95import torch
96from torch import nn
97
98from labml_helpers.module import Module
99from labml_nn.transformers.feed_forward import FeedForward
100from labml_nn.transformers.mha import PrepareForMultiHeadAttention
101from labml_nn.utils import clone_module_list
这是本文中引入的新投影函数。DPFP 从维度到维度的投影,其中是一个超参数。
where 是和的串联,用于给出大小为、和。是 vector 的第-th 个元素,如果大于 vector 中的元素数量,则会滚动。
基本上,它通过乘以移位的元素来创建一个新的向量。
这将生成稀疏投影(只有少数元素为非零)和正交投影(对于大多数,除非和非常相似。
本文介绍了一个简单的规范化,
检查论文的推导。
104class DPFP(Module):
nu
是超参数。eps
是用于确保归一化时没有被零除的小值。138 def __init__(self, nu: int = 1, eps: float = 1e-6):
143 super().__init__()
144 self.nu = nu
145 self.relu = nn.ReLU()
146 self.eps = eps
148 def forward(self, k: torch.Tensor):
得到
150 k = self.dpfp(k)
规范化依据
152 return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
154 def dpfp(self, k: torch.Tensor):
159 x = self.relu(torch.cat([k, -k], dim=-1))
移位然后滚过去,得到
162 x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, self.nu + 1)]
连接以获取
165 x_rolled = torch.cat(x_rolled, dim=-1)
串联的副本
167 x_repeat = torch.cat([x] * self.nu, dim=-1)
乘以它们,
173 return x_repeat * x_rolled
本文介绍了一种新的计算更新规则。模型首先检索与键配对的当前值。然后存储检索到的值和输入的组合。
where 是一个可训练的参数,是 sigmoid 函数。
请注意,我们不需要规范化项,因为它是规范化的。
176class FastWeightsAttention(Module):
204 def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
205 super().__init__()
每头特征数
208 self.d_k = d_model // heads
头数
210 self.heads = heads
这些改变了query
,key
和value
多头的注意力.
213 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
214 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
215 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
每个头的插值权重函数
218 self.interpolation_weight = nn.Sequential(
219 PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
220 nn.Sigmoid()
221 )
224 self.phi = phi
输出层
227 self.output = nn.Linear(d_model, d_model)
辍学
229 self.dropout = nn.Dropout(dropout_prob)
231 def forward(self, x: torch.Tensor):
获取步数
233 seq_len = x.shape[0]
适用于所有台阶和头部
235 query = self.phi(self.query(x))
适用于所有台阶和头部
237 key = self.phi(self.key(x))
适用于所有台阶和头部
239 value = self.value(x)
适用于所有台阶和头部
241 beta = self.interpolation_weight(x)
244 weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))
存储输出的列表
246 outputs = []
遍历各个步骤
249 for i in range(seq_len):
251 value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])
256 weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])
259 y = torch.einsum('bhvk,bhk->bhv', weights, query[i])
合并多个头部并追加到outputs
262 outputs.append(y.reshape(y.shape[0], -1))
将每一步的输出堆叠到单个张量中
265 x = torch.stack(outputs)
输出层
268 return self.output(x)
这是一个结合了自我关注和前馈网络的通用变压器层。
271class FastWeightsAttentionTransformerLayer(Module):
275 def __init__(self, *,
276 d_model: int,
277 attn: FastWeightsAttention,
278 feed_forward: FeedForward,
279 dropout_prob: float):
280 super().__init__()
变压器尺寸
282 self.size = d_model
快速举重注意模块
284 self.attn = attn
前馈网络
286 self.feed_forward = feed_forward
辍学层
288 self.dropout = nn.Dropout(dropout_prob)
归一化层
291 self.norm_self_attn = nn.LayerNorm([d_model])
292 self.norm_ff = nn.LayerNorm([d_model])
294 def forward(self, x: torch.Tensor):
计算快速权重自我注意
296 attn = self.attn(x)
添加自我关注的结果
298 x = x + self.dropout(attn)
标准化以进行前馈
301 z = self.norm_ff(x)
通过前馈网络
303 ff = self.feed_forward(z)
将前馈结果添加回来
305 x = x + self.dropout(ff)
308 return x
这是具有多个变压器层的通用变压器模块
311class FastWeightsAttentionTransformer(Module):
315 def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
316 super().__init__()
制作变压器层的副本
318 self.layers = clone_module_list(layer, n_layers)
最终归一化层
320 self.norm = nn.LayerNorm([layer.size])
322 def forward(self, x: torch.Tensor):
323 for i, layer in enumerate(self.layers):
获取图层输出
325 x = layer(x)
规范化输出
328 return self.norm(x)