MLPミキサー:ビジョン用のオールMLPアーキテクチャ

これは、論文「MLPミキサー:ビジョン用のオールMLPアーキテクチャ」をPyTorchで実装したものです

本稿では、このモデルをビジョンタスクに適用します。このモデルは、アテンションレイヤーがパッチ(NLPタスクの場合はトークン)全体に適用されるMLPに置き換えられるトランスフォーマーに似ています

MLP Mixerの実装は、トランスフォーマー実装のセルフアテンションレイヤーに代わるものです。つまり、テンソルを転置してシーケンスの次元全体に MLP を適用するだけのコードです

この論文では視覚タスクにMLP Mixerを適用しましたが、マスクされた言語モデルで試してみましたこれが実験コードです

27from typing import Optional
28
29import torch
30from torch import nn

MLP ミキサー

このモジュールは、セルフアテンションレイヤーに代わるドロップインモジュールです。入力テンソルをMLPにフィードする前に転置して戻すので、MLPはフィーチャディメンションではなくシーケンスディメンション全体(トークンまたはイメージパッチ)に適用されます

33class MLPMixer(nn.Module):
  • ffn は MLP モジュールです。
43    def __init__(self, mlp: nn.Module):
47        super().__init__()
48        self.mlp = mlp

通常のアテンションモジュールには、、、マスクにさまざまなトークンを埋め込むことができます。

同じ関数シグネチャに従うので、直接置換できます。

MLPミキシング用で、マスキングはできません。query (keyvalue ) の形はです[seq_len, batch_size, d_model]

50    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):

、、すべて同じでなければなりません

62        assert query is key and key is value

MLPミキサーはマスキングをサポートしていません。つまり、すべてのトークンに他のすべてのトークン埋め込みが表示されます。

64        assert mask is None

x わかりやすいように割り当てる

67        x = query

最後の次元がシーケンス次元になるように転置します。新しい形は [d_model, batch_size, seq_len]

71        x = x.transpose(0, 2)

MLP をトークン全体に適用

73        x = self.mlp(x)

元の形式に戻す

75        x = x.transpose(0, 2)

78        return x