注意线性偏差 (AliBI)

这是《T rain Short,Test Long:使用线性偏差的注意力实现输入长度外推》一文中的 “使用线性偏差注意力 (AliBI)” 的实现。

这将用在注意力分数中添加偏差(注意力对数,在 softmax 之前)取代位置编码。这是一种在自回归任务上测试的相对方案,closeby代币的偏差更高,而遥远的代币的偏差更低。偏差在对数标度中呈线性减小(因为它在softmax之前),并且每个头部都有不同的斜率。

这是第-th 代币的注意力公式,

其中,第-th 个令牌的查询,最大是密钥数以及每个标头的要素数。请注意,上述等式之所以停止,是因为翻译是不变的(您可以在不更改结果的情况下向所有元素添加任何常量)。

以下是 AliBi 模型的训练代码

33import math
34from typing import Optional
35
36import torch
37from torch import nn
38
39from labml.logger import inspect
40from labml_nn.transformers.mha import MultiHeadAttention

为每个头部获取特定于头部的斜率

  • n_heads 是注意层中的头部数量

第一个头的斜率是

其余头部的斜率为几何序列,其比例与上面相同。

例如,当头数为时,斜率为

43def get_slopes(n_heads: int):

获得最接近 2 的幂n_heads 。如果不n_heads 是 2 的幂,那么我们首先计算斜率到最接近(较小)的 2 幂,然后再加上剩余的斜率。

62    n = 2 ** math.floor(math.log2(n_heads))

64    m_0 = 2.0 ** (-8.0 / n)

66    m = torch.pow(m_0, torch.arange(1, 1 + n))

如果不n_heads 是 2 的幂,那么我们将剩余的斜率相加。我们计算剩余的斜率(避免之前添加的斜率)。然后选择斜坡n_heads

71    if n < n_heads:

73        m_hat_0 = 2.0 ** (-4.0 / n)

请注意,我们会采取措施避免之前添加的斜坡。

76        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))

将斜坡与其余的斜坡连接起来。

78        m = torch.cat([m, m_hat])
79
80    return m

计算注意力偏差矩阵

  • n_heads 是注意层中的头部数量
  • mask 是形状的注意力面具[seq_len_q, seq_len_k]

这将返回一个[seq_len_q, seq_len_k, n_heads, ] 具有 AliBi 注意力偏差的形状矩阵。

83@torch.no_grad()
84def get_alibi_biases(n_heads: int, mask: torch.Tensor):

获取每个头部的斜率

95    m = get_slopes(n_heads).to(mask.device)

计算距离这里我们使用掩码计算距离。

既然它是因果掩码,我们也可以使用。distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]

102    distance = mask.cumsum(dim=-1)

将它们成对乘以得到 AliBi 偏差矩阵

105    return distance[:, :, None] * m[None, None, :]

注意线性偏差 (AliBI)

我们覆盖多头注意力

108class AlibiMultiHeadAttention(MultiHeadAttention):
115    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
116        super().__init__(heads, d_model, dropout_prob)

缓存 AliBi 的偏见

119        self.alibi_biases = None

query keyvalue 是存储查询向量集合的张量。它们有形状[seq_len, batch_size, d_model]

mask 有形状[seq_len, seq_len, batch_size]mask[i, j, b] 指示是否为批量查询b ,位置处的查询i 有权访问位置处的键值j

121    def forward(self, *,
122                query: torch.Tensor,
123                key: torch.Tensor,
124                value: torch.Tensor,
125                mask: Optional[torch.Tensor] = None):

AliBi 仅适用于因果口罩。

137        assert mask is not None
138        assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1

querykey 并且value 有形状[seq_len, batch_size, d_model]

141        seq_len, batch_size, _ = query.shape

将头部尺寸添加到蒙版并检查其形状。

144        mask = self.prepare_mask(mask, query.shape, key.shape)

准备querykeyvalue 进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]

148        query = self.query(query)
149        key = self.key(key)
150        value = self.value(value)

计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]

154        scores = self.get_scores(query, key)

音阶分数

157        scores *= self.scale

如果 AliBI 未被缓存,则创建偏差

160        if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:

mask 有形状 seq_len、seq_len、1、1

162            self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])

将 AliBi 偏见添加到注意力分数中。AliBi 偏见有形[seq_len, seq_len, n_heads]scores 也有形状[seq_len, seq_len, batch_size, n_heads]

167        scores += self.alibi_biases[:seq_len, :seq_len, None, :]

涂抹面膜

170        scores = scores.masked_fill(mask == 0, float('-inf'))

关注按键序列维度

174        attn = self.softmax(scores)

申请退学

177        attn = self.dropout(attn)

乘以值

181        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

连接多个头

184        x = x.reshape(seq_len, batch_size, -1)

输出层

187        return self.output(x)

查看斜率的简单测试功能。

190def _test_alibi():
194    inspect(get_slopes(12).tolist(), _n=-1)
195    from labml_nn.transformers.utils import subsequent_mask
196
197    mask = subsequent_mask(8)[:, :, 0]
198    inspect(mask)
199
200    inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)

204if __name__ == '__main__':
205    _test_alibi()