これは、論文「アテンション・フリー・トランスフォーマー」をPyTorchで実装したものです。
この論文では、セルフアテンション層を新しい効率的な操作に置き換えます。この操作では、メモリの複雑さが、はシーケンスの長さ、は埋め込みの次元です。
この論文では、AFTとAFTローカルおよびAFT-Convについて紹介しています。ここでは、自己回帰モデルで近傍のトークンに注目するAFT-Localを実装しました
。AFT(MHA と同様)は、まず埋め込みを学習した重み付きのクエリ、キー、値のテンソルに変換します。各ポジションの出力は、次の操作で計算されます。
ここで、は要素ごとの積、は非線形性 (シグモイド) で、ペアごとの位置バイアスの学習行列です。
つまり、値の加重平均値にクエリを掛けます。これにより、MHA が必要とするアテンションマトリックスを計算する必要がなくなるため、必要なメモリ量が少なくなります
。AFT Localは、学習したペアワイズ位置バイアスをローカルにのみ適用します。
、ローカルウィンドウのサイズはどこですか。
ローカルウィンドウの外にありますが、AFT 操作では他の領域のキーと値のペアが引き続き使用されます。これは、ローカルウィンドウの外に埋め込まれたものが完全に見えないローカルトランスフォーマーとは異なります
。59from typing import Optional
60
61import torch
62from torch import nn
63
64from labml_helpers.module import Module
67class AFTLocal(Module):
d_model
はquery
、key
value
およびベクトル内の特徴の数です。seq_len
は local_window_size
はローカルウィンドウサイズです bias
、の変換にバイアスパラメータを設定するかどうかです。 86 def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
94 super().__init__()
ローカルウィンドウサイズ
97 self.local_window_size = local_window_size
これらはquery
、、key
value
およびベクトルを変換します。
99 self.query = nn.Linear(d_model, d_model, bias=bias)
100 self.key = nn.Linear(d_model, d_model, bias=bias)
101 self.value = nn.Linear(d_model, d_model, bias=bias)
ペアワイズの位置バイアス
103 self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)
用マスク
105 self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)
アクティベーション
107 self.activation = nn.Sigmoid()
出力レイヤー
109 self.output = nn.Linear(d_model, d_model)
111 @staticmethod
112 def create_local_mask(seq_len, local_window_size):
1 に初期化
128 local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)
ゼロにする
130 local_mask = torch.tril(local_mask, local_window_size - 1)
ゼロにする
132 local_mask = torch.triu(local_mask, -(local_window_size - 1))
135 return local_mask
query
、key
value
およびは、クエリ、キー、および値のトークン埋め込みのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]
。
mask
[seq_len, seq_len, batch_size]
形状があり、バッチの場合b
、mask[i, j, b]
i
その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j
137 def forward(self, *,
138 query: torch.Tensor,
139 key: torch.Tensor,
140 value: torch.Tensor,
141 mask: Optional[torch.Tensor] = None):
query
、key
value
そして形がある [seq_len, batch_size, d_model]
153 seq_len, _, _ = query.shape
154
155 if mask is not None:
mask
には形状があり[seq_len_q, seq_len_k, batch_size]
、最初の次元はクエリ次元です。クエリディメンションがそれと等しい場合はブロードキャストされます
159 assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
160 assert mask.shape[1] == key.shape[0]
161 assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]
クエリ、キー、値の埋め込みを変換
164 query = self.query(query)
165 key = self.key(key)
166 value = self.value(value)
179 pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
180 pos_bias = pos_bias.unsqueeze(-1)
181 pos_bias.masked_fill_(~mask, float('-inf'))
ソフトマックスの計算を安定させるために、指数を計算する前に減算します。
大きいと巨大になり、の計算が不安定になります。分子と分母から指数を計算する前に定数を引くと相殺され、計算を安定させるのに役立ちます。そこで、減算して計算を安定させます。
203 max_key = key.max(dim=0, keepdims=True)[0]
204 max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]
207 exp_key = torch.exp(key - max_key)
209 exp_pos_bias = torch.exp(pos_bias - max_pos_bias)
分子部分
212 num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)
分母部分
214 den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)
[出力]
219 y = self.activation(query) * num / den
出力レイヤー
222 return self.output(y)
ローカルマスクをテスト
225def _test_local_mask():
229 from labml.logger import inspect
230 inspect(AFTLocal.create_local_mask(10, 4))
234if __name__ == '__main__':
235 _test_local_mask()