これは、論文「アテンション・フリー・トランスフォーマー」をPyTorchで実装したものです。
この論文では、セルフアテンション層を新しい効率的な操作に置き換えます。この操作では、メモリの複雑さが、はシーケンスの長さ、は埋め込みの次元です。
この論文では、AFTとAFTローカルおよびAFT-Convについて紹介しています。ここでは、自己回帰モデルで近傍のトークンに注目するAFT-Localを実装しました
。AFT(MHA と同様)は、まず埋め込みを学習した重み付きのクエリ、キー、値のテンソルに変換します。各ポジションの出力は、次の操作で計算されます。
ここで、は要素ごとの積、は非線形性 (シグモイド) で、ペアごとの位置バイアスの学習行列です。
つまり、値の加重平均値にクエリを掛けます。これにより、MHA が必要とするアテンションマトリックスを計算する必要がなくなるため、必要なメモリ量が少なくなります
。AFT Localは、学習したペアワイズ位置バイアスをローカルにのみ適用します。
、ローカルウィンドウのサイズはどこですか。
ローカルウィンドウの外にありますが、AFT 操作では他の領域のキーと値のペアが引き続き使用されます。これは、ローカルウィンドウの外に埋め込まれたものが完全に見えないローカルトランスフォーマーとは異なります
。59from typing import Optional
60
61import torch
62from torch import nn
63
64from labml_helpers.module import Module67class AFTLocal(Module):d_model
はquery
、key
value
およびベクトル内の特徴の数です。seq_len
は local_window_size
はローカルウィンドウサイズです bias
、の変換にバイアスパラメータを設定するかどうかです。  86    def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):94        super().__init__()ローカルウィンドウサイズ
97        self.local_window_size = local_window_sizeこれらはquery
、、key
value
およびベクトルを変換します。
99        self.query = nn.Linear(d_model, d_model, bias=bias)
100        self.key = nn.Linear(d_model, d_model, bias=bias)
101        self.value = nn.Linear(d_model, d_model, bias=bias)ペアワイズの位置バイアス
103        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)用マスク
105        self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)アクティベーション
107        self.activation = nn.Sigmoid()出力レイヤー
109        self.output = nn.Linear(d_model, d_model)111    @staticmethod
112    def create_local_mask(seq_len, local_window_size):1 に初期化
128        local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)ゼロにする
130        local_mask = torch.tril(local_mask, local_window_size - 1)ゼロにする
132        local_mask = torch.triu(local_mask, -(local_window_size - 1))135        return local_maskquery
、key
value
およびは、クエリ、キー、および値のトークン埋め込みのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]
。
mask
[seq_len, seq_len, batch_size]
形状があり、バッチの場合b
、mask[i, j, b]
i
その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j
137    def forward(self, *,
138                query: torch.Tensor,
139                key: torch.Tensor,
140                value: torch.Tensor,
141                mask: Optional[torch.Tensor] = None):query
、key
value
そして形がある [seq_len, batch_size, d_model]
153        seq_len, _, _ = query.shape
154
155        if mask is not None:mask
には形状があり[seq_len_q, seq_len_k, batch_size]
、最初の次元はクエリ次元です。クエリディメンションがそれと等しい場合はブロードキャストされます
159            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
160            assert mask.shape[1] == key.shape[0]
161            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]クエリ、キー、値の埋め込みを変換
164        query = self.query(query)
165        key = self.key(key)
166        value = self.value(value)179        pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
180        pos_bias = pos_bias.unsqueeze(-1)
181        pos_bias.masked_fill_(~mask, float('-inf'))ソフトマックスの計算を安定させるために、指数を計算する前に減算します。
大きいと巨大になり、の計算が不安定になります。分子と分母から指数を計算する前に定数を引くと相殺され、計算を安定させるのに役立ちます。そこで、減算して計算を安定させます。
203        max_key = key.max(dim=0, keepdims=True)[0]
204        max_pos_bias = pos_bias.max(dim=1,  keepdims=True)[0]207        exp_key = torch.exp(key - max_key)209        exp_pos_bias = torch.exp(pos_bias - max_pos_bias)分子部分
212        num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)分母部分
214        den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)[出力]
219        y = self.activation(query) * num / den出力レイヤー
222        return self.output(y)ローカルマスクをテスト
225def _test_local_mask():229    from labml.logger import inspect
230    inspect(AFTLocal.create_local_mask(10, 4))234if __name__ == '__main__':
235    _test_local_mask()