線形バイアスによる注意 (AliBi)

これは、「トレインショート、テストロング」という論文の「線形バイアスによる注意(AliBi)」の実装です。線形バイアスによる注意により、入力の長さの推定が可能になります

これにより、位置エンコーディングがアテンションスコア(ソフトマックスの前のアテンションロジット)にバイアスが加わったものに置き換わります。これは自己回帰タスクでテストされた相対的なスキームで、近くにあるトークンの方がバイアスが大きく、遠いトークンの方がバイアスが低くなります。対数スケールでは(ソフトマックスの前なので)バイアスは直線的に減少し、各ヘッドの傾きは異なります

-th トークンのアテンションフォーミュラは次のとおりです。

ここで、-th トークンのクエリ、までのキーおよびヘッドあたりのフィーチャ数です。上記の等式は変換に不変であるため中止されることに注意してください (結果を変更せずにすべての要素に任意の定数を追加できます

)。

AliBi モデルのトレーニングコードは次のとおりです

33import math
34from typing import Optional
35
36import torch
37from torch import nn
38
39from labml.logger import inspect
40from labml_nn.transformers.mha import MultiHeadAttention

各頭部の頭部固有の勾配を取得

  • n_heads アテンションレイヤーのヘッド数です

1 番目のヘッドの勾配は

残りのヘッドの勾配は幾何学的に連続しており、その比率は上記と同じです。

たとえば、ヘッドの数がの場合、スロープは

43def get_slopes(n_heads: int):

2 n_heads の累乗に最も近いものを求めます。が 2 n_heads の累乗でない場合は、まず 2 に最も近い (小さな) 累乗までの勾配を計算し、次に残りの勾配を加算します

62    n = 2 ** math.floor(math.log2(n_heads))

64    m_0 = 2.0 ** (-8.0 / n)

66    m = torch.pow(m_0, torch.arange(1, 1 + n))

n_heads が 2 の累乗でない場合は、残りの勾配を加算します。残りの勾配を計算します (以前に追加された勾配は除きます)。そして、n_heads 上の斜面を選んでください

.
71    if n < n_heads:

73        m_hat_0 = 2.0 ** (-4.0 / n)

なお、以前にスロープが追加されないように対策を講じています。

76        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))

スロープを残りのスロープと連結します。

78        m = torch.cat([m, m_hat])
79
80    return m

注意バイアスマトリックスの計算

  • n_heads アテンションレイヤーのヘッド数です
  • mask シェイプの注意マスクです [seq_len_q, seq_len_k]

これにより、AliBi [seq_len_q, seq_len_k, n_heads, ] の注意バイアスが入った形状のマトリックスが返されます。

83@torch.no_grad()
84def get_alibi_biases(n_heads: int, mask: torch.Tensor):

各ヘッドのスロープを取得

95    m = get_slopes(n_heads).to(mask.device)

距離の計算ここではマスクを使って距離を計算します。

カジュアルマスクなのでそのまま使えます。distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]

102    distance = mask.cumsum(dim=-1)

それらをペアごとに乗算して、AliBi バイアスマトリックスを求めます。

105    return distance[:, :, None] * m[None, None, :]

線形バイアスによる注意 (AliBi)

マルチヘッドアテンションを無効にします

108class AlibiMultiHeadAttention(MultiHeadAttention):
115    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
116        super().__init__(heads, d_model, dropout_prob)

AliBi にバイアスをキャッシュするには

119        self.alibi_biases = None

querykey value およびは、クエリキーおよび値のベクトルのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]

mask [seq_len, seq_len, batch_size] 形状があり、バッチの場合bmask[i, j, b] i その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j

121    def forward(self, *,
122                query: torch.Tensor,
123                key: torch.Tensor,
124                value: torch.Tensor,
125                mask: Optional[torch.Tensor] = None):

AliBi は因果マスクでのみ機能します。

137        assert mask is not None
138        assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1

querykey value そして形がある [seq_len, batch_size, d_model]

141        seq_len, batch_size, _ = query.shape

マスクに頭部の寸法を追加し、形状を確認します。

144        mask = self.prepare_mask(mask, query.shape, key.shape)

query key value 注意力計算の準備をして[seq_len, batch_size, heads, d_k] これで形ができあがります。

148        query = self.query(query)
149        key = self.key(key)
150        value = self.value(value)

アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads] これにより形状のテンソルが得られます

154        scores = self.get_scores(query, key)

スケールスコア

157        scores *= self.scale

キャッシュされていない場合はAliBiバイアスを作成する

160        if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:
162            self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])

AliBi バイアスをアテンションスコアに追加します。AliBi [seq_len, seq_len, n_heads] scores バイアスには形と形がある [seq_len, seq_len, batch_size, n_heads]

167        scores += self.alibi_biases[:seq_len, :seq_len, None, :]

マスクを適用

170        scores = scores.masked_fill(mask == 0, float('-inf'))

キーシーケンス次元に沿って注目

174        attn = self.softmax(scores)

ドロップアウトを適用

177        attn = self.dropout(attn)

値による乗算

181        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

複数のヘッドを連結

184        x = x.reshape(seq_len, batch_size, -1)

出力レイヤー

187        return self.output(x)

スロープを確認できる簡単なテスト機能。

190def _test_alibi():
194    inspect(get_slopes(12).tolist(), _n=-1)
195    from labml_nn.transformers.utils import subsequent_mask
196
197    mask = subsequent_mask(8)[:, :, 0]
198    inspect(mask)
199
200    inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)

204if __name__ == '__main__':
205    _test_alibi()