必要なのはパッチだけ?(コンバージョンミキサー

)

これは紙のパッチをPyTorchで実装したものです。必要なのはパッチだけですか

ConvMixer diagram from the paper

ConvMixerはMLPミキサーに似ています。

MLP-Mixerは、空間次元全体にMLPを適用し、次にチャネル次元全体にMLPを適用することで、空間次元とチャネル次元の混合を分離します(空間MLPはvITのアテンションに代わり、チャネルMLPはVITのFFNです)。

ConvMixerは、チャンネルミキシングにはコンボリューションを使用し、空間ミキシングには奥行きコンボリューションを使用します。スペース全体でフルMLPではなく畳み込みなので、VITやMLPミキサーとは対照的に、近くのバッチのみをミキシングします。また、MLPミキサーはミキシングごとに2層のMLPを使用し、ConvMixerはミキシングごとに1層のMLPを使用します

この論文では、チャネルミキシング全体の残留接続を削除し(点単位の畳み込み)、空間ミキシングでは残留接続のみにする(深さ方向の畳み込み)ことを推奨しています。また、

レイヤー正規化の代わりにバッチ正規化を使用します

これは、CIFAR-10 で ConvMixer をトレーニングする実験です

36import torch
37from torch import nn
38
39from labml_helpers.module import Module
40from labml_nn.utils import clone_module_list

ConvMixer レイヤー

これは単一の ConvMixer レイヤーです。モデルにはこれらのシリーズがあります。

43class ConvMixerLayer(Module):
  • d_model はパッチ埋め込みのチャンネル数、
  • kernel_size は空間畳み込みのカーネルの大きさです
52    def __init__(self, d_model: int, kernel_size: int):
57        super().__init__()

深度方向の畳み込みは、チャンネルごとに別々の畳み込みになります。これは、グループの数がチャネル数と等しい畳み込み層で行います。そのため、各チャンネルはそれぞれ独自のグループになります。

61        self.depth_wise_conv = nn.Conv2d(d_model, d_model,
62                                         kernel_size=kernel_size,
63                                         groups=d_model,
64                                         padding=(kernel_size - 1) // 2)

深さ方向の畳み込み後のアクティベーション

66        self.act1 = nn.GELU()

深さ方向の畳み込み後の正規化

68        self.norm1 = nn.BatchNorm2d(d_model)

点単位の畳み込みは畳み込みです。つまり、パッチ埋め込みの線形変換です

72        self.point_wise_conv = nn.Conv2d(d_model, d_model, kernel_size=1)

点単位の畳み込み後のアクティベーション

74        self.act2 = nn.GELU()

点単位の畳み込み後の正規化

76        self.norm2 = nn.BatchNorm2d(d_model)
78    def forward(self, x: torch.Tensor):

深さ方向の畳み込みの周りの残差結合について

80        residual = x

深度方向の畳み込み、活性化、正規化

83        x = self.depth_wise_conv(x)
84        x = self.act1(x)
85        x = self.norm1(x)

残余接続を追加

88        x += residual

点単位の畳み込み、活性化、正規化

91        x = self.point_wise_conv(x)
92        x = self.act2(x)
93        x = self.norm2(x)

96        return x

パッチ埋め込みを入手

これにより、画像が複数のサイズのパッチに分割され、各パッチが埋め込まれます。

99class PatchEmbeddings(Module):
  • d_model パッチ埋め込みのチャンネル数です
  • patch_size はパッチのサイズ、
  • in_channels は入力画像のチャンネル数 (RGB の場合は 3)
108    def __init__(self, d_model: int, patch_size: int, in_channels: int):
114        super().__init__()

カーネルサイズでストライドの長さがパッチサイズと同じコンボリューションレイヤーを作成します。これは、画像をパッチに分割し、各パッチで線形変換を行うのと同じです

119        self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

アクティベーション機能

121        self.act = nn.GELU()

バッチ正規化

123        self.norm = nn.BatchNorm2d(d_model)
  • x 形状の入力画像です [batch_size, channels, height, width]
125    def forward(self, x: torch.Tensor):

畳み込み層を適用

130        x = self.conv(x)

アクティベーションと正規化

132        x = self.act(x)
133        x = self.norm(x)

136        return x

分類責任者

平均プーリング(すべてのパッチ埋め込みの平均を取る)と最終的な線形変換を行って、画像クラスの対数確率を予測します。

139class ClassificationHead(Module):
  • d_model はパッチ埋め込みのチャンネル数、
  • n_classes 分類タスク内のクラス数です
149    def __init__(self, d_model: int, n_classes: int):
154        super().__init__()

アベレージプール

156        self.pool = nn.AdaptiveAvgPool2d((1, 1))

リニアレイヤー

158        self.linear = nn.Linear(d_model, n_classes)
160    def forward(self, x: torch.Tensor):

平均プーリング

162        x = self.pool(x)

x 埋め込みを入れると、形が整います [batch_size, d_model, 1, 1]

164        x = x[:, :, 0, 0]

リニアレイヤー

166        x = self.linear(x)

169        return x

コンバージョンミキサー

これにより、パッチ埋め込みブロック、多数の ConvMixer レイヤー、および分類ヘッドが組み合わされます。

172class ConvMixer(Module):
179    def __init__(self, conv_mixer_layer: ConvMixerLayer, n_layers: int,
180                 patch_emb: PatchEmbeddings,
181                 classification: ClassificationHead):
189        super().__init__()

パッチ埋め込み

191        self.patch_emb = patch_emb

分類ヘッド

193        self.classification = classification

ConvMixer レイヤーのコピーを作成

195        self.conv_mixer_layers = clone_module_list(conv_mixer_layer, n_layers)
  • x 形状の入力画像です [batch_size, channels, height, width]
197    def forward(self, x: torch.Tensor):

パッチの埋め込みを入手してください。[batch_size, d_model, height / patch_size, width / patch_size] これにより形状のテンソルが得られます

202        x = self.patch_emb(x)
205        for layer in self.conv_mixer_layers:
206            x = layer(x)

ロジットを取得するための分類ヘッド

209        x = self.classification(x)

212        return x