这是 PyTorch 对论文《补丁就是你所需要的?》的实现
。convMixer 类似于 MLP 混音器。MLP-Mixer 通过在空间维度上应用 MLP,然后在信道维度上应用 MLP 来分离空间维度和信道维度的混音(空间 MLP 取代 vIT 注意力,信道 MLP 是 ViT 的 FFN)。
ConvMixer 使用卷积进行通道混合,使用深度卷积进行空间混合。由于它是卷积而不是整个空间的完整的 MLP,因此与 vIT 或 MLP-Mixer 相比,它只混合附近的批次。此外,MLP-Mixer 每次混合使用两层 MLP,ConvMixer 每次混合使用单层。
该论文建议删除信道混合(逐点卷积)上的剩余连接,在空间混合(深度卷积)上仅使用残差连接。他们还使用批量标准化而不是图层标准化。
36import torch
37from torch import nn
38
39from labml_helpers.module import Module
40from labml_nn.utils import clone_module_list
43class ConvMixerLayer(Module):
d_model
是补丁嵌入中的通道数,kernel_size
是空间卷积内核的大小,52 def __init__(self, d_model: int, kernel_size: int):
57 super().__init__()
深度卷积是每个通道的单独卷积。我们使用卷积层来完成此操作,该卷积层的组数等于通道数。因此,每个频道都是它自己的组。
61 self.depth_wise_conv = nn.Conv2d(d_model, d_model,
62 kernel_size=kernel_size,
63 groups=d_model,
64 padding=(kernel_size - 1) // 2)
深度卷积后激活
66 self.act1 = nn.GELU()
深度卷积后的归一化
68 self.norm1 = nn.BatchNorm2d(d_model)
逐点卷积是一种卷积。即补丁嵌入的线性变换
72 self.point_wise_conv = nn.Conv2d(d_model, d_model, kernel_size=1)
逐点卷积后激活
74 self.act2 = nn.GELU()
逐点卷积后的归一化
76 self.norm2 = nn.BatchNorm2d(d_model)
78 def forward(self, x: torch.Tensor):
对于围绕深度卷积的剩余连接
80 residual = x
深度卷积、激活和归一化
83 x = self.depth_wise_conv(x)
84 x = self.act1(x)
85 x = self.norm1(x)
添加剩余连接
88 x += residual
逐点卷积、激活和归一化
91 x = self.point_wise_conv(x)
92 x = self.act2(x)
93 x = self.norm2(x)
96 return x
99class PatchEmbeddings(Module):
d_model
是补丁嵌入中的通道数patch_size
是补丁的大小,in_channels
是输入图像中的通道数(rgb 为 3)108 def __init__(self, d_model: int, patch_size: int, in_channels: int):
114 super().__init__()
我们创建一个卷积层,其内核大小和步长等于补丁大小。这相当于将图像分割成色块并在每个面片上进行线性变换。
119 self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
激活功能
121 self.act = nn.GELU()
批量标准化
123 self.norm = nn.BatchNorm2d(d_model)
x
是形状的输入图像[batch_size, channels, height, width]
125 def forward(self, x: torch.Tensor):
应用卷积层
130 x = self.conv(x)
激活和规范化
132 x = self.act(x)
133 x = self.norm(x)
136 return x
139class ClassificationHead(Module):
d_model
是补丁嵌入中的通道数,n_classes
是分类任务中的类数149 def __init__(self, d_model: int, n_classes: int):
154 super().__init__()
平均池
156 self.pool = nn.AdaptiveAvgPool2d((1, 1))
线性层
158 self.linear = nn.Linear(d_model, n_classes)
160 def forward(self, x: torch.Tensor):
平均汇集
162 x = self.pool(x)
得到嵌入,x
会有形状[batch_size, d_model, 1, 1]
164 x = x[:, :, 0, 0]
线性层
166 x = self.linear(x)
169 return x
172class ConvMixer(Module):
conv_mixer_layer
是单个 C onvMixer 层的副本。我们制作它的副本来制作 ConvMixern_layers
。n_layers
是 ConvMixer 层(或深度)的数量。patch_emb
是补丁嵌入层。classification
是分类头。179 def __init__(self, conv_mixer_layer: ConvMixerLayer, n_layers: int,
180 patch_emb: PatchEmbeddings,
181 classification: ClassificationHead):
189 super().__init__()
补丁嵌入
191 self.patch_emb = patch_emb
分类主管
193 self.classification = classification
制作 C onvMixer 图层的副本
195 self.conv_mixer_layers = clone_module_list(conv_mixer_layer, n_layers)
x
是形状的输入图像[batch_size, channels, height, width]
197 def forward(self, x: torch.Tensor):
获取补丁嵌入。这给出了形状的张量[batch_size, d_model, height / patch_size, width / patch_size]
。
202 x = self.patch_emb(x)
穿过 ConvMixer 图层
205 for layer in self.conv_mixer_layers:
206 x = layer(x)
分类头,获取日志
209 x = self.classification(x)
212 return x