你只需要补丁吗?(convMixer)

这是 PyTorch 对论文《补丁就是你所需要的?》的实现

ConvMixer diagram from the paper

convMixer 类似于 MLP 混音器。MLP-Mixer 通过在空间维度上应用 MLP,然后在信道维度上应用 MLP 来分离空间维度和信道维度的混音(空间 MLP 取代 vIT 注意力,信道 MLP 是 ViT 的 FFN)。

ConvMixer 使用卷积进行通道混合,使用深度卷积进行空间混合。由于它是卷积而不是整个空间的完整的 MLP,因此与 vIT 或 MLP-Mixer 相比,它只混合附近的批次。此外,MLP-Mixer 每次混合使用两层 MLP,ConvMixer 每次混合使用单层。

该论文建议删除信道混合(逐点卷积)上的剩余连接,在空间混合(深度卷积)上仅使用残差连接。他们还使用批量标准化而不是图层标准化

这是一项在 CIFAR-10 上训练 ConvMixer 的实验

36import torch
37from torch import nn
38
39from labml_helpers.module import Module
40from labml_nn.utils import clone_module_list

混音器层

这是单个 ConvMixer 层。该模型将有一系列这样的。

43class ConvMixerLayer(Module):
  • d_model 是补丁嵌入中的通道数,
  • kernel_size 是空间卷积内核的大小,
52    def __init__(self, d_model: int, kernel_size: int):
57        super().__init__()

深度卷积是每个通道的单独卷积。我们使用卷积层来完成此操作,该卷积层的组数等于通道数。因此,每个频道都是它自己的组。

61        self.depth_wise_conv = nn.Conv2d(d_model, d_model,
62                                         kernel_size=kernel_size,
63                                         groups=d_model,
64                                         padding=(kernel_size - 1) // 2)

深度卷积后激活

66        self.act1 = nn.GELU()

深度卷积后的归一化

68        self.norm1 = nn.BatchNorm2d(d_model)

逐点卷积是一种卷积。即补丁嵌入的线性变换

72        self.point_wise_conv = nn.Conv2d(d_model, d_model, kernel_size=1)

逐点卷积后激活

74        self.act2 = nn.GELU()

逐点卷积后的归一化

76        self.norm2 = nn.BatchNorm2d(d_model)
78    def forward(self, x: torch.Tensor):

对于围绕深度卷积的剩余连接

80        residual = x

深度卷积、激活和归一化

83        x = self.depth_wise_conv(x)
84        x = self.act1(x)
85        x = self.norm1(x)

添加剩余连接

88        x += residual

逐点卷积、激活和归一化

91        x = self.point_wise_conv(x)
92        x = self.act2(x)
93        x = self.norm2(x)

96        return x

获取补丁嵌入

这会将图像拆分为大小的补丁,并为每个补丁提供嵌入。

99class PatchEmbeddings(Module):
  • d_model 是补丁嵌入中的通道数
  • patch_size 是补丁的大小,
  • in_channels 是输入图像中的通道数(rgb 为 3)
108    def __init__(self, d_model: int, patch_size: int, in_channels: int):
114        super().__init__()

我们创建一个卷积层,其内核大小和步长等于补丁大小。这相当于将图像分割成色块并在每个面片上进行线性变换。

119        self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

激活功能

121        self.act = nn.GELU()

批量标准化

123        self.norm = nn.BatchNorm2d(d_model)
  • x 是形状的输入图像[batch_size, channels, height, width]
125    def forward(self, x: torch.Tensor):

应用卷积层

130        x = self.conv(x)

激活和规范化

132        x = self.act(x)
133        x = self.norm(x)

136        return x

分类主管

它们进行平均池(取所有补丁嵌入的均值)和最终的线性变换来预测影像类的对数概率。

139class ClassificationHead(Module):
  • d_model 是补丁嵌入中的通道数,
  • n_classes 是分类任务中的类数
149    def __init__(self, d_model: int, n_classes: int):
154        super().__init__()

平均池

156        self.pool = nn.AdaptiveAvgPool2d((1, 1))

线性层

158        self.linear = nn.Linear(d_model, n_classes)
160    def forward(self, x: torch.Tensor):

平均汇集

162        x = self.pool(x)

得到嵌入,x 会有形状[batch_size, d_model, 1, 1]

164        x = x[:, :, 0, 0]

线性层

166        x = self.linear(x)

169        return x

混音器

它结合了补丁嵌入块、许多 ConvMixer 层和一个分类头。

172class ConvMixer(Module):
  • conv_mixer_layer 是单个 C onvMixer 层的副本。我们制作它的副本来制作 ConvMixern_layers
  • n_layers 是 ConvMixer 层(或深度)的数量
  • patch_emb补丁嵌入层
  • classification分类头
179    def __init__(self, conv_mixer_layer: ConvMixerLayer, n_layers: int,
180                 patch_emb: PatchEmbeddings,
181                 classification: ClassificationHead):
189        super().__init__()

补丁嵌入

191        self.patch_emb = patch_emb

分类主管

193        self.classification = classification

制作 C onvMixer 图层的副本

195        self.conv_mixer_layers = clone_module_list(conv_mixer_layer, n_layers)
  • x 是形状的输入图像[batch_size, channels, height, width]
197    def forward(self, x: torch.Tensor):

获取补丁嵌入。这给出了形状的张量[batch_size, d_model, height / patch_size, width / patch_size]

202        x = self.patch_emb(x)
205        for layer in self.conv_mixer_layers:
206            x = layer(x)

分类头,获取日志

209        x = self.classification(x)

212        return x