16import torch
17from torch import nn
18
19from labml.logger import inspect
20from labml_nn.transformers.mha import MultiHeadAttention
このメソッドは、行列の行を列ごとにシフトします。
入力がの場合[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]
、シフトされた結果は次のようになります。[[1, 2 ,3], [0, 4, 5], [9, 0, 7]]
下の三角形をマスクするのが理想的ですが、この目的には問題ありません。
23def shift_right(x: torch.Tensor):
0 の列を連結する
33 zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
34 x_padded = torch.cat([x, zero_pad], dim=1)
形を変えて端から余分な要素を取り除く
37 x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
38 x = x_padded[:-1].view_as(x)
41 return x
44class RelativeMultiHeadAttention(MultiHeadAttention):
52 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
線形変換はスコアの計算時に明示的に含めるので、バイアスは必要ありません。ただし、value
偏見を持つことは理にかなっているかもしれません.
56 super().__init__(heads, d_model, dropout_prob, bias=False)
相対位置の数
59 self.P = 2 ** 12
クエリを基準としたキーの相対位置埋め込み。キーはクエリの前でも後でも構わないので、埋め込みが必要です
。63 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
クエリに対するキーの相対的な位置埋め込みバイアス。
65 self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
クエリの位置埋め込みはクエリの位置とは無関係です
67 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
絶対的な注意を払って
ここで、は元の埋め込みの線形変換で、は絶対位置エンコーディングの線形変換です。
彼らは、特定のキーへの注意は、クエリの位置に関係なく同じであるべきだと推論しています。したがって、定数に置き換えてください。
第2用語と第3用語では、相対位置エンコーディングが導入されています。So は、と、に置き換えられます。
69 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
108 key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]]
110 key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]]
112 query_pos_bias = self.query_pos_bias[None, None, :, :]
117 ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
119 b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
121 d = key_pos_bias[None, :, None, :]
行をシフトすると
124 bd = shift_right(b + d)
余分なポジションを削除
126 bd = bd[:, -key.shape[0]:]
合計を返す
134 return ac + bd
137def _test_shift_right():
138 x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
139 inspect(x)
140 inspect(shift_right(x))
141
142 x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
143 inspect(x[:, :, 0, 0])
144 inspect(shift_right(x)[:, :, 0, 0])
145
146 x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
147 inspect(x[:, :, 0, 0])
148 inspect(shift_right(x)[:, :, 0, 0])
149
150
151if __name__ == '__main__':
152 _test_shift_right()