An Attention Free Transformer

This is a PyTorch implementation of the paper An Attention Free Transformer.

This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of $\mathcal{O}(Td)$, where $T$ is the sequence length and $d$ is the dimensionality of embeddings.

The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.

Attention Free Transformer

AFT (similar to MHA) first transforms the embeddings $X$ into query $Q = XW^Q$, key $K = XW^K$ and value $V = XW^V$ tensors with learned weights. The output for each position $t \in [1, T]$ is calculated with the following operation.

, where $\odot$ is element-wise product, $\sigma$ is a non-linearity (sigmoid) and $w \in \mathbb{R}^{T \times T}$ is a learned matrix of pair-wise position biases.

This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the $T \times T$ attention matrix that MHA requires, and therefore reduce the memory requirement.

AFT Local

AFT Local only apply learned pair-wise position biases locally:

, where $s \le T$ is the local window size.

Although $w’_{t,t’}$ is $0$ outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.

Here is the training code for a AFT Local model.

View Run

61from typing import Optional
62
63import torch
64from torch import nn
65
66from labml_helpers.module import Module

AFT Local Operation

where,

69class AFTLocal(Module):
  • d_model is the number of features in the query, key and value vectors.
  • seq_len is $T$
  • local_window_size is the local window size $s$
  • bias is whether to have a bias parameter for transformations for $Q$, $K$ and $V$.
88    def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
96        super().__init__()

Local window size $s$

99        self.local_window_size = local_window_size

These transform the query, key and value vectors.

101        self.query = nn.Linear(d_model, d_model, bias=bias)
102        self.key = nn.Linear(d_model, d_model, bias=bias)
103        self.value = nn.Linear(d_model, d_model, bias=bias)

Pair-wise positional biases $w \in \mathbb{R}^{T \times T}$

105        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)

Mask for $w_{t,t’}$

107        self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)

Activation $\sigma$

109        self.activation = nn.Sigmoid()

Output layer

111        self.output = nn.Linear(d_model, d_model)

Create local mask

This creates a mask for

113    @staticmethod
114    def create_local_mask(seq_len, local_window_size):

Initialize to ones

129        local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)

Make $t’ - t \ge s$ zero

131        local_mask = torch.tril(local_mask, local_window_size - 1)

Make $t - t’ \ge s$ zero

133        local_mask = torch.triu(local_mask, -(local_window_size - 1))
136        return local_mask

query, key and value are the tensors that store collection of token embeddings for query, key and value. They have shape [seq_len, batch_size, d_model].

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b, query at position i has access to key-value at position j.

138    def forward(self, *,
139                query: torch.Tensor,
140                key: torch.Tensor,
141                value: torch.Tensor,
142                mask: Optional[torch.Tensor] = None):

query, key and value have shape [seq_len, batch_size, d_model]

154        seq_len, _, _ = query.shape
155
156        if mask is not None:

mask has shape [seq_len_q, seq_len_k, batch_size], where first dimension is the query dimension. If the query dimension is equal to $1$ it will be broadcasted.

160            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
161            assert mask.shape[1] == key.shape[0]
162            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

Transform query, key and value embeddings

165        query = self.query(query)
166        key = self.key(key)
167        value = self.value(value)

Get using the mask

178        pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
179        pos_bias = pos_bias.unsqueeze(-1)
180        pos_bias.masked_fill_(~mask, float('-inf'))

We compute $\exp(w_{t,t’})$, $\exp(K_{t’}) \odot V_{t’}$ and $\exp(K_{t’})$ separately and do a matrix multiplication. We use einsum for clarity.

We subtract $\max_{t’}(K_{t’})$ and $\max_{t’}(w_{t,t’})$ before calculating the exponents to stabilize the softmax calculation.

If $x_i$ is large $\exp(x_i)$ becomes huge and the computation of $\frac{\sum\exp(x_i)y_i}{\sum\exp(x_i)}$becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract $\max(x_i)$ to stabilize the computation.

202        max_key = key.max(dim=0, keepdims=True)[0]
203        max_pos_bias = pos_bias.max(dim=1,  keepdims=True)[0]

$\exp \big(K_{t’}- \max_{t’}(K_{t’})\big)$

206        exp_key = torch.exp(key - max_key)

$\exp \big(w_{t,t’} - \max_{t’}(w_{t,t’})\big)$

208        exp_pos_bias = torch.exp(pos_bias - max_pos_bias)

The numerator part $\sum_{t’=1}^T \exp(w_{t,t’}) \odot \exp(K_{t’}) \odot V_{t’}$

211        num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)

The denominator part $\sum_{t’=1}^T \exp(w_{t,t’}) \odot \exp(K_{t’})$

213        den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)

Output

218        y = self.activation(query) * num / den

Output layer

221        return self.output(y)

Test local mask

224def _test_local_mask():
228    from labml.logger import inspect
229    inspect(AFTLocal.create_local_mask(10, 4))
233if __name__ == '__main__':
234    _test_local_mask()