这是《T rain Short,Test Long:使用线性偏差的注意力实现输入长度外推》一文中的 “使用线性偏差注意力 (AliBI)” 的实现。
这将用在注意力分数中添加偏差(注意力对数,在 softmax 之前)取代位置编码。这是一种在自回归任务上测试的相对方案,closeby代币的偏差更高,而遥远的代币的偏差更低。偏差在对数标度中呈线性减小(因为它在softmax之前),并且每个头部都有不同的斜率。
这是第-th 代币的注意力公式,
其中,是第-th 个令牌的查询,最大是密钥数以及每个标头的要素数。请注意,上述等式之所以停止,是因为翻译是不变的(您可以在不更改结果的情况下向所有元素添加任何常量)。
33import math
34from typing import Optional
35
36import torch
37from torch import nn
38
39from labml.logger import inspect
40from labml_nn.transformers.mha import MultiHeadAttention
43def get_slopes(n_heads: int):
获得最接近 2 的幂n_heads
。如果不n_heads
是 2 的幂,那么我们首先计算斜率到最接近(较小)的 2 幂,然后再加上剩余的斜率。
62 n = 2 ** math.floor(math.log2(n_heads))
64 m_0 = 2.0 ** (-8.0 / n)
66 m = torch.pow(m_0, torch.arange(1, 1 + n))
如果不n_heads
是 2 的幂,那么我们将剩余的斜率相加。我们计算剩余的斜率(避免之前添加的斜率)。然后选择斜坡n_heads
。
71 if n < n_heads:
73 m_hat_0 = 2.0 ** (-4.0 / n)
请注意,我们会采取措施避免之前添加的斜坡。
76 m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))
将斜坡与其余的斜坡连接起来。
78 m = torch.cat([m, m_hat])
79
80 return m
n_heads
是注意层中的头部数量mask
是形状的注意力面具[seq_len_q, seq_len_k]
这将返回一个[seq_len_q, seq_len_k, n_heads, ]
具有 AliBi 注意力偏差的形状矩阵。
83@torch.no_grad()
84def get_alibi_biases(n_heads: int, mask: torch.Tensor):
获取每个头部的斜率
95 m = get_slopes(n_heads).to(mask.device)
计算距离这里我们使用掩码计算距离。
既然它是因果掩码,我们也可以使用。distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]
102 distance = mask.cumsum(dim=-1)
将它们成对乘以得到 AliBi 偏差矩阵
105 return distance[:, :, None] * m[None, None, :]
115 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
116 super().__init__(heads, d_model, dropout_prob)
缓存 AliBi 的偏见
119 self.alibi_biases = None
query
key
和value
是存储查询、键和值向量集合的张量。它们有形状[seq_len, batch_size, d_model]
。
mask
有形状[seq_len, seq_len, batch_size]
并mask[i, j, b]
指示是否为批量查询b
,位置处的查询i
有权访问位置处的键值j
。
121 def forward(self, *,
122 query: torch.Tensor,
123 key: torch.Tensor,
124 value: torch.Tensor,
125 mask: Optional[torch.Tensor] = None):
AliBi 仅适用于因果口罩。
137 assert mask is not None
138 assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1
query
,key
并且value
有形状[seq_len, batch_size, d_model]
141 seq_len, batch_size, _ = query.shape
将头部尺寸添加到蒙版并检查其形状。
144 mask = self.prepare_mask(mask, query.shape, key.shape)
准备query
,key
并value
进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]
。
148 query = self.query(query)
149 key = self.key(key)
150 value = self.value(value)
计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]
。
154 scores = self.get_scores(query, key)
音阶分数
157 scores *= self.scale
如果 AliBI 未被缓存,则创建偏差
160 if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:
mask
has shape [seq_len, seq_len, 1, 1]
162 self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])
将 AliBi 偏见添加到注意力分数中。AliBi 偏见有形[seq_len, seq_len, n_heads]
状scores
也有形状[seq_len, seq_len, batch_size, n_heads]
167 scores += self.alibi_biases[:seq_len, :seq_len, None, :]
涂抹面膜
170 scores = scores.masked_fill(mask == 0, float('-inf'))
关注按键序列维度
174 attn = self.softmax(scores)
申请退学
177 attn = self.dropout(attn)
乘以值
181 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
连接多个头
184 x = x.reshape(seq_len, batch_size, -1)
输出层
187 return self.output(x)
查看斜率的简单测试功能。
190def _test_alibi():
194 inspect(get_slopes(12).tolist(), _n=-1)
195 from labml_nn.transformers.utils import subsequent_mask
196
197 mask = subsequent_mask(8)[:, :, 0]
198 inspect(mask)
199
200 inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)
204if __name__ == '__main__':
205 _test_alibi()