Relative Multi-Headed Attention

This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.

Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is $l$ it will keep the embeddings of all layers for previous batch of length $l$ and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.

28import torch
29from torch import nn
30
31from labml.logger import inspect
32from labml_nn.transformers.mha import MultiHeadAttention

This method shifts $i^{th}$ row of a matrix by $i$ columns.

If the input is [[1, 2 ,3], [4, 5 ,6], [7, 8, 9]], the shifted result would be [[1, 2 ,3], [0, 4, 5], [9, 0, 7]]. Ideally we should mask out the lower triangle but it’s ok for our purpose.

35def shift_right(x: torch.Tensor):

Concatenate a column of zeros

45    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
46    x_padded = torch.cat([x, zero_pad], dim=1)

Remove excess elements from the end

49    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
50    x = x_padded[:-1].view_as(x)
51
52    return x

Relative Multi-Head Attention Module

We override Multi-Head Attention module so we only need to write the get_scores method.

55class RelativeMultiHeadAttention(MultiHeadAttention):
63    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):

The linear transformations doesn’t need a bias since we take care of it when calculating scores. However having a bias for value might make sense.

67        super().__init__(heads, d_model, dropout_prob, bias=False)

Number of relative positions

70        self.P = 2 ** 12

Relative positional embeddings for key relative to the query. We need $2P$ embeddings because the keys can be before or after the query.

74        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)

Relative positional embedding bias for key relative to the query.

76        self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)

Positional embeddings for the query is independent of the position of the query

78        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

Get relative attention scores

With absolute attention

where $Q_i, K_j$, are linear transformations of original embeddings $X^q_i, X^k_j$ and $U^Q_i, U^K_j$ are linear transformations of absolute positional encodings $P_i, P_j$.

They reason out that the attention to a given key should be the same regardless of the position of query. Hence replace $\underset{\color{lightgreen}{C}}{{U^Q_i}^\top K_j}$ with a constant $\underset{\color{lightgreen}{C}}{\color{orange}{v^\top} K_j}$.

For the second and third terms relative positional encodings are introduced. So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$ and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$ with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.

80    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

$\color{orange}{R_k}$

119        key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]

$\color{orange}{S_k}$

121        key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]

$\color{orange}{v^\top}$

123        query_pos_bias = self.query_pos_bias[None, None, :, :]

${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} = Q_i^\top K_j + \color{orange}{v^\top} K_jZ$

128        ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)

$\color{lightgreen}{\mathbf{B’}_{i,k}} = Q_i^\top \color{orange}{R_k}$

130        b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)

$\color{lightgreen}{\mathbf{D’}_{i,k}} = \color{orange}{S_k}$

132        d = key_pos_bias[None, :, None, :]

Shift the rows of $\color{lightgreen}{\mathbf{(B’ + D’)}_{i,k}}$ to get

135        bd = shift_right(b + d)

Remove extra positions

137        bd = bd[:, -key.shape[0]:]

Return the sum

145        return ac + bd
148def _test_shift_right():
149    x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
150    inspect(x)
151    inspect(shift_right(x))
152
153    x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
154    inspect(x[:, :, 0, 0])
155    inspect(shift_right(x)[:, :, 0, 0])
156
157    x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
158    inspect(x[:, :, 0, 0])
159    inspect(shift_right(x)[:, :, 0, 0])
160
161
162if __name__ == '__main__':
163    _test_shift_right()