MLP (GMLP) にご注意ください

これは、論文「MLPに注意してをPyTorchで実装したものです

この論文では、ゲーティングを備えた多層パーセプトロン(MLP)ベースのアーキテクチャ(GMLPと名付けられています)を紹介します。gMLP ブロックのスタックで構成されています

gMLPモデルベースの自己回帰モデルのトレーニングコードは次のとおりです

19from typing import Optional
20
21import torch
22from torch import nn

GmLP ブロック

各ブロックは、入力埋め込みに対して次の変換を行います。ここで、はシーケンスの長さ、は埋め込みの次元です。

学習可能な投影重みの位置と位置は以下に定義するスペーシャル・ゲーティング・ユニットです。の出力次元はの半分になります。はGelUのようなアクティベーション関数です

25class GMLPBlock(nn.Module):
  • d_model はの次元 ()
  • d_ffn の次元です
  • seq_len はトークン・シーケンスの長さ ()
46    def __init__(self, d_model: int, d_ffn: int, seq_len: int):
52        super().__init__()

プレノルムの正規化レイヤー

54        self.norm = nn.LayerNorm([d_model])

アクティベーション機能

56        self.activation = nn.GELU()

の投影レイヤー

58        self.proj1 = nn.Linear(d_model, d_ffn)

空間ゲートユニット

60        self.sgu = SpacialGatingUnit(d_ffn, seq_len)

の投影レイヤー

62        self.proj2 = nn.Linear(d_ffn // 2, d_model)
66        self.size = d_model
  • x 形状の入力埋め込みテンソルです [seq_len, batch_size, d_model]
  • mask は、[seq_len, seq_len, 1] トークン同士の可視性を制御するブーリアンシェイプマスクです。
  • 68    def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):

    ショートカット接続用にコピーを保存

    75        shortcut = x

    ノーマライズ

    77        x = self.norm(x)

    プロジェクションとアクティベーション

    79        z = self.activation(self.proj1(x))

    空間ゲートユニット

    81        z = self.sgu(z, mask)

    最終投影

    83        z = self.proj2(z)

    ショートカット接続を追加する

    86        return z + shortcut

    空間ゲートユニット

    ここで、はシーケンス次元に沿った線形変換で、は要素単位の乗算です。チャネル寸法(埋め込み寸法)に沿って同じサイズの2つの部分に分割されます

    89class SpacialGatingUnit(nn.Module):
    • d_z の次元です
    • seq_len はシーケンスの長さです
    99    def __init__(self, d_z: int, seq_len: int):
    104        super().__init__()

    適用前の正規化レイヤー

    106        self.norm = nn.LayerNorm([d_z // 2])

    重量 (イン)

    この論文では、重みを小さい値に初期化し、バイアスをに初期化することが重要であると述べています。そうすれば、最初のトレーニング中に(分割は別として)同一に近いものになります。

    111        self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)

    重量 (イン)

    この論文では、バイアスをに初期化することが重要だと指摘しています。

    115        self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
    • z 形状の入力です [seq_len, batch_size, d_z]
    • mask is は、[seq_len, seq_len, 1] トークン同士の可視性を制御するブーリアンマスクです。1 サイズの最後のディメンションはバッチです。これは他のトランスフォーマー実装にもありますが、互換性のために残されています
    117    def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):

    シーケンスの長さを取得

    126        seq_len = z.shape[0]

    とに分割

    128        z1, z2 = torch.chunk(z, 2, dim=-1)

    チェックマスク

    131        if mask is not None:

    mask 形があります[seq_len_q, seq_len_k, batch_size]1 この実装ではバッチ内のすべてのサンプルに対して同じマスクしかサポートしないため、バッチディメンションはサイズにする必要があります。

    135            assert mask.shape[0] == 1 or mask.shape[0] == seq_len
    136            assert mask.shape[1] == seq_len

    ここでは、すべてのサンプルで同じマスクのみをサポートしています

    138            assert mask.shape[2] == 1

    バッチディメンションを削除する

    140            mask = mask[:, :, 0]

    前にノーマライズ

    143        z2 = self.norm(z2)

    ウェイトマトリックスを取得。これより大きい場合は切り捨てる seq_len

    145        weight = self.weight[:seq_len, :seq_len]

    ウェイトにマスクをかけます。

    もしそうなら、トークンから情報を取得することはない

    150        if mask is not None:
    151            weight = weight * mask

    154        z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]

    157        return z1 * z2