一款无注意的变形金刚

这是 PyTorch 对无注意力的变形金刚》一文的实现。

本文用一种新的高效运算取代了自我注意力层,该运算的存储复杂度为,其中是序列长度,是嵌入的维度。

本文介绍了 AFT 以及 AFT-Local 和 AFT-conv。这里我们实现了 aft-Local,它关注自回归模型中的 cloby 代币。

免注意变形金刚

A@@

FT(类似于 MHA)首先将嵌入转换为具有学习权重的查询、键和值张量。每个位置的输出都是通过以下运算计算的。

,其中是元素乘积,是非线性(sigmoid),是成对位置偏差的学习矩阵。

这意味着我们取值的加权平均值并将其乘以查询。这样就无需计算 MHA 所需的注意力矩阵,从而降低了内存需求。

AFT 本地

AFT Local 仅在本地应用学习的配对位置偏差:

,其中是本地窗口大小。

尽管不在本地窗口之外,但 AFT 操作仍使用来自其他区域的键值对。这与本地转换器不同,本地窗口之外的嵌入完全不可见。

以下是 AFT Local 模型的训练代码

59from typing import Optional
60
61import torch
62from torch import nn
63
64from labml_helpers.module import Module

AFT 本地操作

在哪里,

67class AFTLocal(Module):
  • d_modelquerykeyvalue 向量中的要素数。
  • seq_len
  • local_window_size 是本地窗口大小
  • bias 是是否为和的变换设置偏置参数
86    def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
94        super().__init__()

本地窗口大小

97        self.local_window_size = local_window_size

这些变换querykeyvalue 向量。

99        self.query = nn.Linear(d_model, d_model, bias=bias)
100        self.key = nn.Linear(d_model, d_model, bias=bias)
101        self.value = nn.Linear(d_model, d_model, bias=bias)

成对位置偏差

103        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)

面具用于

105        self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)

激活

107        self.activation = nn.Sigmoid()

输出层

109        self.output = nn.Linear(d_model, d_model)

创建局部蒙版

这会为以下对象创建遮罩

111    @staticmethod
112    def create_local_mask(seq_len, local_window_size):

初始化为一

128        local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)

130        local_mask = torch.tril(local_mask, local_window_size - 1)

132        local_mask = torch.triu(local_mask, -(local_window_size - 1))

135        return local_mask

query keyvalue 是存储查询的令牌嵌入集合的张量。它们有形状[seq_len, batch_size, d_model]

mask 有形状[seq_len, seq_len, batch_size]mask[i, j, b] 指示是否为批量查询b ,位置处的查询i 有权访问位置处的键值j

137    def forward(self, *,
138                query: torch.Tensor,
139                key: torch.Tensor,
140                value: torch.Tensor,
141                mask: Optional[torch.Tensor] = None):

querykey 并且value 有形状[seq_len, batch_size, d_model]

153        seq_len, _, _ = query.shape
154
155        if mask is not None:

mask 有形状[seq_len_q, seq_len_k, batch_size] ,其中第一个维度是查询维度。如果查询维度等于它将被广播。

159            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
160            assert mask.shape[1] == key.shape[0]
161            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

转换查询、键和值嵌入

164        query = self.query(query)
165        key = self.key(key)
166        value = self.value(value)

得到

使用口罩

179        pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
180        pos_bias = pos_bias.unsqueeze(-1)
181        pos_bias.masked_fill_(~mask, float('-inf'))

我们分别计算然后进行矩阵乘法。为了清楚起见,我们使用 einsum。

我们在计算指数之前减去和,以稳定softmax的计算。

if 大变大,计算变得不稳定。在计算分子和分母的指数之前减去一个常数将抵消。并且可以帮助稳定计算。所以我们减去以稳定计算。

203        max_key = key.max(dim=0, keepdims=True)[0]
204        max_pos_bias = pos_bias.max(dim=1,  keepdims=True)[0]

207        exp_key = torch.exp(key - max_key)

209        exp_pos_bias = torch.exp(pos_bias - max_pos_bias)

分子部分

212        num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)

分母部分

214        den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)

输出

219        y = self.activation(query) * num / den

输出层

222        return self.output(y)

测试局部掩码

225def _test_local_mask():
229    from labml.logger import inspect
230    inspect(AFTLocal.create_local_mask(10, 4))

234if __name__ == '__main__':
235    _test_local_mask()