PyTorchの線形トランスフォーマーは密かに高速な重み記憶システムであるという論文では、線形自己アテンションシステムと高速重みシステムの類似点を見出し、それに基づいて自己アテンションの更新ルールを修正しています。また、よりシンプルでありながら効果的なカーネル機能も導入されています
。著者らは、論文で比較した他のバリエーションも含めて、論文の公式な実装を提供しています。
入力のシーケンスまたは長さを考えてみると、各ステップはサイズのベクトルになります。高速ウェイトモデルでは、各ステップでウェイトマトリックスを生成して出力を生成します
。は外積 () で、2 つのベクトルの要素を掛け合わせて行列になります。起動機能です。トレーニング可能なウェイト (パラメータ) です。は各ステップで生成される高速ウェイトです
。オリジナルトランスフォーマーの自己アテンションは、(わかりやすくするために省略)
どこ
セルフアテンションの線形化の背後にある考え方は、セルフアテンション関数の分母をより速く計算できるように、softmax カーネルを別のカーネルに置き換えることです。
これにより
とを使うと、それらを効率的に計算できます。
これはファーストウェイトとよく似ています。
この論文では、新しい線形注意投影関数、正規化の新しい更新規則と変更について紹介しています。
以下は、Tiny Shakespeareデータセットで高速ウェイトトランスフォーマーをトレーニングするためのトレーニングコードとノートブックです。
95import torch
96from torch import nn
97
98from labml_helpers.module import Module
99from labml_nn.transformers.feed_forward import FeedForward
100from labml_nn.transformers.mha import PrepareForMultiHeadAttention
101from labml_nn.utils import clone_module_list
これは論文で紹介した新しい投影機能です。DPFP は次元と次元の関係を投影します。ここで、はハイパーパラメータです
。
ここで、はを連結してサイズのベクトル、、を返します。はベクトルの -番目の要素で、が内の要素数より大きい場合はロールアラウンドされます
。基本的には、 shiftdの要素を乗算して新しいベクトルを作成します。
この結果、投影はまばらで (0 以外の要素はごくわずか)、直交 (ほとんどの場合で例外は除く)、非常によく似た投影が生成されます。
論文では、以下の簡単な正規化について紹介しています。
論文の導出を確認してください。
104class DPFP(Module):
nu
はハイパーパラメータです。eps
正規化時にゼロで除算されないようにするために使用される小さな値です。138 def __init__(self, nu: int = 1, eps: float = 1e-6):
143 super().__init__()
144 self.nu = nu
145 self.relu = nn.ReLU()
146 self.eps = eps
148 def forward(self, k: torch.Tensor):
取得
150 k = self.dpfp(k)
による正規化
152 return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
154 def dpfp(self, k: torch.Tensor):
159 x = self.relu(torch.cat([k, -k], dim=-1))
シフトしてロールバイすると、
162 x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, self.nu + 1)]
連結して取得
165 x_rolled = torch.cat(x_rolled, dim=-1)
のコピーを連結
167 x_repeat = torch.cat([x] * self.nu, dim=-1)
それらを掛け合わせて、
173 return x_repeat * x_rolled
この論文では、計算に関する新しい更新ルールを紹介しています。モデルはまず、キーとペアになっている現在の値を取得します。次に、取得した値と入力の組み合わせを格納します
。where はトレーニング可能なパラメーター、はシグモイド関数です。
は正規化されているので正規化項は必要ないことに注意してください。
176class FastWeightsAttention(Module):
204 def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
205 super().__init__()
ヘッドあたりの機能数
208 self.d_k = d_model // heads
ヘッド数
210 self.heads = heads
これらはquery
、key
value
そして多面的な注意力を変えます。
213 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
214 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
215 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
各ヘッドの補間ウェイト関数
218 self.interpolation_weight = nn.Sequential(
219 PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
220 nn.Sigmoid()
221 )
224 self.phi = phi
出力レイヤー
227 self.output = nn.Linear(d_model, d_model)
ドロップアウト
229 self.dropout = nn.Dropout(dropout_prob)
231 def forward(self, x: torch.Tensor):
ステップ数を取得
233 seq_len = x.shape[0]
すべてのステップとヘッドに対応
235 query = self.phi(self.query(x))
すべてのステップとヘッドに対応
237 key = self.phi(self.key(x))
すべてのステップとヘッドに対応
239 value = self.value(x)
すべてのステップとヘッドに対応
241 beta = self.interpolation_weight(x)
244 weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))
出力を保存するリスト
246 outputs = []
手順を繰り返す
249 for i in range(seq_len):
251 value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])
256 weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])
259 y = torch.einsum('bhvk,bhk->bhv', weights, query[i])
複数のヘッドを結合して追加 outputs
262 outputs.append(y.reshape(y.shape[0], -1))
各ステップの出力を1つのテンソルにスタックします
265 x = torch.stack(outputs)
出力レイヤー
268 return self.output(x)
これは、セルフアテンションとフィードフォワードネットワークを組み合わせた一般的なトランス層です。
271class FastWeightsAttentionTransformerLayer(Module):
275 def __init__(self, *,
276 d_model: int,
277 attn: FastWeightsAttention,
278 feed_forward: FeedForward,
279 dropout_prob: float):
280 super().__init__()
変圧器サイズ
282 self.size = d_model
ファストウェイトアテンションモジュール
284 self.attn = attn
フィードフォワードネットワーク
286 self.feed_forward = feed_forward
ドロップアウトレイヤー
288 self.dropout = nn.Dropout(dropout_prob)
正規化レイヤー
291 self.norm_self_attn = nn.LayerNorm([d_model])
292 self.norm_ff = nn.LayerNorm([d_model])
294 def forward(self, x: torch.Tensor):
セルフアテンションの高速ウェイト計算
296 attn = self.attn(x)
セルフアテンションの結果を追加
298 x = x + self.dropout(attn)
フィードフォワード用に正規化
301 z = self.norm_ff(x)
フィードフォワードネットワークを通過
303 ff = self.feed_forward(z)
フィードフォワードの結果を追加し直す
305 x = x + self.dropout(ff)
308 return x
これは、複数の変圧器層を備えた一般的な変圧器モジュールです。
311class FastWeightsAttentionTransformer(Module):
315 def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
316 super().__init__()
トランスレイヤーのコピーを作成
318 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
320 self.norm = nn.LayerNorm([layer.size])
322 def forward(self, x: torch.Tensor):
323 for i, layer in enumerate(self.layers):
レイヤー出力を取得
325 x = layer(x)
出力を正規化
328 return self.norm(x)