これは、論文「MLPに注意して」をPyTorchで実装したものです。
この論文では、ゲーティングを備えた多層パーセプトロン(MLP)ベースのアーキテクチャ(GMLPと名付けられています)を紹介します。gMLP ブロックのスタックで構成されています
。19from typing import Optional
20
21import torch
22from torch import nn
各ブロックは、入力埋め込みに対して次の変換を行います。ここで、はシーケンスの長さ、は埋め込みの次元です。
学習可能な投影重みの位置と位置は以下に定義するスペーシャル・ゲーティング・ユニットです。の出力次元はの半分になります。はGelUのようなアクティベーション関数です
。25class GMLPBlock(nn.Module):
d_model
はの次元 () d_ffn
の次元です seq_len
はトークン・シーケンスの長さ ()46 def __init__(self, d_model: int, d_ffn: int, seq_len: int):
52 super().__init__()
プレノルムの正規化レイヤー
54 self.norm = nn.LayerNorm([d_model])
アクティベーション機能
56 self.activation = nn.GELU()
の投影レイヤー
58 self.proj1 = nn.Linear(d_model, d_ffn)
空間ゲートユニット
60 self.sgu = SpacialGatingUnit(d_ffn, seq_len)
の投影レイヤー
62 self.proj2 = nn.Linear(d_ffn // 2, d_model)
66 self.size = d_model
x
形状の入力埋め込みテンソルです [seq_len, batch_size, d_model]
mask
は、[seq_len, seq_len, 1]
トークン同士の可視性を制御するブーリアンシェイプマスクです。68 def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
ショートカット接続用にコピーを保存
75 shortcut = x
ノーマライズ
77 x = self.norm(x)
プロジェクションとアクティベーション
79 z = self.activation(self.proj1(x))
空間ゲートユニット
81 z = self.sgu(z, mask)
最終投影
83 z = self.proj2(z)
ショートカット接続を追加する
86 return z + shortcut
89class SpacialGatingUnit(nn.Module):
d_z
の次元です seq_len
はシーケンスの長さです99 def __init__(self, d_z: int, seq_len: int):
104 super().__init__()
適用前の正規化レイヤー
106 self.norm = nn.LayerNorm([d_z // 2])
111 self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)
115 self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
z
形状の入力です [seq_len, batch_size, d_z]
mask
is は、[seq_len, seq_len, 1]
トークン同士の可視性を制御するブーリアンマスクです。1
サイズの最後のディメンションはバッチです。これは他のトランスフォーマー実装にもありますが、互換性のために残されています117 def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):
シーケンスの長さを取得
126 seq_len = z.shape[0]
とに分割
128 z1, z2 = torch.chunk(z, 2, dim=-1)
チェックマスク
131 if mask is not None:
mask
形があります[seq_len_q, seq_len_k, batch_size]
。1
この実装ではバッチ内のすべてのサンプルに対して同じマスクしかサポートしないため、バッチディメンションはサイズにする必要があります。
135 assert mask.shape[0] == 1 or mask.shape[0] == seq_len
136 assert mask.shape[1] == seq_len
ここでは、すべてのサンプルで同じマスクのみをサポートしています
138 assert mask.shape[2] == 1
バッチディメンションを削除する
140 mask = mask[:, :, 0]
前にノーマライズ
143 z2 = self.norm(z2)
ウェイトマトリックスを取得。これより大きい場合は切り捨てる seq_len
145 weight = self.weight[:seq_len, :seq_len]
150 if mask is not None:
151 weight = weight * mask
154 z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]
157 return z1 * z2