This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.

Here is the training code that uses a basic transformer with MHA for NLP auto-regression.

Here is an experiment implementation that trains a simple transformer.

24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker

This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.

44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()

Linear layer for linear transform

47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

Number of dimensions in vectors in each head

51        self.d_k = d_k
53    def forward(self, x: torch.Tensor):

Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] . We apply the linear transformation to the last dimension and split that into the heads.

Linear transform

60        x = self.linear(x)

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]

66        return x

This computes scaled multi-headed attention for given query , key and value vectors.

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.

Softmax is calculated along the axis of of the sequence (or time).

• d_model is the number of features in the query , key and value vectors.
90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96        super().__init__()

99        self.d_k = d_model // heads

These transform the query , key and value vectors for multi-headed attention.

Softmax for attention along the time dimension of key

109        self.softmax = nn.Softmax(dim=1)

Output layer

112        self.output = nn.Linear(d_model, d_model)

Dropout

114        self.dropout = nn.Dropout(dropout_prob)

Scaling factor before the softmax

116        self.scale = 1 / math.sqrt(self.d_k)

We store attentions so that it can be used for logging, or other computations if needed

119        self.attn = None

### Calculate scores between queries and keys

This method can be overridden for other variations like relative attention.

121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

Calculate or

mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the query dimension. If the query dimension is equal to it will be broadcasted.

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model] .

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .

147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,

query , key and value have shape [seq_len, batch_size, d_model]

163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:

Prepare query , key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k] .

170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

176        scores = self.get_scores(query, key)

Scale scores

179        scores *= self.scale

182        if mask is not None:

attention along the key sequence dimension

187        attn = self.softmax(scores)

Save attentions if debugging

190        tracker.debug('attn', attn)

Apply dropout

193        attn = self.dropout(attn)

Multiply by values

197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

Save attentions for any other calculations

200        self.attn = attn.detach()