This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.

Here is the training code that uses a basic transformer with MHA for NLP auto-regression.

Here is an experiment implementation that trains a simple transformer.

```
24import math
25from typing import Optional, List
26
27import torch
28from torch import nn as nn
29
30from labml import tracker
31from labml_helpers.module import Module
```

This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform **key**, **query**, and **value** vectors.

`34class PrepareForMultiHeadAttention(Module):`

```
45 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
46 super().__init__()
```

Linear layer for linear transform

`48 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)`

Number of heads

`50 self.heads = heads`

Number of dimensions in vectors in each head

`52 self.d_k = d_k`

`54 def forward(self, x: torch.Tensor):`

Input has shape `[seq_len, batch_size, d_model]`

or `[batch_size, d_model]`

. We apply the linear transformation to the last dimension and split that into the heads.

`58 head_shape = x.shape[:-1]`

Linear transform

`61 x = self.linear(x)`

Split last dimension into heads

`64 x = x.view(*head_shape, self.heads, self.d_k)`

Output has shape `[seq_len, batch_size, heads, d_k]`

or `[batch_size, d_model]`

`67 return x`

This computes scaled multi-headed attention for given `query`

, `key`

and `value`

vectors.

$Attention(Q,K,V)=seqsoftmax (d_{k} QK_{⊤} )V$

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $d_{k} 1 $. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_{k}$ is large.

Softmax is calculated along the axis of of the sequence (or time).

`70class MultiHeadAttention(Module):`

`heads`

is the number of heads.`d_model`

is the number of features in the`query`

,`key`

and`value`

vectors.

`91 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):`

`97 super().__init__()`

Number of features per head

`100 self.d_k = d_model // heads`

Number of heads

`102 self.heads = heads`

These transform the `query`

, `key`

and `value`

vectors for multi-headed attention.

```
105 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
107 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
```

Softmax for attention along the time dimension of `key`

`110 self.softmax = nn.Softmax(dim=1)`

Output layer

`113 self.output = nn.Linear(d_model, d_model)`

Dropout

`115 self.dropout = nn.Dropout(dropout_prob)`

Scaling factor before the softmax

`117 self.scale = 1 / math.sqrt(self.d_k)`

We store attentions so that it can be used for logging, or other computations if needed

`120 self.attn = None`

This method can be overridden for other variations like relative attention.

`122 def get_scores(self, query: torch.Tensor, key: torch.Tensor):`

Calculate $QK_{⊤}$ or $S_{ijbh}=∑_{d}Q_{ibhd}K_{jbhd}$

`130 return torch.einsum('ibhd,jbhd->ijbh', query, key)`

`mask`

has shape `[seq_len_q, seq_len_k, batch_size]`

, where first dimension is the query dimension. If the query dimension is equal to $1$ it will be broadcasted.

`132 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):`

```
138 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
139 assert mask.shape[1] == key_shape[0]
140 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
```

Same mask applied to all heads.

`143 mask = mask.unsqueeze(-1)`

resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`

`146 return mask`

`query`

, `key`

and `value`

are the tensors that store collection of *query*, *key* and *value* vectors. They have shape `[seq_len, batch_size, d_model]`

.

`mask`

has shape `[seq_len, seq_len, batch_size]`

and `mask[i, j, b]`

indicates whether for batch `b`

, query at position `i`

has access to key-value at position `j`

.

```
148 def forward(self, *,
149 query: torch.Tensor,
150 key: torch.Tensor,
151 value: torch.Tensor,
152 mask: Optional[torch.Tensor] = None):
```

`query`

, `key`

and `value`

have shape `[seq_len, batch_size, d_model]`

```
164 seq_len, batch_size, _ = query.shape
165
166 if mask is not None:
167 mask = self.prepare_mask(mask, query.shape, key.shape)
```

Prepare `query`

, `key`

and `value`

for attention computation. These will then have shape `[seq_len, batch_size, heads, d_k]`

.

```
171 query = self.query(query)
172 key = self.key(key)
173 value = self.value(value)
```

Compute attention scores $QK_{⊤}$. This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`

.

`177 scores = self.get_scores(query, key)`

Scale scores $d_{k} QK_{⊤} $

`180 scores *= self.scale`

Apply mask

```
183 if mask is not None:
184 scores = scores.masked_fill(mask == 0, float('-inf'))
```

$softmax$ attention along the key sequence dimension $seqsoftmax (d_{k} QK_{⊤} )$

`188 attn = self.softmax(scores)`

Save attentions if debugging

`191 tracker.debug('attn', attn)`

Apply dropout

`194 attn = self.dropout(attn)`

Multiply by values $seqsoftmax (d_{k} QK_{⊤} )V$

`198 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)`

Save attentions for any other calculations

`201 self.attn = attn.detach()`

Concatenate multiple heads

`204 x = x.reshape(seq_len, batch_size, -1)`

Output layer

`207 return self.output(x)`