アテンションフリーのトランスフォーマー

これは、論文「アテンション・フリー・トランスフォーマー」をPyTorchで実装したものです

この論文では、セルフアテンション層を新しい効率的な操作に置き換えます。この操作では、メモリの複雑さがはシーケンスの長さ、は埋め込みの次元です。

この論文では、AFTとAFTローカルおよびAFT-Convについて紹介しています。ここでは、自己回帰モデルで近傍のトークンに注目するAFT-Localを実装しました

アテンションフリー変圧器

AFT(MHA と同様)は、まず埋め込みを学習した重み付きのクエリ、キー値のテンソルに変換します。各ポジションの出力は、次の操作で計算されます。

ここで、は要素ごとの積、は非線形性 (シグモイド) で、ペアごとの位置バイアスの学習行列です。

つまり、値の加重平均値にクエリを掛けます。これにより、MHA が必要とするアテンションマトリックスを計算する必要がなくなるため、必要なメモリ量が少なくなります

AFT ローカル

AFT Localは、学習したペアワイズ位置バイアスをローカルにのみ適用します。

ローカルウィンドウのサイズはどこですか。

ローカルウィンドウの外にありますが、AFT 操作では他の領域のキーと値のペアが引き続き使用されます。これは、ローカルウィンドウの外に埋め込まれたものが完全に見えないローカルトランスフォーマーとは異なります

AFT Localモデルのトレーニングコードは次のとおりです

59from typing import Optional
60
61import torch
62from torch import nn
63
64from labml_helpers.module import Module

AFT ローカルオペレーション

どこ、

67class AFTLocal(Module):
  • d_modelquerykey value およびベクトル内の特徴の数です。
  • seq_len
  • local_window_size はローカルウィンドウサイズです
  • bias 、の変換にバイアスパラメータを設定するかどうかです。
  • 86    def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
    94        super().__init__()

    ローカルウィンドウサイズ

    97        self.local_window_size = local_window_size

    これらはquery 、、key value およびベクトルを変換します。

    99        self.query = nn.Linear(d_model, d_model, bias=bias)
    100        self.key = nn.Linear(d_model, d_model, bias=bias)
    101        self.value = nn.Linear(d_model, d_model, bias=bias)

    ペアワイズの位置バイアス

    103        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)

    用マスク

    105        self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)

    アクティベーション

    107        self.activation = nn.Sigmoid()

    出力レイヤー

    109        self.output = nn.Linear(d_model, d_model)

    ローカルマスクの作成

    これにより次のマスクが作成されます

    111    @staticmethod
    112    def create_local_mask(seq_len, local_window_size):

    1 に初期化

    128        local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)

    ゼロにする

    130        local_mask = torch.tril(local_mask, local_window_size - 1)

    ゼロにする

    132        local_mask = torch.triu(local_mask, -(local_window_size - 1))

    135        return local_mask

    query key value およびは、クエリ、キー、および値のトークン埋め込みのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]

    mask [seq_len, seq_len, batch_size] 形状があり、バッチの場合bmask[i, j, b] i その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j

    137    def forward(self, *,
    138                query: torch.Tensor,
    139                key: torch.Tensor,
    140                value: torch.Tensor,
    141                mask: Optional[torch.Tensor] = None):

    querykey value そして形がある [seq_len, batch_size, d_model]

    153        seq_len, _, _ = query.shape
    154
    155        if mask is not None:

    mask には形状があり[seq_len_q, seq_len_k, batch_size] 、最初の次元はクエリ次元です。クエリディメンションがそれと等しい場合はブロードキャストされます

    159            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
    160            assert mask.shape[1] == key.shape[0]
    161            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

    クエリ、キー、値の埋め込みを変換

    164        query = self.query(query)
    165        key = self.key(key)
    166        value = self.value(value)

    取得

    マスクを使う

    179        pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
    180        pos_bias = pos_bias.unsqueeze(-1)
    181        pos_bias.masked_fill_(~mask, float('-inf'))

    計算し別々に行列の乗算を行います。わかりやすくするためにアインサムを使用しています。

    ソフトマックスの計算を安定させるために、指数を計算する前に減算します。

    大きいと巨大になり、の計算が不安定になります。分子と分母から指数を計算する前に定数を引くと相殺され、計算を安定させるのに役立ちます。そこで、減算して計算を安定させます。

    203        max_key = key.max(dim=0, keepdims=True)[0]
    204        max_pos_bias = pos_bias.max(dim=1,  keepdims=True)[0]

    207        exp_key = torch.exp(key - max_key)

    209        exp_pos_bias = torch.exp(pos_bias - max_pos_bias)

    分子部分

    212        num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)

    分母部分

    214        den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)

    [出力]

    219        y = self.activation(query) * num / den

    出力レイヤー

    222        return self.output(y)

    ローカルマスクをテスト

    225def _test_local_mask():
    229    from labml.logger import inspect
    230    inspect(AFTLocal.create_local_mask(10, 4))

    234if __name__ == '__main__':
    235    _test_local_mask()