Vision Transformer (ViT)

This is a PyTorch implementation of the paper An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale.

Vision transformer applies a pure transformer to images without any convolution layers. They split the image into patches and apply a transformer on patch embeddings. Patch embeddings are generated by applying a simple linear transformation to the flattened pixel values of the patch. Then a standard transformer encoder is fed with the patch embeddings, along with a classification token [CLS] . The encoding on the [CLS] token is used to classify the image with an MLP.

When feeding the transformer with the patches, learned positional embeddings are added to the patch embeddings, because the patch embeddings do not have any information about where that patch is from. The positional embeddings are a set of vectors for each patch location that get trained with gradient descent along with other parameters.

ViTs perform well when they are pre-trained on large datasets. The paper suggests pre-training them with an MLP classification head and then using a single linear layer when fine-tuning. The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. They also use higher resolution images during inference while keeping the patch size the same. The positional embeddings for new patch locations are calculated by interpolating learning positional embeddings.

Here's an experiment that trains ViT on CIFAR-10. This doesn't do very well because it's trained on a small dataset. It's a simple experiment that anyone can run and play with ViTs.

43import torch
44from torch import nn
45
46from labml_nn.transformers import TransformerLayer
47from labml_nn.utils import clone_module_list

Get patch embeddings

The paper splits the image into patches of equal size and do a linear transformation on the flattened pixels for each patch.

We implement the same thing through a convolution layer, because it's simpler to implement.

50class PatchEmbeddings(nn.Module):
  • d_model is the transformer embeddings size
  • patch_size is the size of the patch
  • in_channels is the number of channels in the input image (3 for rgb)
62    def __init__(self, d_model: int, patch_size: int, in_channels: int):
68        super().__init__()

We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch.

73        self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
  • x is the input image of shape [batch_size, channels, height, width]
75    def forward(self, x: torch.Tensor):

Apply convolution layer

80        x = self.conv(x)

Get the shape.

82        bs, c, h, w = x.shape

Rearrange to shape [patches, batch_size, d_model]

84        x = x.permute(2, 3, 0, 1)
85        x = x.view(h * w, bs, c)

Return the patch embeddings

88        return x

Add parameterized positional encodings

This adds learned positional embeddings to patch embeddings.

91class LearnedPositionalEmbeddings(nn.Module):
  • d_model is the transformer embeddings size
  • max_len is the maximum number of patches
100    def __init__(self, d_model: int, max_len: int = 5_000):
105        super().__init__()

Positional embeddings for each location

107        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
  • x is the patch embeddings of shape [patches, batch_size, d_model]
109    def forward(self, x: torch.Tensor):

Get the positional embeddings for the given patches

114        pe = self.positional_encodings[:x.shape[0]]

Add to patch embeddings and return

116        return x + pe

MLP Classification Head

This is the two layer MLP head to classify the image based on [CLS] token embedding.

119class ClassificationHead(nn.Module):
  • d_model is the transformer embedding size
  • n_hidden is the size of the hidden layer
  • n_classes is the number of classes in the classification task
127    def __init__(self, d_model: int, n_hidden: int, n_classes: int):
133        super().__init__()

First layer

135        self.linear1 = nn.Linear(d_model, n_hidden)

Activation

137        self.act = nn.ReLU()

Second layer

139        self.linear2 = nn.Linear(n_hidden, n_classes)
  • x is the transformer encoding for [CLS] token
141    def forward(self, x: torch.Tensor):

First layer and activation

146        x = self.act(self.linear1(x))

Second layer

148        x = self.linear2(x)

151        return x

Vision Transformer

This combines the patch embeddings, positional embeddings, transformer and the classification head.

154class VisionTransformer(nn.Module):
162    def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
163                 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
164                 classification: ClassificationHead):
173        super().__init__()

Patch embeddings

175        self.patch_emb = patch_emb
176        self.pos_emb = pos_emb

Classification head

178        self.classification = classification

Make copies of the transformer layer

180        self.transformer_layers = clone_module_list(transformer_layer, n_layers)

[CLS] token embedding

183        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)

Final normalization layer

185        self.ln = nn.LayerNorm([transformer_layer.size])
  • x is the input image of shape [batch_size, channels, height, width]
187    def forward(self, x: torch.Tensor):

Get patch embeddings. This gives a tensor of shape [patches, batch_size, d_model]

192        x = self.patch_emb(x)

Concatenate the [CLS] token embeddings before feeding the transformer

194        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
195        x = torch.cat([cls_token_emb, x])

Add positional embeddings

197        x = self.pos_emb(x)

Pass through transformer layers with no attention masking

200        for layer in self.transformer_layers:
201            x = layer(x=x, mask=None)

Get the transformer output of the [CLS] token (which is the first in the sequence).

204        x = x[0]

Layer normalization

207        x = self.ln(x)

Classification head, to get logits

210        x = self.classification(x)

213        return x