视觉变压器 (ViT)

这是 PyTorch 对《An Image Is Worth 16x16 Words:用于大规模图像识别的变形金刚》论文的实现。

视觉转换器将纯粹的变换器应用于没有任何卷积层的图像。他们将图像分割成补丁,并在补丁嵌入上应用转换器。补丁嵌入是通过对补丁的扁平化像素值应用简单的线性变换来生成的。然后,向标准变压器编码器提供补丁嵌入以及分类标记[CLS][CLS] 令牌上的编码用于使用 MLP 对图像进行分类。

当向转换器提供补丁时,学到的位置嵌入会添加到补丁嵌入中,因为补丁嵌入没有任何关于该补丁来自何处的信息。位置嵌入是每个补丁位置的一组向量,这些向量使用梯度下降和其他参数进行训练。

VIT 在大型数据集上进行预训练时表现良好。本文建议使用 MLP 分类头对他们进行预训练,然后在微调时使用单个线性层。该论文在3亿张图像数据集上预先训练了ViT,击败了SOTA。它们还在推理过程中使用更高分辨率的图像,同时保持补丁大小不变。新补丁位置的位置嵌入是通过插值学习位置嵌入来计算的。

这是一项在 CIFAR-10 上训练 ViT 的实验。这效果不太好,因为它是在一个小数据集上训练的。这是一个简单的实验,任何人都可以使用Vits运行和玩游戏。

43import torch
44from torch import nn
45
46from labml_helpers.module import Module
47from labml_nn.transformers import TransformerLayer
48from labml_nn.utils import clone_module_list

获取补丁嵌入

纸张将图像分割成大小相等的斑块,然后对每个补丁的扁平像素进行线性变换。

我们通过卷积层实现同样的东西,因为它更容易实现。

51class PatchEmbeddings(Module):
  • d_model 变压器嵌入的大小是多少
  • patch_size 是补丁的大小
  • in_channels 是输入图像中的通道数(rgb 为 3)
63    def __init__(self, d_model: int, patch_size: int, in_channels: int):
69        super().__init__()

我们创建一个卷积层,其内核大小和步长等于补丁大小。这相当于将图像分割成色块并在每个面片上进行线性变换。

74        self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
  • x 是形状的输入图像[batch_size, channels, height, width]
76    def forward(self, x: torch.Tensor):

应用卷积层

81        x = self.conv(x)

得到形状。

83        bs, c, h, w = x.shape

重新排列成形状[patches, batch_size, d_model]

85        x = x.permute(2, 3, 0, 1)
86        x = x.view(h * w, bs, c)

返回补丁嵌入

89        return x

添加参数化的位置编码

这将学习的位置嵌入添加到补丁嵌入中。

92class LearnedPositionalEmbeddings(Module):
  • d_model 变压器嵌入的大小是多少
  • max_len 是补丁的最大数量
101    def __init__(self, d_model: int, max_len: int = 5_000):
106        super().__init__()

每个位置的位置嵌入

108        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
  • x 是形状的补丁嵌入[patches, batch_size, d_model]
110    def forward(self, x: torch.Tensor):

获取给定补丁的位置嵌入

115        pe = self.positional_encodings[:x.shape[0]]

添加到补丁嵌入并返回

117        return x + pe

MLP 分类主管

这是基于[CLS] 令牌嵌入对图像进行分类的双层 MLP 头。

120class ClassificationHead(Module):
  • d_model 变压器嵌入的大小是多少
  • n_hidden 是隐藏层的大小
  • n_classes 是分类任务中的类数
128    def __init__(self, d_model: int, n_hidden: int, n_classes: int):
134        super().__init__()

第一层

136        self.linear1 = nn.Linear(d_model, n_hidden)

激活

138        self.act = nn.ReLU()

第二层

140        self.linear2 = nn.Linear(n_hidden, n_classes)
  • x[CLS] 令牌的转换器编码
142    def forward(self, x: torch.Tensor):

第一层和激活

147        x = self.act(self.linear1(x))

第二层

149        x = self.linear2(x)

152        return x

视觉变压器

这结合了补丁嵌入位置嵌入、变压器和分类头

155class VisionTransformer(Module):
163    def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
164                 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
165                 classification: ClassificationHead):
174        super().__init__()

补丁嵌入

176        self.patch_emb = patch_emb
177        self.pos_emb = pos_emb

分类主管

179        self.classification = classification

制作变压器层的副本

181        self.transformer_layers = clone_module_list(transformer_layer, n_layers)

[CLS] 令牌嵌入

184        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)

最终归一化层

186        self.ln = nn.LayerNorm([transformer_layer.size])
  • x 是形状的输入图像[batch_size, channels, height, width]
188    def forward(self, x: torch.Tensor):

获取补丁嵌入。这给出了形状的张量[patches, batch_size, d_model]

193        x = self.patch_emb(x)

在给变压器供电之前连接[CLS] 令牌嵌入

195        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
196        x = torch.cat([cls_token_emb, x])

添加位置嵌入

198        x = self.pos_emb(x)

穿过变压器层,不遮挡注意力

201        for layer in self.transformer_layers:
202            x = layer(x=x, mask=None)

获取[CLS] 令牌的转换器输出(序列中的第一个)。

205        x = x[0]

层规范化

208        x = self.ln(x)

分类头,获取日志

211        x = self.classification(x)

214        return x