这是 PyTorch 对《无注意力的变形金刚》一文的实现。
本文用一种新的高效运算取代了自我注意力层,该运算的存储复杂度为,其中是序列长度,是嵌入的维度。
本文介绍了 AFT 以及 AFT-Local 和 AFT-conv。这里我们实现了 aft-Local,它关注自回归模型中的 cloby 代币。
FT(类似于 MHA)首先将嵌入转换为具有学习权重的查询、键和值张量。每个位置的输出都是通过以下运算计算的。
,其中是元素乘积,是非线性(sigmoid),是成对位置偏差的学习矩阵。
这意味着我们取值的加权平均值并将其乘以查询。这样就无需计算 MHA 所需的注意力矩阵,从而降低了内存需求。
AFT Local 仅在本地应用学习的配对位置偏差:
,其中是本地窗口大小。
尽管不在本地窗口之外,但 AFT 操作仍使用来自其他区域的键值对。这与本地转换器不同,本地窗口之外的嵌入完全不可见。
59from typing import Optional
60
61import torch
62from torch import nn
63
64from labml_helpers.module import Module
67class AFTLocal(Module):
d_model
是query
、key
和value
向量中的要素数。seq_len
是local_window_size
是本地窗口大小bias
是是否为和的变换设置偏置参数。86 def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
94 super().__init__()
本地窗口大小
97 self.local_window_size = local_window_size
这些变换query
、key
和value
向量。
99 self.query = nn.Linear(d_model, d_model, bias=bias)
100 self.key = nn.Linear(d_model, d_model, bias=bias)
101 self.value = nn.Linear(d_model, d_model, bias=bias)
成对位置偏差
103 self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)
面具用于
105 self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)
激活
107 self.activation = nn.Sigmoid()
输出层
109 self.output = nn.Linear(d_model, d_model)
111 @staticmethod
112 def create_local_mask(seq_len, local_window_size):
初始化为一
128 local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)
做零
130 local_mask = torch.tril(local_mask, local_window_size - 1)
做零
132 local_mask = torch.triu(local_mask, -(local_window_size - 1))
135 return local_mask
query
key
和value
是存储查询、键和值的令牌嵌入集合的张量。它们有形状[seq_len, batch_size, d_model]
。
mask
有形状[seq_len, seq_len, batch_size]
并mask[i, j, b]
指示是否为批量查询b
,位置处的查询i
有权访问位置处的键值j
。
137 def forward(self, *,
138 query: torch.Tensor,
139 key: torch.Tensor,
140 value: torch.Tensor,
141 mask: Optional[torch.Tensor] = None):
query
,key
并且value
有形状[seq_len, batch_size, d_model]
153 seq_len, _, _ = query.shape
154
155 if mask is not None:
mask
有形状[seq_len_q, seq_len_k, batch_size]
,其中第一个维度是查询维度。如果查询维度等于它将被广播。
159 assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
160 assert mask.shape[1] == key.shape[0]
161 assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]
转换查询、键和值嵌入
164 query = self.query(query)
165 key = self.key(key)
166 value = self.value(value)
179 pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
180 pos_bias = pos_bias.unsqueeze(-1)
181 pos_bias.masked_fill_(~mask, float('-inf'))
203 max_key = key.max(dim=0, keepdims=True)[0]
204 max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]
207 exp_key = torch.exp(key - max_key)
209 exp_pos_bias = torch.exp(pos_bias - max_pos_bias)
分子部分
212 num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)
分母部分
214 den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)
输出
219 y = self.activation(query) * num / den
输出层
222 return self.output(y)
测试局部掩码
225def _test_local_mask():
229 from labml.logger import inspect
230 inspect(AFTLocal.create_local_mask(10, 4))
234if __name__ == '__main__':
235 _test_local_mask()