This is a PyTorch implementation of the paper An Attention Free Transformer.
This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of , where is the sequence length and is the dimensionality of embeddings.
The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.
AFT (similar to MHA) first transforms the embeddings into query , key and value tensors with learned weights. The output for each position is calculated with the following operation.
, where is element-wise product, is a non-linearity (sigmoid) and is a learned matrix of pair-wise position biases.
This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the attention matrix that MHA requires, and therefore reduce the memory requirement.
AFT Local only apply learned pair-wise position biases locally:
, where is the local window size.
Although is outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.
Here is the training code for a AFT Local model.
59from typing import Optional
60
61import torch
62from torch import nn
63
64from labml_helpers.module import Module
67class AFTLocal(Module):
d_model
is the number of features in the query
, key
and value
vectors. seq_len
is local_window_size
is the local window size bias
is whether to have a bias parameter for transformations for , and .86 def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
94 super().__init__()
Local window size
97 self.local_window_size = local_window_size
These transform the query
, key
and value
vectors.
99 self.query = nn.Linear(d_model, d_model, bias=bias)
100 self.key = nn.Linear(d_model, d_model, bias=bias)
101 self.value = nn.Linear(d_model, d_model, bias=bias)
Pair-wise positional biases
103 self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)
Mask for
105 self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)
Activation
107 self.activation = nn.Sigmoid()
Output layer
109 self.output = nn.Linear(d_model, d_model)
111 @staticmethod
112 def create_local_mask(seq_len, local_window_size):
Initialize to ones
128 local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)
Make zero
130 local_mask = torch.tril(local_mask, local_window_size - 1)
Make zero
132 local_mask = torch.triu(local_mask, -(local_window_size - 1))
135 return local_mask
query
, key
and value
are the tensors that store collection of token embeddings for query, key and value. They have shape [seq_len, batch_size, d_model]
.
mask
has shape [seq_len, seq_len, batch_size]
and mask[i, j, b]
indicates whether for batch b
, query at position i
has access to key-value at position j
.
137 def forward(self, *,
138 query: torch.Tensor,
139 key: torch.Tensor,
140 value: torch.Tensor,
141 mask: Optional[torch.Tensor] = None):
query
, key
and value
have shape [seq_len, batch_size, d_model]
153 seq_len, _, _ = query.shape
154
155 if mask is not None:
mask
has shape [seq_len_q, seq_len_k, batch_size]
, where first dimension is the query dimension. If the query dimension is equal to it will be broadcasted.
159 assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
160 assert mask.shape[1] == key.shape[0]
161 assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]
Transform query, key and value embeddings
164 query = self.query(query)
165 key = self.key(key)
166 value = self.value(value)
179 pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
180 pos_bias = pos_bias.unsqueeze(-1)
181 pos_bias.masked_fill_(~mask, float('-inf'))
We subtract and before calculating the exponents to stabilize the softmax calculation.
If is large becomes huge and the computation of becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract to stabilize the computation.
203 max_key = key.max(dim=0, keepdims=True)[0]
204 max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]
207 exp_key = torch.exp(key - max_key)
209 exp_pos_bias = torch.exp(pos_bias - max_pos_bias)
The numerator part
212 num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)
The denominator part
214 den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)
Output
219 y = self.activation(query) * num / den
Output layer
222 return self.output(y)
Test local mask
225def _test_local_mask():
229 from labml.logger import inspect
230 inspect(AFTLocal.create_local_mask(10, 4))
234if __name__ == '__main__':
235 _test_local_mask()