24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker

マルチヘッド・アテンションに備えましょう

このモジュールは線形変換を行い、ベクトルを指定された数のヘッドに分割してマルチヘッドアテンションを行います。これは、キークエリおよび値のベクトルを変換するために使用されます

33class PrepareForMultiHeadAttention(nn.Module):
44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()

線形変換用の線形層

47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

ヘッド数

49        self.heads = heads

各ヘッドのベクトルの次元数

51        self.d_k = d_k
53    def forward(self, x: torch.Tensor):

[seq_len, batch_size, d_model] [batch_size, d_model] 入力の形状はまたはです。線形変換を最後の次元に適用し、それを頭に分割します。

57        head_shape = x.shape[:-1]

線形変換

60        x = self.linear(x)

最後のディメンションをヘッドに分割

63        x = x.view(*head_shape, self.heads, self.d_k)

[seq_len, batch_size, heads, d_k] 出力の形状があるか [batch_size, heads, d_model]

66        return x

マルチヘッドアテンションモジュール

query 与えられたベクトルやベクトルに対して、スケーリングされたマルチヘッド・アテンションを計算します。key value

簡単に言うと、クエリに一致するキーを見つけ、それらのキーの値を取得します。

クエリとキーのドット積がどの程度一致しているかを示す指標として使用します。撮影前にドットプロダクトをスケーリングします。これは、ドット積値が大きい場合に softmax のグラデーションが非常に小さくなる原因とならないようにするためです

Softmax は、シーケンス (または時間) の軸に沿って計算されます。

69class MultiHeadAttention(nn.Module):
  • heads は頭の数です。
  • d_modelquerykey value およびベクトル内の特徴の数です。
90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96        super().__init__()

ヘッドあたりの機能数

99        self.d_k = d_model // heads

ヘッド数

101        self.heads = heads

これらはquery 、、key value のベクトルを変えて、多面的な注意を促します。

104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

時間軸に沿った注目のソフトマックス key

109        self.softmax = nn.Softmax(dim=1)

出力レイヤー

112        self.output = nn.Linear(d_model, d_model)

ドロップアウト

114        self.dropout = nn.Dropout(dropout_prob)

ソフトマックス前のスケーリングファクター

116        self.scale = 1 / math.sqrt(self.d_k)

必要に応じてロギングやその他の計算に使用できるように、アテンションを保存します

119        self.attn = None

クエリとキー間のスコアの計算

この方法は、相対的注意力などの他のバリエーションではオーバーライドできます。

121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

計算または

129        return torch.einsum('ibhd,jbhd->ijbh', query, key)

mask には形状があり[seq_len_q, seq_len_k, batch_size] 、最初の次元はクエリ次元です。クエリディメンションがそれと等しい場合はブロードキャストされます

131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138        assert mask.shape[1] == key_shape[0]
139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

すべての頭に同じマスクをかけました。

142        mask = mask.unsqueeze(-1)

生成されるマスクには形状があります [seq_len_q, seq_len_k, batch_size, heads]

145        return mask

querykey value およびは、クエリキーおよび値のベクトルのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]

mask [seq_len, seq_len, batch_size] 形状があり、バッチの場合bmask[i, j, b] i その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j

147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,
151                mask: Optional[torch.Tensor] = None):

querykey value そして形がある [seq_len, batch_size, d_model]

163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:
166            mask = self.prepare_mask(mask, query.shape, key.shape)

query key value 注意力計算の準備をして[seq_len, batch_size, heads, d_k] これで形ができあがります。

170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)

アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads] これにより形状のテンソルが得られます

176        scores = self.get_scores(query, key)

スケールスコア

179        scores *= self.scale

マスクを適用

182        if mask is not None:
183            scores = scores.masked_fill(mask == 0, float('-inf'))

キーシーケンス次元に沿って注目

187        attn = self.softmax(scores)

デバッグ時の注意事項を保存

190        tracker.debug('attn', attn)

ドロップアウトを適用

193        attn = self.dropout(attn)

値による乗算

197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

他の計算に注意を向けておく

200        self.attn = attn.detach()

複数のヘッドを連結

203        x = x.reshape(seq_len, batch_size, -1)

出力レイヤー

206        return self.output(x)