An Attention Free Transformer

This is a PyTorch implementation of the paper An Attention Free Transformer.

This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of , where is the sequence length and is the dimensionality of embeddings.

The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.

Attention Free Transformer

AFT (similar to MHA) first transforms the embeddings into query , key and value tensors with learned weights. The output for each position is calculated with the following operation.

, where is element-wise product, is a non-linearity (sigmoid) and is a learned matrix of pair-wise position biases.

This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the attention matrix that MHA requires, and therefore reduce the memory requirement.

AFT Local

AFT Local only apply learned pair-wise position biases locally:

, where is the local window size.

Although is outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.

Here is the training code for a AFT Local model.

View Run

61from typing import Optional
62
63import torch
64from torch import nn
65
66from labml_helpers.module import Module

AFT Local Operation

where,

69class AFTLocal(Module):
  • d_model is the number of features in the query , key and value vectors.
  • seq_len is
  • local_window_size is the local window size
  • bias is whether to have a bias parameter for transformations for , and .
88    def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
96        super().__init__()

Local window size

99        self.local_window_size = local_window_size

These transform the query , key and value vectors.

101        self.query = nn.Linear(d_model, d_model, bias=bias)
102        self.key = nn.Linear(d_model, d_model, bias=bias)
103        self.value = nn.Linear(d_model, d_model, bias=bias)

Pair-wise positional biases

105        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)

Mask for

107        self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)

Activation

109        self.activation = nn.Sigmoid()

Output layer

111        self.output = nn.Linear(d_model, d_model)

Create local mask

This creates a mask for

113    @staticmethod
114    def create_local_mask(seq_len, local_window_size):

Initialize to ones

130        local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)

Make zero

132        local_mask = torch.tril(local_mask, local_window_size - 1)

Make zero

134        local_mask = torch.triu(local_mask, -(local_window_size - 1))

137        return local_mask

query , key and value are the tensors that store collection of token embeddings for query, key and value. They have shape [seq_len, batch_size, d_model] .

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .

139    def forward(self, *,
140                query: torch.Tensor,
141                key: torch.Tensor,
142                value: torch.Tensor,
143                mask: Optional[torch.Tensor] = None):

query , key and value have shape [seq_len, batch_size, d_model]

155        seq_len, _, _ = query.shape
156
157        if mask is not None:

mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the query dimension. If the query dimension is equal to it will be broadcasted.

161            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
162            assert mask.shape[1] == key.shape[0]
163            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

Transform query, key and value embeddings

166        query = self.query(query)
167        key = self.key(key)
168        value = self.value(value)

Get

using the mask

181        pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
182        pos_bias = pos_bias.unsqueeze(-1)
183        pos_bias.masked_fill_(~mask, float('-inf'))

We compute , and separately and do a matrix multiplication. We use einsum for clarity.

We subtract and before calculating the exponents to stabilize the softmax calculation.

If is large becomes huge and the computation of becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract to stabilize the computation.

205        max_key = key.max(dim=0, keepdims=True)[0]
206        max_pos_bias = pos_bias.max(dim=1,  keepdims=True)[0]

209        exp_key = torch.exp(key - max_key)

211        exp_pos_bias = torch.exp(pos_bias - max_pos_bias)

The numerator part

214        num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)

The denominator part

216        den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)

Output

221        y = self.activation(query) * num / den

Output layer

224        return self.output(y)

Test local mask

227def _test_local_mask():
231    from labml.logger import inspect
232    inspect(AFTLocal.create_local_mask(10, 4))

236if __name__ == '__main__':
237    _test_local_mask()