これは、論文「MLPに注意して」をPyTorchで実装したものです。
この論文では、ゲーティングを備えた多層パーセプトロン(MLP)ベースのアーキテクチャ(GMLPと名付けられています)を紹介します。gMLP ブロックのスタックで構成されています
。19from typing import Optional
20
21import torch
22from torch import nn各ブロックは、入力埋め込みに対して次の変換を行います。ここで、はシーケンスの長さ、は埋め込みの次元です。
学習可能な投影重みの位置と位置は以下に定義するスペーシャル・ゲーティング・ユニットです。の出力次元はの半分になります。はGelUのようなアクティベーション関数です
。25class GMLPBlock(nn.Module):d_model
はの次元 ()  d_ffn
の次元です seq_len
はトークン・シーケンスの長さ ()46    def __init__(self, d_model: int, d_ffn: int, seq_len: int):52        super().__init__()プレノルムの正規化レイヤー
54        self.norm = nn.LayerNorm([d_model])アクティベーション機能
56        self.activation = nn.GELU()の投影レイヤー
58        self.proj1 = nn.Linear(d_model, d_ffn)空間ゲートユニット
60        self.sgu = SpacialGatingUnit(d_ffn, seq_len)の投影レイヤー
62        self.proj2 = nn.Linear(d_ffn // 2, d_model)66        self.size = d_modelx
形状の入力埋め込みテンソルです [seq_len, batch_size, d_model]
mask
は、[seq_len, seq_len, 1]
トークン同士の可視性を制御するブーリアンシェイプマスクです。68    def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):ショートカット接続用にコピーを保存
75        shortcut = xノーマライズ
77        x = self.norm(x)プロジェクションとアクティベーション
79        z = self.activation(self.proj1(x))空間ゲートユニット
81        z = self.sgu(z, mask)最終投影
83        z = self.proj2(z)ショートカット接続を追加する
86        return z + shortcut89class SpacialGatingUnit(nn.Module):d_z
の次元です seq_len
はシーケンスの長さです99    def __init__(self, d_z: int, seq_len: int):104        super().__init__()適用前の正規化レイヤー
106        self.norm = nn.LayerNorm([d_z // 2])111        self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)115        self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)z
形状の入力です [seq_len, batch_size, d_z]
mask
is は、[seq_len, seq_len, 1]
トークン同士の可視性を制御するブーリアンマスクです。1
サイズの最後のディメンションはバッチです。これは他のトランスフォーマー実装にもありますが、互換性のために残されています117    def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):シーケンスの長さを取得
126        seq_len = z.shape[0]とに分割
128        z1, z2 = torch.chunk(z, 2, dim=-1)チェックマスク
131        if mask is not None:mask
形があります[seq_len_q, seq_len_k, batch_size]
。1
この実装ではバッチ内のすべてのサンプルに対して同じマスクしかサポートしないため、バッチディメンションはサイズにする必要があります。
135            assert mask.shape[0] == 1 or mask.shape[0] == seq_len
136            assert mask.shape[1] == seq_lenここでは、すべてのサンプルで同じマスクのみをサポートしています
138            assert mask.shape[2] == 1バッチディメンションを削除する
140            mask = mask[:, :, 0]前にノーマライズ
143        z2 = self.norm(z2)ウェイトマトリックスを取得。これより大きい場合は切り捨てる seq_len
145        weight = self.weight[:seq_len, :seq_len]150        if mask is not None:
151            weight = weight * mask154        z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]157        return z1 * z2