Multi-Headed Attention (MHA)

This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.

Here is the training code that uses a basic transformer with MHA for NLP auto-regression.

20import math
21from typing import Optional
22
23import torch
24from torch import nn as nn
25
26from labml import tracker
27from labml_helpers.module import Module

Prepare for multi-head attention

This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.

30class PrepareForMultiHeadAttention(Module):
41    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
42        super().__init__()

Linear layer for linear transform

44        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

Number of heads

46        self.heads = heads

Number of dimensions in vectors in each head

48        self.d_k = d_k
50    def forward(self, x: torch.Tensor):

Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model]. We apply the linear transformation to the last dimension and split that into the heads.

54        head_shape = x.shape[:-1]

Linear transform

57        x = self.linear(x)

Split last dimension into heads

60        x = x.view(*head_shape, self.heads, self.d_k)

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]

63        return x

Multi-Head Attention Module

This computes scaled multi-headed attention for given query, key and value vectors.

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.

Softmax is calculated along the axis of of the sequence (or time).

66class MultiHeadAttention(Module):
  • heads is the number of heads.
  • d_model is the number of features in the query, key and value vectors.
87    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
93        super().__init__()

Number of features per head

96        self.d_k = d_model // heads

Number of heads

98        self.heads = heads

These transform the query, key and value vectors for multi-headed attention.

101        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
102        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
103        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

Softmax for attention along the time dimension of key

106        self.softmax = nn.Softmax(dim=1)

Output layer

109        self.output = nn.Linear(d_model, d_model)

Dropout

111        self.dropout = nn.Dropout(dropout_prob)

Scaling factor before the softmax

113        self.scale = 1 / math.sqrt(self.d_k)

We store attentions so that it can be used for logging, or other computations if needed

116        self.attn = None

Calculate scores between queries and keys

This method can be overridden for other variations like relative attention.

118    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$

126        return torch.einsum('ibhd,jbhd->ijbh', query, key)

query, key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model].

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b, query at position i has access to key-value at position j.

128    def forward(self, *,
129                query: torch.Tensor,
130                key: torch.Tensor,
131                value: torch.Tensor,
132                mask: Optional[torch.Tensor] = None):

query, key and value have shape [seq_len, batch_size, d_model]

144        seq_len, batch_size, _ = query.shape
145
146        if mask is not None:

mask has shape [seq_len_q, seq_len_k, batch_size], where first dimension is the query dimension. If the query dimension is equal to $1$ it will be broadcasted.

150            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
151            assert mask.shape[1] == key.shape[0]
152            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

Same mask applied to all heads.

155            mask = mask.unsqueeze(-1)

Prepare query, key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k].

159        query = self.query(query)
160        key = self.key(key)
161        value = self.value(value)

Compute attention scores $Q K^\top$. This gives a tensor of shape [seq_len, seq_len, batch_size, heads].

165        scores = self.get_scores(query, key)

Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$

168        scores *= self.scale

Apply mask

171        if mask is not None:
172            scores = scores.masked_fill(mask == 0, float('-inf'))

$softmax$ attention along the key sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$

176        attn = self.softmax(scores)

Save attentions if debugging

179        tracker.debug('attn', attn)

Apply dropout

182        attn = self.dropout(attn)

Multiply by values

186        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

Save attentions for any other calculations

189        self.attn = attn.detach()

Concatenate multiple heads

192        x = x.reshape(seq_len, batch_size, -1)

Output layer

195        return self.output(x)