This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer
17import math
18from typing import Optional
19
20import torch
21from torch import nn as nn
22
23from labml import tracker
24from labml_helpers.module import Module
This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.
27class PrepareForMultiHeadAttention(Module):
38 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
39 super().__init__()
Linear layer for linear transform
41 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
Number of heads
43 self.heads = heads
Number of dimensions in vectors in each head
45 self.d_k = d_k
47 def forward(self, x: torch.Tensor):
Input has shape [seq_len, batch_size, d_model]
or [batch_size, d_model]
.
We apply the linear transformation to the last dimension and split that into
the heads.
51 head_shape = x.shape[:-1]
Linear transform
54 x = self.linear(x)
Split last dimension into heads
57 x = x.view(*head_shape, self.heads, self.d_k)
Output has shape [seq_len, batch_size, heads, d_k]
or [batch_size, d_model]
60 return x
This computes scaled multi-headed attention for given query
, key
and value
vectors.
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.
Softmax is calculated along the axis of of the sequence (or time).
63class MultiHeadAttention(Module):
heads
is the number of heads.d_model
is the number of features in the query
, key
and value
vectors.84 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
90 super().__init__()
Number of features per head
93 self.d_k = d_model // heads
Number of heads
95 self.heads = heads
These transform the query
, key
and value
vectors for multi-headed attention.
98 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
99 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
100 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
Softmax for attention along the time dimension of key
103 self.softmax = nn.Softmax(dim=1)
Output layer
106 self.output = nn.Linear(d_model, d_model)
Dropout
108 self.dropout = nn.Dropout(dropout_prob)
Scaling factor before the softmax
110 self.scale = 1 / math.sqrt(self.d_k)
We store attentions so that it can be used for logging, or other computations if needed
113 self.attn = None
This method can be overridden for other variations like relative attention.
115 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
123 return torch.einsum('ibhd,jbhd->ijbh', query, key)
query
, key
and value
are the tensors that store
collection of query, key and value vectors.
They have shape [seq_len, batch_size, d_model]
.
mask
has shape [seq_len, seq_len, batch_size]
and
mask[i, j, b]
indicates whether for batch b
,
query at position i
has access to key-value at position j
.
125 def forward(self, *,
126 query: torch.Tensor,
127 key: torch.Tensor,
128 value: torch.Tensor,
129 mask: Optional[torch.Tensor] = None):
query
, key
and value
have shape [seq_len, batch_size, d_model]
141 seq_len, batch_size, _ = query.shape
142
143 if mask is not None:
mask
has shape [seq_len_q, seq_len_k, batch_size]
,
where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted.
147 assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
148 assert mask.shape[1] == key.shape[0]
149 assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]
Same mask applied to all heads.
152 mask = mask.unsqueeze(-1)
Prepare query
, key
and value
for attention computation.
These will then have shape [seq_len, batch_size, heads, d_k]
.
156 query = self.query(query)
157 key = self.key(key)
158 value = self.value(value)
Compute attention scores $Q K^\top$.
This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
.
162 scores = self.get_scores(query, key)
Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
165 scores *= self.scale
Apply mask
168 if mask is not None:
169 scores = scores.masked_fill(mask == 0, float('-inf'))
$softmax$ attention along the key sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
173 attn = self.softmax(scores)
Save attentions if debugging
176 tracker.debug('attn', attn)
Apply dropout
179 attn = self.dropout(attn)
Multiply by values
183 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
Save attentions for any other calculations
186 self.attn = attn.detach()
Concatenate multiple heads
189 x = x.reshape(seq_len, batch_size, -1)
Output layer
192 return self.output(x)