Relative Multi-Headed Attention

This is an implementation of relative multi-headed attention from paper Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.

16import torch
17from torch import nn
18
19from labml.logger import inspect
20from labml_nn.transformers.mha import MultiHeadAttention

This method shifts $i^{th}$ row of a matrix by $i$ columns.

If the input is [[1, 2 ,3], [4, 5 ,6], [7, 8, 9]], the shifted result would be [[1, 2 ,3], [0, 4, 5], [9, 0, 7]]. Ideally we should mask out the lower triangle but it’s ok for our purpose.

23def shift_right(x: torch.Tensor):

Concatenate a column of zeros

33    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
34    x_padded = torch.cat([x, zero_pad], dim=1)

Reshape and remove excess elements from the end

37    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
38    x = x_padded[:-1].view_as(x)
41    return x

Relative Multi-Head Attention Module

We override Multi-Head Attention module so we only need to write the get_scores method.

44class RelativeMultiHeadAttention(MultiHeadAttention):
52    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):

The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value might make sense.

56        super().__init__(heads, d_model, dropout_prob, bias=False)

Number of relative positions

59        self.P = 2 ** 12

Relative positional embeddings for key relative to the query. We need $2P$ embeddings because the keys can be before or after the query.

63        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)

Relative positional embedding bias for key relative to the query.

65        self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)

Positional embeddings for the query is independent of the position of the query

67        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

Get relative attention scores

With absolute attention

where $Q_i, K_j$, are linear transformations of original embeddings $X^q_i, X^k_j$ and $U^Q_i, U^K_j$ are linear transformations of absolute positional encodings $P_i, P_j$.

They reason out that the attention to a given key should be the same regardless of the position of query. Hence replace $\underset{\color{lightgreen}{C}}{{U^Q_i}^\top K_j}$ with a constant $\underset{\color{lightgreen}{C}}{\color{orange}{v^\top} K_j}$.

For the second and third terms relative positional encodings are introduced. So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$ and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$ with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.

69    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

$\color{orange}{R_k}$

108        key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]]

$\color{orange}{S_k}$

110        key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]]

$\color{orange}{v^\top}$

112        query_pos_bias = self.query_pos_bias[None, None, :, :]

${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} = Q_i^\top K_j + \color{orange}{v^\top} K_jZ$

117        ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)

$\color{lightgreen}{\mathbf{B’}_{i,k}} = Q_i^\top \color{orange}{R_k}$

119        b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)

$\color{lightgreen}{\mathbf{D’}_{i,k}} = \color{orange}{S_k}$

121        d = key_pos_bias[None, :, None, :]

Shift the rows of $\color{lightgreen}{\mathbf{(B’ + D’)}_{i,k}}$ to get

124        bd = shift_right(b + d)

Remove extra positions

126        bd = bd[:, -key.shape[0]:]

Return the sum

134        return ac + bd
137def _test_shift_right():
138    x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
139    inspect(x)
140    inspect(shift_right(x))
141
142    x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
143    inspect(x[:, :, 0, 0])
144    inspect(shift_right(x)[:, :, 0, 0])
145
146    x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
147    inspect(x[:, :, 0, 0])
148    inspect(shift_right(x)[:, :, 0, 0])
149
150
151if __name__ == '__main__':
152    _test_shift_right()