这是 PyTorch 对《An Image Is Worth 16x16 Words:用于大规模图像识别的变形金刚》论文的实现。
视觉转换器将纯粹的变换器应用于没有任何卷积层的图像。他们将图像分割成补丁,并在补丁嵌入上应用转换器。补丁嵌入是通过对补丁的扁平化像素值应用简单的线性变换来生成的。然后,向标准变压器编码器提供补丁嵌入以及分类标记[CLS]
。[CLS]
令牌上的编码用于使用 MLP 对图像进行分类。
当向转换器提供补丁时,学到的位置嵌入会添加到补丁嵌入中,因为补丁嵌入没有任何关于该补丁来自何处的信息。位置嵌入是每个补丁位置的一组向量,这些向量使用梯度下降和其他参数进行训练。
VIT 在大型数据集上进行预训练时表现良好。本文建议使用 MLP 分类头对他们进行预训练,然后在微调时使用单个线性层。该论文在3亿张图像数据集上预先训练了ViT,击败了SOTA。它们还在推理过程中使用更高分辨率的图像,同时保持补丁大小不变。新补丁位置的位置嵌入是通过插值学习位置嵌入来计算的。
这是一项在 CIFAR-10 上训练 ViT 的实验。这效果不太好,因为它是在一个小数据集上训练的。这是一个简单的实验,任何人都可以使用Vits运行和玩游戏。
43import torch
44from torch import nn
45
46from labml_helpers.module import Module
47from labml_nn.transformers import TransformerLayer
48from labml_nn.utils import clone_module_list
51class PatchEmbeddings(Module):
d_model
变压器嵌入的大小是多少patch_size
是补丁的大小in_channels
是输入图像中的通道数(rgb 为 3)63 def __init__(self, d_model: int, patch_size: int, in_channels: int):
69 super().__init__()
我们创建一个卷积层,其内核大小和步长等于补丁大小。这相当于将图像分割成色块并在每个面片上进行线性变换。
74 self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
x
是形状的输入图像[batch_size, channels, height, width]
76 def forward(self, x: torch.Tensor):
应用卷积层
81 x = self.conv(x)
得到形状。
83 bs, c, h, w = x.shape
重新排列成形状[patches, batch_size, d_model]
85 x = x.permute(2, 3, 0, 1)
86 x = x.view(h * w, bs, c)
返回补丁嵌入
89 return x
92class LearnedPositionalEmbeddings(Module):
d_model
变压器嵌入的大小是多少max_len
是补丁的最大数量101 def __init__(self, d_model: int, max_len: int = 5_000):
106 super().__init__()
每个位置的位置嵌入
108 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
x
是形状的补丁嵌入[patches, batch_size, d_model]
110 def forward(self, x: torch.Tensor):
获取给定补丁的位置嵌入
115 pe = self.positional_encodings[:x.shape[0]]
添加到补丁嵌入并返回
117 return x + pe
120class ClassificationHead(Module):
d_model
变压器嵌入的大小是多少n_hidden
是隐藏层的大小n_classes
是分类任务中的类数128 def __init__(self, d_model: int, n_hidden: int, n_classes: int):
134 super().__init__()
第一层
136 self.linear1 = nn.Linear(d_model, n_hidden)
激活
138 self.act = nn.ReLU()
第二层
140 self.linear2 = nn.Linear(n_hidden, n_classes)
x
是[CLS]
令牌的转换器编码142 def forward(self, x: torch.Tensor):
第一层和激活
147 x = self.act(self.linear1(x))
第二层
149 x = self.linear2(x)
152 return x
163 def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
164 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
165 classification: ClassificationHead):
174 super().__init__()
补丁嵌入
176 self.patch_emb = patch_emb
177 self.pos_emb = pos_emb
分类主管
179 self.classification = classification
制作变压器层的副本
181 self.transformer_layers = clone_module_list(transformer_layer, n_layers)
[CLS]
令牌嵌入
184 self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
最终归一化层
186 self.ln = nn.LayerNorm([transformer_layer.size])
x
是形状的输入图像[batch_size, channels, height, width]
188 def forward(self, x: torch.Tensor):
获取补丁嵌入。这给出了形状的张量[patches, batch_size, d_model]
193 x = self.patch_emb(x)
在给变压器供电之前连接[CLS]
令牌嵌入
195 cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
196 x = torch.cat([cls_token_emb, x])
添加位置嵌入
198 x = self.pos_emb(x)
穿过变压器层,不遮挡注意力
201 for layer in self.transformer_layers:
202 x = layer(x=x, mask=None)
获取[CLS]
令牌的转换器输出(序列中的第一个)。
205 x = x[0]
层规范化
208 x = self.ln(x)
分类头,获取日志
211 x = self.classification(x)
214 return x