Attention with Linear Biases (ALiBi)

This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.

This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.

Here's the attention formula for -th token,

where is the query of the -th token, are the keys up to , and the number of features per head. Note that the above equality halts because is invariant to translations (you can add any constant to all elements without changing the result).

Here is the training code for a ALiBi model.

View Run

35import math
36from typing import Optional
37
38import torch
39from torch import nn
40
41from labml.logger import inspect
42from labml_nn.transformers.mha import MultiHeadAttention

Get head-specific slope for each head

  • n_heads is the number of heads in the attention layer

The slope for first head is

The slopes for the rest of the heads are in a geometric series with a ratio same as above.

For instance when the number of heads is the slopes are

45def get_slopes(n_heads: int):

Get the closest power of 2 to n_heads . If n_heads is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, and then add the remaining slopes.

64    n = 2 ** math.floor(math.log2(n_heads))

66    m_0 = 2.0 ** (-8.0 / n)

68    m = torch.pow(m_0, torch.arange(1, 1 + n))

If n_heads is not a power of 2, then we add the remaining slopes. We calculate the remaining slopes for (avoiding slopes added previously). And pick the slopes upto n_heads .

73    if n < n_heads:

75        m_hat_0 = 2.0 ** (-4.0 / n)

Note that we take steps by to avoid slopes added previously.

78        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))

Concatenate the slopes with the remaining slopes.

80        m = torch.cat([m, m_hat])
81
82    return m

Calculate the attention biases matrix

  • n_heads is the number of heads in the attention layer
  • mask is the attention mask of shape [seq_len_q, seq_len_k]

This returns a matrix of shape [seq_len_q, seq_len_k, n_heads, ] with ALiBi attention biases.

85@torch.no_grad()
86def get_alibi_biases(n_heads: int, mask: torch.Tensor):

Get slopes for each head

97    m = get_slopes(n_heads).to(mask.device)

Calculate distances Here we calculate the distances using the mask.

Since it's causal mask we can just use too. distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]

104    distance = mask.cumsum(dim=-1)

Multiply them pair-wise to get the AliBi bias matrix

107    return distance[:, :, None] * m[None, None, :]

Attention with Linear Biases (ALiBi)

We override Multi-Head Attention.

110class AlibiMultiHeadAttention(MultiHeadAttention):
117    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
118        super().__init__(heads, d_model, dropout_prob)

To cache AliBi the biases

121        self.alibi_biases = None

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model] .

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .

123    def forward(self, *,
124                query: torch.Tensor,
125                key: torch.Tensor,
126                value: torch.Tensor,
127                mask: Optional[torch.Tensor] = None):

ALiBi only works with causal masks.

139        assert mask is not None
140        assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1

query , key and value have shape [seq_len, batch_size, d_model]

143        seq_len, batch_size, _ = query.shape

Add head dimension to mask and check its shape.

146        mask = self.prepare_mask(mask, query.shape, key.shape)

Prepare query , key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k] .

150        query = self.query(query)
151        key = self.key(key)
152        value = self.value(value)

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

156        scores = self.get_scores(query, key)

Scale scores

159        scores *= self.scale

Create AliBi biases if it's not cached

162        if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:

mask has shape seq_len, seq_len, 1, 1

164            self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])

Add AliBi biases to attention scores. ALiBi biases has shape [seq_len, seq_len, n_heads] and scores has shape [seq_len, seq_len, batch_size, n_heads]

169        scores += self.alibi_biases[:seq_len, :seq_len, None, :]

Apply mask

172        scores = scores.masked_fill(mask == 0, float('-inf'))

attention along the key sequence dimension

176        attn = self.softmax(scores)

Apply dropout

179        attn = self.dropout(attn)

Multiply by values

183        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

Concatenate multiple heads

186        x = x.reshape(seq_len, batch_size, -1)

Output layer

189        return self.output(x)

Simple test function to see the slopes.

192def _test_alibi():
196    inspect(get_slopes(12).tolist(), _n=-1)
197    from labml_nn.transformers.utils import subsequent_mask
198
199    mask = subsequent_mask(8)[:, :, 0]
200    inspect(mask)
201
202    inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)

206if __name__ == '__main__':
207    _test_alibi()