これは、論文「MLPミキサー:ビジョン用のオールMLPアーキテクチャ」をPyTorchで実装したものです。
本稿では、このモデルをビジョンタスクに適用します。このモデルは、アテンションレイヤーがパッチ(NLPタスクの場合はトークン)全体に適用されるMLPに置き換えられるトランスフォーマーに似ています
。MLP Mixerの実装は、トランスフォーマー実装のセルフアテンションレイヤーに代わるものです。つまり、テンソルを転置してシーケンスの次元全体に MLP を適用するだけのコードです
。この論文では視覚タスクにMLP Mixerを適用しましたが、マスクされた言語モデルで試してみました。これが実験コードです。
27from typing import Optional
28
29import torch
30from torch import nnこのモジュールは、セルフアテンションレイヤーに代わるドロップインモジュールです。入力テンソルをMLPにフィードする前に転置して戻すので、MLPはフィーチャディメンションではなくシーケンスディメンション全体(トークンまたはイメージパッチ)に適用されます
。33class MLPMixer(nn.Module):ffn
は MLP モジュールです。43 def __init__(self, mlp: nn.Module):47 super().__init__()
48 self.mlp = mlp通常のアテンションモジュールには、、、、マスクにさまざまなトークンを埋め込むことができます。
同じ関数シグネチャに従うので、直接置換できます。
MLPミキシング用で、マスキングはできません。query
(key
とvalue
) の形はです[seq_len, batch_size, d_model]
。
50 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):、、すべて同じでなければなりません
62 assert query is key and key is valueMLPミキサーはマスキングをサポートしていません。つまり、すべてのトークンに他のすべてのトークン埋め込みが表示されます。
64 assert mask is Nonex
わかりやすいように割り当てる
67 x = query最後の次元がシーケンス次元になるように転置します。新しい形は [d_model, batch_size, seq_len]
71 x = x.transpose(0, 2)MLP をトークン全体に適用
73 x = self.mlp(x)元の形式に戻す
75 x = x.transpose(0, 2)78 return x