これは、「画像は16x16の言葉に値する」という論文「大規模画像認識のためのトランスフォーマー」をPyTorchで実装したものです。
ビジョントランスフォーマーは、畳み込み層のない画像に純粋なトランスフォーマーを適用します。画像をパッチに分割し、パッチの埋め込みにトランスフォーマーを適用します。パッチ埋め込みは、パッチの平坦化されたピクセル値に単純な線形変換を適用することによって生成されます。次に、標準のトランスエンコーダに、分類トークンとともにパッチ埋め込みが供給されます。[CLS]
[CLS]
トークンのエンコーディングは、画像をMLPで分類するために使用されます
トランスにパッチを供給する際、学習した位置埋め込みがパッチ埋め込みに追加されます。これは、パッチ埋め込みにはそのパッチがどこから来たかについての情報がないためです。位置埋め込みは、各パッチ位置のベクトルのセットで、他のパラメーターとともに勾配降下法でトレーニングされます
。VITは、大規模なデータセットで事前にトレーニングしておくとうまく機能します。この論文では、MLP分類ヘッドで事前にトレーニングし、微調整の際には単一の線形層を使用することを提案しています。この論文は、3億の画像データセットで事前にトレーニングされたVITでSOTAを上回っています。また、パッチサイズを同じに保ちながら、推論時には高解像度の画像を使用します。新しいパッチ位置の位置埋め込みは、学習した位置埋め込みを補間することによって計算されます
。これは、CIFAR-10 で VIT をトレーニングする実験です。これは小さなデータセットでトレーニングされているため、あまりうまくいきません。誰でも走ってVITで遊べる簡単な実験です
。43import torch
44from torch import nn
45
46from labml_helpers.module import Module
47from labml_nn.transformers import TransformerLayer
48from labml_nn.utils import clone_module_list51class PatchEmbeddings(Module):d_model
変圧器の埋め込みサイズですpatch_size
パッチのサイズin_channels
は入力画像のチャンネル数 (RGB の場合は 3)63    def __init__(self, d_model: int, patch_size: int, in_channels: int):69        super().__init__()カーネルサイズでストライドの長さがパッチサイズと同じコンボリューションレイヤーを作成します。これは、画像をパッチに分割し、各パッチで線形変換を行うのと同じです
。74        self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)x
形状の入力画像です [batch_size, channels, height, width]
76    def forward(self, x: torch.Tensor):畳み込み層を適用
81        x = self.conv(x)形を手に入れよう。
83        bs, c, h, w = x.shape図形に再配置 [patches, batch_size, d_model]
85        x = x.permute(2, 3, 0, 1)
86        x = x.view(h * w, bs, c)パッチの埋め込みを返す
89        return x92class LearnedPositionalEmbeddings(Module):d_model
変圧器の埋め込みサイズですmax_len
パッチの最大数です101    def __init__(self, d_model: int, max_len: int = 5_000):106        super().__init__()各ロケーションの位置埋め込み
108        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)x
形状のパッチ埋め込みです [patches, batch_size, d_model]
110    def forward(self, x: torch.Tensor):与えられたパッチの位置埋め込みを取得
115        pe = self.positional_encodings[:x.shape[0]]パッチ埋め込みに追加して返す
117        return x + pe120class ClassificationHead(Module):d_model
変圧器の埋め込みサイズですn_hidden
隠れレイヤーのサイズn_classes
分類タスク内のクラス数です128    def __init__(self, d_model: int, n_hidden: int, n_classes: int):134        super().__init__()第 1 レイヤー
136        self.linear1 = nn.Linear(d_model, n_hidden)アクティベーション
138        self.act = nn.ReLU()第 2 レイヤー
140        self.linear2 = nn.Linear(n_hidden, n_classes)x
トークンのトランスフォーマーエンコーディングです [CLS]
142    def forward(self, x: torch.Tensor):第1層とアクティベーション
147        x = self.act(self.linear1(x))第 2 レイヤー
149        x = self.linear2(x)152        return x155class VisionTransformer(Module):transformer_layer
1 つのトランスレイヤーのコピーです。n_layers
それをコピーして変圧器を作りますn_layers
変圧器層の数です。patch_emb
パッチ埋め込みレイヤーです。pos_emb
位置埋め込みレイヤーです。classification
分類責任者です。163    def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
164                 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
165                 classification: ClassificationHead):174        super().__init__()パッチ埋め込み
176        self.patch_emb = patch_emb
177        self.pos_emb = pos_emb分類ヘッド
179        self.classification = classificationトランスレイヤーのコピーを作成
181        self.transformer_layers = clone_module_list(transformer_layer, n_layers)[CLS]
トークン埋め込み
184        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)最終正規化レイヤー
186        self.ln = nn.LayerNorm([transformer_layer.size])x
形状の入力画像です [batch_size, channels, height, width]
188    def forward(self, x: torch.Tensor):パッチの埋め込みを入手してください。これにより形状のテンソルが得られます。[patches, batch_size, d_model]
193        x = self.patch_emb(x)位置埋め込みを追加
195        x = self.pos_emb(x)[CLS]
トランスフォーマーに給電する前にトークンの埋め込みを連結してください
197        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
198        x = torch.cat([cls_token_emb, x])アテンション・マスクなしで変圧器層を通過
201        for layer in self.transformer_layers:
202            x = layer(x=x, mask=None)[CLS]
トークン (シーケンスの最初のもの) のトランスフォーマー出力を取得します。
205        x = x[0]レイヤー正規化
208        x = self.ln(x)ロジットを取得するための分類ヘッド
211        x = self.classification(x)214        return x