これは、PyTorchの論文「注意さえあれば十分」の「多面的な注意」のチュートリアル/実装です。実装は注釈付きトランスフォーマーから着想を得ています
。24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker
このモジュールは線形変換を行い、ベクトルを指定された数のヘッドに分割してマルチヘッドアテンションを行います。これは、キー、クエリ、および値のベクトルを変換するために使用されます。
33class PrepareForMultiHeadAttention(nn.Module):
44 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45 super().__init__()
線形変換用の線形層
47 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
ヘッド数
49 self.heads = heads
各ヘッドのベクトルの次元数
51 self.d_k = d_k
53 def forward(self, x: torch.Tensor):
[seq_len, batch_size, d_model]
[batch_size, d_model]
入力の形状はまたはです。線形変換を最後の次元に適用し、それを頭に分割します。
57 head_shape = x.shape[:-1]
線形変換
60 x = self.linear(x)
最後のディメンションをヘッドに分割
63 x = x.view(*head_shape, self.heads, self.d_k)
[seq_len, batch_size, heads, d_k]
出力の形状があるか [batch_size, heads, d_model]
66 return x
query
与えられたベクトルやベクトルに対して、スケーリングされたマルチヘッド・アテンションを計算します。key
value
簡単に言うと、クエリに一致するキーを見つけ、それらのキーの値を取得します。
クエリとキーのドット積がどの程度一致しているかを示す指標として使用します。撮影前にドットプロダクトをスケーリングします。これは、ドット積値が大きい場合に softmax のグラデーションが非常に小さくなる原因とならないようにするためです
。Softmax は、シーケンス (または時間) の軸に沿って計算されます。
69class MultiHeadAttention(nn.Module):
heads
は頭の数です。d_model
はquery
、key
value
およびベクトル内の特徴の数です。90 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96 super().__init__()
ヘッドあたりの機能数
99 self.d_k = d_model // heads
ヘッド数
101 self.heads = heads
これらはquery
、、key
value
のベクトルを変えて、多面的な注意を促します。
104 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
時間軸に沿った注目のソフトマックス key
109 self.softmax = nn.Softmax(dim=1)
出力レイヤー
112 self.output = nn.Linear(d_model, d_model)
ドロップアウト
114 self.dropout = nn.Dropout(dropout_prob)
ソフトマックス前のスケーリングファクター
116 self.scale = 1 / math.sqrt(self.d_k)
必要に応じてロギングやその他の計算に使用できるように、アテンションを保存します
119 self.attn = None
121 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
計算または
129 return torch.einsum('ibhd,jbhd->ijbh', query, key)
mask
には形状があり[seq_len_q, seq_len_k, batch_size]
、最初の次元はクエリ次元です。クエリディメンションがそれと等しい場合はブロードキャストされます
131 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138 assert mask.shape[1] == key_shape[0]
139 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
すべての頭に同じマスクをかけました。
142 mask = mask.unsqueeze(-1)
生成されるマスクには形状があります [seq_len_q, seq_len_k, batch_size, heads]
145 return mask
query
、key
value
およびは、クエリ、キー、および値のベクトルのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]
。
mask
[seq_len, seq_len, batch_size]
形状があり、バッチの場合b
、mask[i, j, b]
i
その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j
147 def forward(self, *,
148 query: torch.Tensor,
149 key: torch.Tensor,
150 value: torch.Tensor,
151 mask: Optional[torch.Tensor] = None):
query
、key
value
そして形がある [seq_len, batch_size, d_model]
163 seq_len, batch_size, _ = query.shape
164
165 if mask is not None:
166 mask = self.prepare_mask(mask, query.shape, key.shape)
query
key
value
注意力計算の準備をして[seq_len, batch_size, heads, d_k]
これで形ができあがります。
170 query = self.query(query)
171 key = self.key(key)
172 value = self.value(value)
アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads]
これにより形状のテンソルが得られます
176 scores = self.get_scores(query, key)
スケールスコア
179 scores *= self.scale
マスクを適用
182 if mask is not None:
183 scores = scores.masked_fill(mask == 0, float('-inf'))
キーシーケンス次元に沿って注目
187 attn = self.softmax(scores)
デバッグ時の注意事項を保存
190 tracker.debug('attn', attn)
ドロップアウトを適用
193 attn = self.dropout(attn)
値による乗算
197 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
他の計算に注意を向けておく
200 self.attn = attn.detach()
複数のヘッドを連結
203 x = x.reshape(seq_len, batch_size, -1)
出力レイヤー
206 return self.output(x)