これは、「トレインショート、テストロング」という論文の「線形バイアスによる注意(AliBi)」の実装です。線形バイアスによる注意により、入力の長さの推定が可能になります。
これにより、位置エンコーディングがアテンションスコア(ソフトマックスの前のアテンションロジット)にバイアスが加わったものに置き換わります。これは自己回帰タスクでテストされた相対的なスキームで、近くにあるトークンの方がバイアスが大きく、遠いトークンの方がバイアスが低くなります。対数スケールでは(ソフトマックスの前なので)バイアスは直線的に減少し、各ヘッドの傾きは異なります
。-th トークンのアテンションフォーミュラは次のとおりです。
ここで、は -th トークンのクエリ、までのキー、およびヘッドあたりのフィーチャ数です。上記の等式は変換に不変であるため中止されることに注意してください (結果を変更せずにすべての要素に任意の定数を追加できます
)。AliBi モデルのトレーニングコードは次のとおりです。
33import math
34from typing import Optional
35
36import torch
37from torch import nn
38
39from labml.logger import inspect
40from labml_nn.transformers.mha import MultiHeadAttention
n_heads
アテンションレイヤーのヘッド数です 1 番目のヘッドの勾配は
残りのヘッドの勾配は幾何学的に連続しており、その比率は上記と同じです。
たとえば、ヘッドの数がの場合、スロープは
43def get_slopes(n_heads: int):
2 n_heads
の累乗に最も近いものを求めます。が 2 n_heads
の累乗でない場合は、まず 2 に最も近い (小さな) 累乗までの勾配を計算し、次に残りの勾配を加算します
62 n = 2 ** math.floor(math.log2(n_heads))
64 m_0 = 2.0 ** (-8.0 / n)
66 m = torch.pow(m_0, torch.arange(1, 1 + n))
n_heads
が 2 の累乗でない場合は、残りの勾配を加算します。残りの勾配を計算します (以前に追加された勾配は除きます)。そして、n_heads
上の斜面を選んでください
71 if n < n_heads:
73 m_hat_0 = 2.0 ** (-4.0 / n)
なお、以前にスロープが追加されないように対策を講じています。
76 m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))
スロープを残りのスロープと連結します。
78 m = torch.cat([m, m_hat])
79
80 return m
n_heads
アテンションレイヤーのヘッド数ですmask
シェイプの注意マスクです [seq_len_q, seq_len_k]
これにより、AliBi [seq_len_q, seq_len_k, n_heads, ]
の注意バイアスが入った形状のマトリックスが返されます。
83@torch.no_grad()
84def get_alibi_biases(n_heads: int, mask: torch.Tensor):
各ヘッドのスロープを取得
95 m = get_slopes(n_heads).to(mask.device)
距離の計算ここではマスクを使って距離を計算します。
カジュアルマスクなのでそのまま使えます。distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]
102 distance = mask.cumsum(dim=-1)
それらをペアごとに乗算して、AliBi バイアスマトリックスを求めます。
105 return distance[:, :, None] * m[None, None, :]
108class AlibiMultiHeadAttention(MultiHeadAttention):
115 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
116 super().__init__(heads, d_model, dropout_prob)
AliBi にバイアスをキャッシュするには
119 self.alibi_biases = None
query
、key
value
およびは、クエリ、キー、および値のベクトルのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]
。
mask
[seq_len, seq_len, batch_size]
形状があり、バッチの場合b
、mask[i, j, b]
i
その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j
121 def forward(self, *,
122 query: torch.Tensor,
123 key: torch.Tensor,
124 value: torch.Tensor,
125 mask: Optional[torch.Tensor] = None):
AliBi は因果マスクでのみ機能します。
137 assert mask is not None
138 assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1
query
、key
value
そして形がある [seq_len, batch_size, d_model]
141 seq_len, batch_size, _ = query.shape
マスクに頭部の寸法を追加し、形状を確認します。
144 mask = self.prepare_mask(mask, query.shape, key.shape)
query
key
value
注意力計算の準備をして[seq_len, batch_size, heads, d_k]
これで形ができあがります。
148 query = self.query(query)
149 key = self.key(key)
150 value = self.value(value)
アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads]
これにより形状のテンソルが得られます
154 scores = self.get_scores(query, key)
スケールスコア
157 scores *= self.scale
キャッシュされていない場合はAliBiバイアスを作成する
160 if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:
mask
図形は連番、連番、1、 1です
162 self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])
AliBi バイアスをアテンションスコアに追加します。AliBi [seq_len, seq_len, n_heads]
scores
バイアスには形と形がある [seq_len, seq_len, batch_size, n_heads]
167 scores += self.alibi_biases[:seq_len, :seq_len, None, :]
マスクを適用
170 scores = scores.masked_fill(mask == 0, float('-inf'))
キーシーケンス次元に沿って注目
174 attn = self.softmax(scores)
ドロップアウトを適用
177 attn = self.dropout(attn)
値による乗算
181 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
複数のヘッドを連結
184 x = x.reshape(seq_len, batch_size, -1)
出力レイヤー
187 return self.output(x)
スロープを確認できる簡単なテスト機能。
190def _test_alibi():
194 inspect(get_slopes(12).tolist(), _n=-1)
195 from labml_nn.transformers.utils import subsequent_mask
196
197 mask = subsequent_mask(8)[:, :, 0]
198 inspect(mask)
199
200 inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)
204if __name__ == '__main__':
205 _test_alibi()