これは、「画像は16x16の言葉に値する」という論文「大規模画像認識のためのトランスフォーマー」をPyTorchで実装したものです。
ビジョントランスフォーマーは、畳み込み層のない画像に純粋なトランスフォーマーを適用します。画像をパッチに分割し、パッチの埋め込みにトランスフォーマーを適用します。パッチ埋め込みは、パッチの平坦化されたピクセル値に単純な線形変換を適用することによって生成されます。次に、標準のトランスエンコーダに、分類トークンとともにパッチ埋め込みが供給されます。[CLS]
[CLS]
トークンのエンコーディングは、画像をMLPで分類するために使用されます
トランスにパッチを供給する際、学習した位置埋め込みがパッチ埋め込みに追加されます。これは、パッチ埋め込みにはそのパッチがどこから来たかについての情報がないためです。位置埋め込みは、各パッチ位置のベクトルのセットで、他のパラメーターとともに勾配降下法でトレーニングされます
。VITは、大規模なデータセットで事前にトレーニングしておくとうまく機能します。この論文では、MLP分類ヘッドで事前にトレーニングし、微調整の際には単一の線形層を使用することを提案しています。この論文は、3億の画像データセットで事前にトレーニングされたVITでSOTAを上回っています。また、パッチサイズを同じに保ちながら、推論時には高解像度の画像を使用します。新しいパッチ位置の位置埋め込みは、学習した位置埋め込みを補間することによって計算されます
。これは、CIFAR-10 で VIT をトレーニングする実験です。これは小さなデータセットでトレーニングされているため、あまりうまくいきません。誰でも走ってVITで遊べる簡単な実験です
。43import torch
44from torch import nn
45
46from labml_helpers.module import Module
47from labml_nn.transformers import TransformerLayer
48from labml_nn.utils import clone_module_list
51class PatchEmbeddings(Module):
d_model
変圧器の埋め込みサイズですpatch_size
パッチのサイズin_channels
は入力画像のチャンネル数 (RGB の場合は 3)63 def __init__(self, d_model: int, patch_size: int, in_channels: int):
69 super().__init__()
カーネルサイズでストライドの長さがパッチサイズと同じコンボリューションレイヤーを作成します。これは、画像をパッチに分割し、各パッチで線形変換を行うのと同じです
。74 self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
x
形状の入力画像です [batch_size, channels, height, width]
76 def forward(self, x: torch.Tensor):
畳み込み層を適用
81 x = self.conv(x)
形を手に入れよう。
83 bs, c, h, w = x.shape
図形に再配置 [patches, batch_size, d_model]
85 x = x.permute(2, 3, 0, 1)
86 x = x.view(h * w, bs, c)
パッチの埋め込みを返す
89 return x
92class LearnedPositionalEmbeddings(Module):
d_model
変圧器の埋め込みサイズですmax_len
パッチの最大数です101 def __init__(self, d_model: int, max_len: int = 5_000):
106 super().__init__()
各ロケーションの位置埋め込み
108 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
x
形状のパッチ埋め込みです [patches, batch_size, d_model]
110 def forward(self, x: torch.Tensor):
与えられたパッチの位置埋め込みを取得
115 pe = self.positional_encodings[:x.shape[0]]
パッチ埋め込みに追加して返す
117 return x + pe
120class ClassificationHead(Module):
d_model
変圧器の埋め込みサイズですn_hidden
隠れレイヤーのサイズn_classes
分類タスク内のクラス数です128 def __init__(self, d_model: int, n_hidden: int, n_classes: int):
134 super().__init__()
第 1 レイヤー
136 self.linear1 = nn.Linear(d_model, n_hidden)
アクティベーション
138 self.act = nn.ReLU()
第 2 レイヤー
140 self.linear2 = nn.Linear(n_hidden, n_classes)
x
トークンのトランスフォーマーエンコーディングです [CLS]
142 def forward(self, x: torch.Tensor):
第1層とアクティベーション
147 x = self.act(self.linear1(x))
第 2 レイヤー
149 x = self.linear2(x)
152 return x
155class VisionTransformer(Module):
transformer_layer
1 つのトランスレイヤーのコピーです。n_layers
それをコピーして変圧器を作りますn_layers
変圧器層の数です。patch_emb
パッチ埋め込みレイヤーです。pos_emb
位置埋め込みレイヤーです。classification
分類責任者です。163 def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
164 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
165 classification: ClassificationHead):
174 super().__init__()
パッチ埋め込み
176 self.patch_emb = patch_emb
177 self.pos_emb = pos_emb
分類ヘッド
179 self.classification = classification
トランスレイヤーのコピーを作成
181 self.transformer_layers = clone_module_list(transformer_layer, n_layers)
[CLS]
トークン埋め込み
184 self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
最終正規化レイヤー
186 self.ln = nn.LayerNorm([transformer_layer.size])
x
形状の入力画像です [batch_size, channels, height, width]
188 def forward(self, x: torch.Tensor):
パッチの埋め込みを入手してください。これにより形状のテンソルが得られます。[patches, batch_size, d_model]
193 x = self.patch_emb(x)
位置埋め込みを追加
195 x = self.pos_emb(x)
[CLS]
トランスフォーマーに給電する前にトークンの埋め込みを連結してください
197 cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
198 x = torch.cat([cls_token_emb, x])
アテンション・マスクなしで変圧器層を通過
201 for layer in self.transformer_layers:
202 x = layer(x=x, mask=None)
[CLS]
トークン (シーケンスの最初のもの) のトランスフォーマー出力を取得します。
205 x = x[0]
レイヤー正規化
208 x = self.ln(x)
ロジットを取得するための分類ヘッド
211 x = self.classification(x)
214 return x