高速ウェイトトランス

PyTorchの線形トランスフォーマーは密かに高速な重み記憶システムであるという論文では、線形自己アテンションシステムと高速重みシステムの類似点を見出し、それに基づいて自己アテンションの更新ルールを修正しています。また、よりシンプルでありながら効果的なカーネル機能も導入されています

著者らは、論文で比較した他のバリエーションも含めて、論文の公式な実装を提供しています

ファーストウェイト

入力のシーケンスまたは長さを考えてみると、各ステップはサイズのベクトルになります。高速ウェイトモデルでは、各ステップでウェイトマトリックスを生成して出力を生成します

は外積 () で、2 つのベクトルの要素を掛け合わせて行列になります。起動機能です。トレーニング可能なウェイト (パラメータ) です。は各ステップで生成される高速ウェイトです

直線的な自己注意

オリジナルトランスフォーマーの自己アテンションは、(わかりやすくするために省略

どこ

セルフアテンションの線形化の背後にある考え方は、セルフアテンション関数の分母をより速く計算できるように、softmax カーネルを別のカーネルに置き換えることです。

これにより

とを使うと、それらを効率的に計算できます。

これはファーストウェイトとよく似ています。

この論文では、新しい線形注意投影関数正規化の新しい更新規則と変更について紹介しています。

以下は、Tiny Shakespeareデータセットで高速ウェイトトランスフォーマーをトレーニングするためのトレーニングコードとノートブックです

Open In Colab

95import torch
96from torch import nn
97
98from labml_helpers.module import Module
99from labml_nn.transformers.feed_forward import FeedForward
100from labml_nn.transformers.mha import PrepareForMultiHeadAttention
101from labml_nn.utils import clone_module_list

決定論的パラメータフリープロジェクト (DPFP)

これは論文で紹介した新しい投影機能です。DPFP は次元と次元の関係を投影します。ここではハイパーパラメータです

ここで、はを連結してサイズのベクトル、を返します。はベクトルの -番目の要素で、が内の要素数より大きい場合はロールアラウンドされます

基本的には、 shiftdの要素を乗算して新しいベクトルを作成します。

この結果、投影はまばらで (0 以外の要素はごくわずか)、直交 (ほとんどの場合で例外は除く)、非常によく似た投影が生成されます。

ノーマライゼーション

論文では、以下の簡単な正規化について紹介しています。

論文の導出を確認してください。

104class DPFP(Module):
  • nu はハイパーパラメータです。
  • eps 正規化時にゼロで除算されないようにするために使用される小さな値です。
138    def __init__(self, nu: int = 1, eps: float = 1e-6):
143        super().__init__()
144        self.nu = nu
145        self.relu = nn.ReLU()
146        self.eps = eps
148    def forward(self, k: torch.Tensor):

取得

150        k = self.dpfp(k)

による正規化

152        return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)

154    def dpfp(self, k: torch.Tensor):

159        x = self.relu(torch.cat([k, -k], dim=-1))

シフトしてロールバイすると

162        x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, self.nu + 1)]

連結して取得

165        x_rolled = torch.cat(x_rolled, dim=-1)

のコピーを連結

167        x_repeat = torch.cat([x] * self.nu, dim=-1)

それらを掛け合わせて、

173        return x_repeat * x_rolled

ファストウエイト注意

この論文では、計算に関する新しい更新ルールを紹介しています。モデルはまず、キーとペアになっている現在の値を取得します。次に、取得した値と入力の組み合わせを格納します

where はトレーニング可能なパラメーター、はシグモイド関数です。

は正規化されているので正規化項は必要ないことに注意してください。

176class FastWeightsAttention(Module):
204    def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
205        super().__init__()

ヘッドあたりの機能数

208        self.d_k = d_model // heads

ヘッド数

210        self.heads = heads

これらはquerykey value そして多面的な注意力を変えます。

213        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
214        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
215        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

各ヘッドの補間ウェイト関数

218        self.interpolation_weight = nn.Sequential(
219            PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
220            nn.Sigmoid()
221        )

224        self.phi = phi

出力レイヤー

227        self.output = nn.Linear(d_model, d_model)

ドロップアウト

229        self.dropout = nn.Dropout(dropout_prob)
231    def forward(self, x: torch.Tensor):

ステップ数を取得

233        seq_len = x.shape[0]

すべてのステップとヘッドに対応

235        query = self.phi(self.query(x))

すべてのステップとヘッドに対応

237        key = self.phi(self.key(x))

すべてのステップとヘッドに対応

239        value = self.value(x)

すべてのステップとヘッドに対応

241        beta = self.interpolation_weight(x)

244        weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))

出力を保存するリスト

246        outputs = []

手順を繰り返す

249        for i in range(seq_len):

251            value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])

256            weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])

259            y = torch.einsum('bhvk,bhk->bhv', weights, query[i])

複数のヘッドを結合して追加 outputs

262            outputs.append(y.reshape(y.shape[0], -1))

各ステップの出力を1つのテンソルにスタックします

265        x = torch.stack(outputs)

出力レイヤー

268        return self.output(x)

これは、セルフアテンションとフィードフォワードネットワークを組み合わせた一般的なトランス層です。

271class FastWeightsAttentionTransformerLayer(Module):
275    def __init__(self, *,
276                 d_model: int,
277                 attn: FastWeightsAttention,
278                 feed_forward: FeedForward,
279                 dropout_prob: float):
280        super().__init__()

変圧器サイズ

282        self.size = d_model

ファストウェイトアテンションモジュール

284        self.attn = attn

フィードフォワードネットワーク

286        self.feed_forward = feed_forward

ドロップアウトレイヤー

288        self.dropout = nn.Dropout(dropout_prob)

正規化レイヤー

291        self.norm_self_attn = nn.LayerNorm([d_model])
292        self.norm_ff = nn.LayerNorm([d_model])
294    def forward(self, x: torch.Tensor):

セルフアテンションの高速ウェイト計算

296        attn = self.attn(x)

セルフアテンションの結果を追加

298        x = x + self.dropout(attn)

フィードフォワード用に正規化

301        z = self.norm_ff(x)

フィードフォワードネットワークを通過

303        ff = self.feed_forward(z)

フィードフォワードの結果を追加し直す

305        x = x + self.dropout(ff)

308        return x

これは、複数の変圧器層を備えた一般的な変圧器モジュールです。

311class FastWeightsAttentionTransformer(Module):
315    def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
316        super().__init__()

トランスレイヤーのコピーを作成

318        self.layers = clone_module_list(layer, n_layers)

最終正規化レイヤー

320        self.norm = nn.LayerNorm([layer.size])
322    def forward(self, x: torch.Tensor):
323        for i, layer in enumerate(self.layers):

レイヤー出力を取得

325            x = layer(x)

出力を正規化

328        return self.norm(x)