This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.
Here is the training code that uses a basic transformer with MHA for NLP auto-regression.
Here is an experiment implementation that trains a simple transformer.
24import math
25from typing import Optional, List
26
27import torch
28from torch import nn as nn
29
30from labml import tracker
31from labml_helpers.module import Module
This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.
34class PrepareForMultiHeadAttention(Module):
45 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
46 super().__init__()
Linear layer for linear transform
48 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
Number of heads
50 self.heads = heads
Number of dimensions in vectors in each head
52 self.d_k = d_k
54 def forward(self, x: torch.Tensor):
Input has shape [seq_len, batch_size, d_model]
or [batch_size, d_model]
. We apply the linear transformation to the last dimension and split that into the heads.
58 head_shape = x.shape[:-1]
Linear transform
61 x = self.linear(x)
Split last dimension into heads
64 x = x.view(*head_shape, self.heads, self.d_k)
Output has shape [seq_len, batch_size, heads, d_k]
or [batch_size, d_model]
67 return x
This computes scaled multi-headed attention for given query
, key
and value
vectors.
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.
Softmax is calculated along the axis of of the sequence (or time).
70class MultiHeadAttention(Module):
heads
is the number of heads. d_model
is the number of features in the query
, key
and value
vectors.91 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
97 super().__init__()
Number of features per head
100 self.d_k = d_model // heads
Number of heads
102 self.heads = heads
These transform the query
, key
and value
vectors for multi-headed attention.
105 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
107 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
Softmax for attention along the time dimension of key
110 self.softmax = nn.Softmax(dim=1)
Output layer
113 self.output = nn.Linear(d_model, d_model)
Dropout
115 self.dropout = nn.Dropout(dropout_prob)
Scaling factor before the softmax
117 self.scale = 1 / math.sqrt(self.d_k)
We store attentions so that it can be used for logging, or other computations if needed
120 self.attn = None
This method can be overridden for other variations like relative attention.
122 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
Calculate or
130 return torch.einsum('ibhd,jbhd->ijbh', query, key)
mask
has shape [seq_len_q, seq_len_k, batch_size]
, where first dimension is the query dimension. If the query dimension is equal to it will be broadcasted.
132 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
138 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
139 assert mask.shape[1] == key_shape[0]
140 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
Same mask applied to all heads.
143 mask = mask.unsqueeze(-1)
resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]
146 return mask
query
, key
and value
are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model]
.
mask
has shape [seq_len, seq_len, batch_size]
and mask[i, j, b]
indicates whether for batch b
, query at position i
has access to key-value at position j
.
148 def forward(self, *,
149 query: torch.Tensor,
150 key: torch.Tensor,
151 value: torch.Tensor,
152 mask: Optional[torch.Tensor] = None):
query
, key
and value
have shape [seq_len, batch_size, d_model]
164 seq_len, batch_size, _ = query.shape
165
166 if mask is not None:
167 mask = self.prepare_mask(mask, query.shape, key.shape)
Prepare query
, key
and value
for attention computation. These will then have shape [seq_len, batch_size, heads, d_k]
.
171 query = self.query(query)
172 key = self.key(key)
173 value = self.value(value)
Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
.
177 scores = self.get_scores(query, key)
Scale scores
180 scores *= self.scale
Apply mask
183 if mask is not None:
184 scores = scores.masked_fill(mask == 0, float('-inf'))
attention along the key sequence dimension
188 attn = self.softmax(scores)
Save attentions if debugging
191 tracker.debug('attn', attn)
Apply dropout
194 attn = self.dropout(attn)
Multiply by values
198 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
Save attentions for any other calculations
201 self.attn = attn.detach()
Concatenate multiple heads
204 x = x.reshape(seq_len, batch_size, -1)
Output layer
207 return self.output(x)