快速称重变压器

论文《线性变形金刚是秘密的快速加权记忆系统》发现了线性自注意力和快速加权系统之间的相似之处,并据此修改了自我注意力更新规则。它还引入了一个更简单但有效的内核函数。

作者提供了该论文的正式实现,包括他们在论文中比较的其他变体。

快速举重

考虑一个输入序列或长度,每个步骤都是大小向量;即快速权重模型在每个步骤生成权重矩阵以生成输出

是外积 (),其中两个向量的元素相互乘以得出矩阵。是一个激活函数。并且是可训练的权重(参数)。是在每一步生成的快速权重。

线性自我注意力

原始变压器的自我注意力是,(为了清楚起见,省略了)

在哪里

线性化自我注意力背后的想法是用不同的内核替换 softmax 内核,这样我们就可以更快地计算出自我注意力函数的分母:

这给了

使用,我们可以有效地计算它们:

这与快速称重非常相似。

本文介绍了一种新的线性注意力投影函数、一种新的更新规则和标准化变更

以下是用于在 Tiny Shakespeare 数据集上训练快速权重转换器的训练代码和一本笔记本。

Open In Colab

95import torch
96from torch import nn
97
98from labml_helpers.module import Module
99from labml_nn.transformers.feed_forward import FeedForward
100from labml_nn.transformers.mha import PrepareForMultiHeadAttention
101from labml_nn.utils import clone_module_list

确定性无参数项目 (DPFP)

这是本文中引入的新投影函数。DPFP 从维度到维度的投影,其中是一个超参数。

where和的串联,用于给出大小为、和是 vector 的第-th 个元素,如果大于 vector 中的元素数量,则会滚动

基本上,它通过乘以移位的元素来创建一个新的向量

这将生成稀疏投影(只有少数元素为非零)和正交投影(对于大多数,除非非常相似。

规范化

本文介绍了一个简单的规范化

检查论文的推导。

104class DPFP(Module):
  • nu 是超参数
  • eps 是用于确保归一化时没有被零除的小值。
138    def __init__(self, nu: int = 1, eps: float = 1e-6):
143        super().__init__()
144        self.nu = nu
145        self.relu = nn.ReLU()
146        self.eps = eps
148    def forward(self, k: torch.Tensor):

得到

150        k = self.dpfp(k)

规范化依据

152        return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)

154    def dpfp(self, k: torch.Tensor):

159        x = self.relu(torch.cat([k, -k], dim=-1))

移位然后滚过去,得到

162        x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, self.nu + 1)]

连接以获取

165        x_rolled = torch.cat(x_rolled, dim=-1)

串联的副本

167        x_repeat = torch.cat([x] * self.nu, dim=-1)

乘以它们,

173        return x_repeat * x_rolled

快速举重注意

本文介绍了一种新的计算更新规则。模型首先检索与键配对的当前值。然后存储检索到的值和输入的组合

where 是一个可训练的参数,是 sigmoid 函数。

请注意,我们不需要规范化项,因为它是规范化的。

176class FastWeightsAttention(Module):
204    def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
205        super().__init__()

每头特征数

208        self.d_k = d_model // heads

头数

210        self.heads = heads

这些改变了query ,keyvalue 多头的注意力.

213        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
214        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
215        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

每个头的插值权重函数

218        self.interpolation_weight = nn.Sequential(
219            PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
220            nn.Sigmoid()
221        )

224        self.phi = phi

输出层

227        self.output = nn.Linear(d_model, d_model)

辍学

229        self.dropout = nn.Dropout(dropout_prob)
231    def forward(self, x: torch.Tensor):

获取步数

233        seq_len = x.shape[0]

适用于所有台阶和头部

235        query = self.phi(self.query(x))

适用于所有台阶和头部

237        key = self.phi(self.key(x))

适用于所有台阶和头部

239        value = self.value(x)

适用于所有台阶和头部

241        beta = self.interpolation_weight(x)

244        weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))

存储输出的列表

246        outputs = []

遍历各个步骤

249        for i in range(seq_len):

251            value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])

256            weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])

259            y = torch.einsum('bhvk,bhk->bhv', weights, query[i])

合并多个头部并追加到outputs

262            outputs.append(y.reshape(y.shape[0], -1))

将每一步的输出堆叠到单个张量中

265        x = torch.stack(outputs)

输出层

268        return self.output(x)

这是一个结合了自我关注和前馈网络的通用变压器层。

271class FastWeightsAttentionTransformerLayer(Module):
275    def __init__(self, *,
276                 d_model: int,
277                 attn: FastWeightsAttention,
278                 feed_forward: FeedForward,
279                 dropout_prob: float):
280        super().__init__()

变压器尺寸

282        self.size = d_model

快速举重注意模块

284        self.attn = attn

前馈网络

286        self.feed_forward = feed_forward

辍学层

288        self.dropout = nn.Dropout(dropout_prob)

归一化层

291        self.norm_self_attn = nn.LayerNorm([d_model])
292        self.norm_ff = nn.LayerNorm([d_model])
294    def forward(self, x: torch.Tensor):

计算快速权重自我注意

296        attn = self.attn(x)

添加自我关注的结果

298        x = x + self.dropout(attn)

标准化以进行前馈

301        z = self.norm_ff(x)

通过前馈网络

303        ff = self.feed_forward(z)

将前馈结果添加回来

305        x = x + self.dropout(ff)

308        return x

这是具有多个变压器层的通用变压器模块

311class FastWeightsAttentionTransformer(Module):
315    def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
316        super().__init__()

制作变压器层的副本

318        self.layers = clone_module_list(layer, n_layers)

最终归一化层

320        self.norm = nn.LayerNorm([layer.size])
322    def forward(self, x: torch.Tensor):
323        for i, layer in enumerate(self.layers):

获取图层输出

325            x = layer(x)

规范化输出

328        return self.norm(x)