# 注意线性偏差 (AliBI)

33import math
34from typing import Optional
35
36import torch
37from torch import nn
38
39from labml.logger import inspect
40from labml_nn.transformers.mha import MultiHeadAttention

## 为每个头部获取特定于头部的斜率

• n_heads 是注意层中的头部数量

43def get_slopes(n_heads: int):

62    n = 2 ** math.floor(math.log2(n_heads))
64    m_0 = 2.0 ** (-8.0 / n)
66    m = torch.pow(m_0, torch.arange(1, 1 + n))

71    if n < n_heads:
73        m_hat_0 = 2.0 ** (-4.0 / n)

76        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))

78        m = torch.cat([m, m_hat])
79
80    return m

## 计算注意力偏差矩阵

• n_heads 是注意层中的头部数量
• mask 是形状的注意力面具[seq_len_q, seq_len_k]

83@torch.no_grad()
84def get_alibi_biases(n_heads: int, mask: torch.Tensor):

95    m = get_slopes(n_heads).to(mask.device)

102    distance = mask.cumsum(dim=-1)

105    return distance[:, :, None] * m[None, None, :]

## 注意线性偏差 (AliBI)

108class AlibiMultiHeadAttention(MultiHeadAttention):
115    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
116        super().__init__(heads, d_model, dropout_prob)

119        self.alibi_biases = None

query keyvalue 是存储查询向量集合的张量。它们有形状[seq_len, batch_size, d_model]

mask 有形状[seq_len, seq_len, batch_size]mask[i, j, b] 指示是否为批量查询b ，位置处的查询i 有权访问位置处的键值j

121    def forward(self, *,
122                query: torch.Tensor,
123                key: torch.Tensor,
124                value: torch.Tensor,
125                mask: Optional[torch.Tensor] = None):

AliBi 仅适用于因果口罩。

137        assert mask is not None
138        assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1

querykey 并且value 有形状[seq_len, batch_size, d_model]

141        seq_len, batch_size, _ = query.shape

144        mask = self.prepare_mask(mask, query.shape, key.shape)

148        query = self.query(query)
149        key = self.key(key)
150        value = self.value(value)

154        scores = self.get_scores(query, key)

157        scores *= self.scale

160        if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:

mask has shape [seq_len, seq_len, 1, 1]

162            self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])

167        scores += self.alibi_biases[:seq_len, :seq_len, None, :]

170        scores = scores.masked_fill(mask == 0, float('-inf'))

174        attn = self.softmax(scores)

177        attn = self.dropout(attn)

181        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

184        x = x.reshape(seq_len, batch_size, -1)

187        return self.output(x)

190def _test_alibi():
194    inspect(get_slopes(12).tolist(), _n=-1)
200    inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)
204if __name__ == '__main__':
205    _test_alibi()