ビジョントランスフォーマー (ViT)

これは、「画像は16x16の言葉に値する」という論文「大規模画像認識のためのトランスフォーマー」をPyTorchで実装したものです

ビジョントランスフォーマーは、畳み込み層のない画像に純粋なトランスフォーマーを適用します。画像をパッチに分割し、パッチの埋め込みにトランスフォーマーを適用します。パッチ埋め込みは、パッチの平坦化されたピクセル値に単純な線形変換を適用することによって生成されます。次に、標準のトランスエンコーダに、分類トークンとともにパッチ埋め込みが供給されます。[CLS] [CLS] トークンのエンコーディングは、画像をMLPで分類するために使用されます

トランスにパッチを供給する際、学習した位置埋め込みがパッチ埋め込みに追加されます。これは、パッチ埋め込みにはそのパッチがどこから来たかについての情報がないためです。位置埋め込みは、各パッチ位置のベクトルのセットで、他のパラメーターとともに勾配降下法でトレーニングされます

VITは、大規模なデータセットで事前にトレーニングしておくとうまく機能します。この論文では、MLP分類ヘッドで事前にトレーニングし、微調整の際には単一の線形層を使用することを提案しています。この論文は、3億の画像データセットで事前にトレーニングされたVITでSOTAを上回っています。また、パッチサイズを同じに保ちながら、推論時には高解像度の画像を使用します。新しいパッチ位置の位置埋め込みは、学習した位置埋め込みを補間することによって計算されます

これは、CIFAR-10 で VIT をトレーニングする実験です。これは小さなデータセットでトレーニングされているため、あまりうまくいきません。誰でも走ってVITで遊べる簡単な実験です

43import torch
44from torch import nn
45
46from labml_helpers.module import Module
47from labml_nn.transformers import TransformerLayer
48from labml_nn.utils import clone_module_list

パッチ埋め込みを入手

用紙は画像を同じサイズのパッチに分割し、パッチごとに平坦化されたピクセルを線形変換します。

実装が簡単なため、畳み込み層でも同じことを実装します。

51class PatchEmbeddings(Module):
  • d_model 変圧器の埋め込みサイズです
  • patch_size パッチのサイズ
  • in_channels は入力画像のチャンネル数 (RGB の場合は 3)
63    def __init__(self, d_model: int, patch_size: int, in_channels: int):
69        super().__init__()

カーネルサイズでストライドの長さがパッチサイズと同じコンボリューションレイヤーを作成します。これは、画像をパッチに分割し、各パッチで線形変換を行うのと同じです

74        self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
  • x 形状の入力画像です [batch_size, channels, height, width]
76    def forward(self, x: torch.Tensor):

畳み込み層を適用

81        x = self.conv(x)

形を手に入れよう。

83        bs, c, h, w = x.shape

図形に再配置 [patches, batch_size, d_model]

85        x = x.permute(2, 3, 0, 1)
86        x = x.view(h * w, bs, c)

パッチの埋め込みを返す

89        return x

パラメータ化された位置エンコーディングの追加

これにより、学習した位置埋め込みがパッチ埋め込みに追加されます。

92class LearnedPositionalEmbeddings(Module):
  • d_model 変圧器の埋め込みサイズです
  • max_len パッチの最大数です
101    def __init__(self, d_model: int, max_len: int = 5_000):
106        super().__init__()

各ロケーションの位置埋め込み

108        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
  • x 形状のパッチ埋め込みです [patches, batch_size, d_model]
110    def forward(self, x: torch.Tensor):

与えられたパッチの位置埋め込みを取得

115        pe = self.positional_encodings[:x.shape[0]]

パッチ埋め込みに追加して返す

117        return x + pe

MLP クラス分けヘッド

これは、[CLS] トークンの埋め込みに基づいて画像を分類するための2層のMLPヘッドです。

120class ClassificationHead(Module):
  • d_model 変圧器の埋め込みサイズです
  • n_hidden 隠れレイヤーのサイズ
  • n_classes 分類タスク内のクラス数です
128    def __init__(self, d_model: int, n_hidden: int, n_classes: int):
134        super().__init__()

第 1 レイヤー

136        self.linear1 = nn.Linear(d_model, n_hidden)

アクティベーション

138        self.act = nn.ReLU()

第 2 レイヤー

140        self.linear2 = nn.Linear(n_hidden, n_classes)
  • x トークンのトランスフォーマーエンコーディングです [CLS]
142    def forward(self, x: torch.Tensor):

第1層とアクティベーション

147        x = self.act(self.linear1(x))

第 2 レイヤー

149        x = self.linear2(x)

152        return x
155class VisionTransformer(Module):
163    def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
164                 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
165                 classification: ClassificationHead):
174        super().__init__()

パッチ埋め込み

176        self.patch_emb = patch_emb
177        self.pos_emb = pos_emb

分類ヘッド

179        self.classification = classification

トランスレイヤーのコピーを作成

181        self.transformer_layers = clone_module_list(transformer_layer, n_layers)

[CLS] トークン埋め込み

184        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)

最終正規化レイヤー

186        self.ln = nn.LayerNorm([transformer_layer.size])
  • x 形状の入力画像です [batch_size, channels, height, width]
188    def forward(self, x: torch.Tensor):

パッチの埋め込みを入手してください。これにより形状のテンソルが得られます。[patches, batch_size, d_model]

193        x = self.patch_emb(x)

位置埋め込みを追加

195        x = self.pos_emb(x)

[CLS] トランスフォーマーに給電する前にトークンの埋め込みを連結してください

197        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
198        x = torch.cat([cls_token_emb, x])

アテンション・マスクなしで変圧器層を通過

201        for layer in self.transformer_layers:
202            x = layer(x=x, mask=None)

[CLS] トークン (シーケンスの最初のもの) のトランスフォーマー出力を取得します。

205        x = x[0]

レイヤー正規化

208        x = self.ln(x)

ロジットを取得するための分類ヘッド

211        x = self.classification(x)

214        return x