相对多头注意力

这是 paper Transfor mer-XL:PyTorch 中固定长度上下文之外的细心语言模型中相对多头关注的实现。

16import torch
17from torch import nn
18
19from labml.logger import inspect
20from labml_nn.transformers.mha import MultiHeadAttention

此方法将矩阵的行按列移动。

如果输入为[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]] ,则移位的结果将为[[1, 2 ,3], [0, 4, 5], [9, 0, 7]]理想情况下,我们应该掩盖下三角形,但这对我们的目的来说是可以的

23def shift_right(x: torch.Tensor):

连接一列零

33    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
34    x_padded = torch.cat([x, zero_pad], dim=1)

重塑并从末端移除多余的元素

37    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
38    x = x_padded[:-1].view_as(x)

41    return x

相对多头注意模块

我们重写了多头注意模块,因此我们只需要编写该get_scores 方法即可。

44class RelativeMultiHeadAttention(MultiHeadAttention):
52    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):

线性变换不需要偏差,因为我们在计算分数时会明确包含偏差。但是,有偏见value 可能是有道理的。

56        super().__init__(heads, d_model, dropout_prob, bias=False)

相对位置的数量

59        self.P = 2 ** 12

键相对于查询的相对位置嵌入。我们需要嵌入,因为键可以在查询之前或之后。

63        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)

键相对于查询的相对位置嵌入偏差。

65        self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)

查询的位置嵌入与查询的位置无关

67        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

获取相对注意力分数

绝对关注

其中,是原始嵌入的线性变换是绝对位置编码的线性变换

他们认为,无论查询的位置如何,对给定键的关注都应该相同。因此,用常量替换

对于第二项和第三项,引入了相对位置编码。因此,替换

69    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

108        key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]]

110        key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]]

112        query_pos_bias = self.query_pos_bias[None, None, :, :]

117        ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)

119        b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)

121        d = key_pos_bias[None, :, None, :]

移动行以获取

124        bd = shift_right(b + d)

移除多余的头寸

126        bd = bd[:, -key.shape[0]:]

返回总和

134        return ac + bd
137def _test_shift_right():
138    x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
139    inspect(x)
140    inspect(shift_right(x))
141
142    x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
143    inspect(x[:, :, 0, 0])
144    inspect(shift_right(x)[:, :, 0, 0])
145
146    x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
147    inspect(x[:, :, 0, 0])
148    inspect(shift_right(x)[:, :, 0, 0])
149
150
151if __name__ == '__main__':
152    _test_shift_right()