これは、論文「MLPミキサー:ビジョン用のオールMLPアーキテクチャ」をPyTorchで実装したものです。
本稿では、このモデルをビジョンタスクに適用します。このモデルは、アテンションレイヤーがパッチ(NLPタスクの場合はトークン)全体に適用されるMLPに置き換えられるトランスフォーマーに似ています
。MLP Mixerの実装は、トランスフォーマー実装のセルフアテンションレイヤーに代わるものです。つまり、テンソルを転置してシーケンスの次元全体に MLP を適用するだけのコードです
。この論文では視覚タスクにMLP Mixerを適用しましたが、マスクされた言語モデルで試してみました。これが実験コードです。
27from typing import Optional
28
29import torch
30from torch import nn
このモジュールは、セルフアテンションレイヤーに代わるドロップインモジュールです。入力テンソルをMLPにフィードする前に転置して戻すので、MLPはフィーチャディメンションではなくシーケンスディメンション全体(トークンまたはイメージパッチ)に適用されます
。33class MLPMixer(nn.Module):
ffn
は MLP モジュールです。43 def __init__(self, mlp: nn.Module):
47 super().__init__()
48 self.mlp = mlp
通常のアテンションモジュールには、、、、マスクにさまざまなトークンを埋め込むことができます。
同じ関数シグネチャに従うので、直接置換できます。
MLPミキシング用で、マスキングはできません。query
(key
とvalue
) の形はです[seq_len, batch_size, d_model]
。
50 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
、、すべて同じでなければなりません
62 assert query is key and key is value
MLPミキサーはマスキングをサポートしていません。つまり、すべてのトークンに他のすべてのトークン埋め込みが表示されます。
64 assert mask is None
x
わかりやすいように割り当てる
67 x = query
最後の次元がシーケンス次元になるように転置します。新しい形は [d_model, batch_size, seq_len]
71 x = x.transpose(0, 2)
MLP をトークン全体に適用
73 x = self.mlp(x)
元の形式に戻す
75 x = x.transpose(0, 2)
78 return x