This is a PyTorch implementation of the paper An Attention Free Transformer.

This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of $\mathcal{O}(Td)$, where $T$ is the sequence length and $d$ is the dimensionality of embeddings.

The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.

AFT (similar to MHA) first transforms the embeddings $X$ into query $Q = XW^Q$, key $K = XW^K$ and value $V = XW^V$ tensors with learned weights. The output for each position $t \in [1, T]$ is calculated with the following operation.

, where $\odot$ is element-wise product, $\sigma$ is a non-linearity (sigmoid) and $w \in \mathbb{R}^{T \times T}$ is a learned matrix of pair-wise position biases.

This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the $T \times T$ attention matrix that MHA requires, and therefore reduce the memory requirement.

AFT Local only apply learned pair-wise position biases locally:

, where $s \le T$ is the local window size.

Although $w’_{t,t’}$ is $0$ outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.

Here is the training code for a AFT Local model.

```
61from typing import Optional
62
63import torch
64from torch import nn
65
66from labml_helpers.module import Module
```

`69class AFTLocal(Module):`

`d_model`

is the number of features in the`query`

,`key`

and`value`

vectors.`seq_len`

is $T$`local_window_size`

is the local window size $s$`bias`

is whether to have a bias parameter for transformations for $Q$, $K$ and $V$.

`88 def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):`

`96 super().__init__()`

Local window size $s$

`99 self.local_window_size = local_window_size`

These transform the `query`

, `key`

and `value`

vectors.

```
101 self.query = nn.Linear(d_model, d_model, bias=bias)
102 self.key = nn.Linear(d_model, d_model, bias=bias)
103 self.value = nn.Linear(d_model, d_model, bias=bias)
```

Pair-wise positional biases $w \in \mathbb{R}^{T \times T}$

`105 self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)`

Mask for $w_{t,t’}$

`107 self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)`

Activation $\sigma$

`109 self.activation = nn.Sigmoid()`

Output layer

`111 self.output = nn.Linear(d_model, d_model)`

```
113 @staticmethod
114 def create_local_mask(seq_len, local_window_size):
```

Initialize to ones

`129 local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)`

Make $t’ - t \ge s$ zero

`131 local_mask = torch.tril(local_mask, local_window_size - 1)`

Make $t - t’ \ge s$ zero

`133 local_mask = torch.triu(local_mask, -(local_window_size - 1))`

`136 return local_mask`

`query`

, `key`

and `value`

are the tensors that store
collection of token embeddings for *query*, *key* and *value*.
They have shape `[seq_len, batch_size, d_model]`

.

`mask`

has shape `[seq_len, seq_len, batch_size]`

and
`mask[i, j, b]`

indicates whether for batch `b`

,
query at position `i`

has access to key-value at position `j`

.

```
138 def forward(self, *,
139 query: torch.Tensor,
140 key: torch.Tensor,
141 value: torch.Tensor,
142 mask: Optional[torch.Tensor] = None):
```

`query`

, `key`

and `value`

have shape `[seq_len, batch_size, d_model]`

```
154 seq_len, _, _ = query.shape
155
156 if mask is not None:
```

`mask`

has shape `[seq_len_q, seq_len_k, batch_size]`

,
where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted.

```
160 assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
161 assert mask.shape[1] == key.shape[0]
162 assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]
```

Transform query, key and value embeddings

```
165 query = self.query(query)
166 key = self.key(key)
167 value = self.value(value)
```

Get using the mask

```
178 pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
179 pos_bias = pos_bias.unsqueeze(-1)
180 pos_bias.masked_fill_(~mask, float('-inf'))
```

We compute $\exp(w_{t,t’})$, $\exp(K_{t’}) \odot V_{t’}$ and $\exp(K_{t’})$ separately and do a matrix multiplication. We use einsum for clarity.

We subtract $\max_{t’}(K_{t’})$ and $\max_{t’}(w_{t,t’})$ before calculating the exponents to stabilize the softmax calculation.

If $x_i$ is large $\exp(x_i)$ becomes huge and the computation of $\frac{\sum\exp(x_i)y_i}{\sum\exp(x_i)}$becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract $\max(x_i)$ to stabilize the computation.

```
202 max_key = key.max(dim=0, keepdims=True)[0]
203 max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]
```

$\exp \big(K_{t’}- \max_{t’}(K_{t’})\big)$

`206 exp_key = torch.exp(key - max_key)`

$\exp \big(w_{t,t’} - \max_{t’}(w_{t,t’})\big)$

`208 exp_pos_bias = torch.exp(pos_bias - max_pos_bias)`

The numerator part $\sum_{t’=1}^T \exp(w_{t,t’}) \odot \exp(K_{t’}) \odot V_{t’}$

`211 num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)`

The denominator part $\sum_{t’=1}^T \exp(w_{t,t’}) \odot \exp(K_{t’})$

`213 den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)`

Output

`218 y = self.activation(query) * num / den`

Output layer

`221 return self.output(y)`

Test local mask

`224def _test_local_mask():`

```
228 from labml.logger import inspect
229 inspect(AFTLocal.create_local_mask(10, 4))
```

```
233if __name__ == '__main__':
234 _test_local_mask()
```