Patches Are All You Need? (ConvMixer)

This is a PyTorch implementation of the paper Patches Are All You Need?.

ConvMixer diagram from the paper

ConvMixer is Similar to MLP-Mixer. MLP-Mixer separates mixing of spatial and channel dimensions, by applying an MLP across spatial dimension and then an MLP across the channel dimension (spatial MLP replaces the ViT attention and channel MLP is the FFN of ViT).

ConvMixer uses a convolution for channel mixing and a depth-wise convolution for spatial mixing. Since it's a convolution instead of a full MLP across the space, it mixes only the nearby batches in contrast to ViT or MLP-Mixer. Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a single layer for each mixing.

The paper recommends removing the residual connection across the channel mixing (point-wise convolution) and having only a residual connection over the spatial mixing (depth-wise convolution). They also use Batch normalization instead of Layer normalization.

Here's an experiment that trains ConvMixer on CIFAR-10.

View Run

38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.utils import clone_module_list

ConvMixer layer

This is a single ConvMixer layer. The model will have a series of these.

45class ConvMixerLayer(Module):
  • d_model is the number of channels in patch embeddings,
  • kernel_size is the size of the kernel of spatial convolution,
54    def __init__(self, d_model: int, kernel_size: int):
59        super().__init__()

Depth-wise convolution is separate convolution for each channel. We do this with a convolution layer with the number of groups equal to the number of channels. So that each channel is it's own group.

63        self.depth_wise_conv = nn.Conv2d(d_model, d_model,
64                                         kernel_size=kernel_size,
65                                         groups=d_model,
66                                         padding=(kernel_size - 1) // 2)

Activation after depth-wise convolution

68        self.act1 = nn.GELU()

Normalization after depth-wise convolution

70        self.norm1 = nn.BatchNorm2d(d_model)

Point-wise convolution is a convolution. i.e. a linear transformation of patch embeddings

74        self.point_wise_conv = nn.Conv2d(d_model, d_model, kernel_size=1)

Activation after point-wise convolution

76        self.act2 = nn.GELU()

Normalization after point-wise convolution

78        self.norm2 = nn.BatchNorm2d(d_model)
80    def forward(self, x: torch.Tensor):

For the residual connection around the depth-wise convolution

82        residual = x

Depth-wise convolution, activation and normalization

85        x = self.depth_wise_conv(x)
86        x = self.act1(x)
87        x = self.norm1(x)

Add residual connection

90        x += residual

Point-wise convolution, activation and normalization

93        x = self.point_wise_conv(x)
94        x = self.act2(x)
95        x = self.norm2(x)

98        return x

Get patch embeddings

This splits the image into patches of size and gives an embedding for each patch.

101class PatchEmbeddings(Module):
  • d_model is the number of channels in patch embeddings
  • patch_size is the size of the patch,
  • in_channels is the number of channels in the input image (3 for rgb)
110    def __init__(self, d_model: int, patch_size: int, in_channels: int):
116        super().__init__()

We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch.

121        self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

Activation function

123        self.act = nn.GELU()

Batch normalization

125        self.norm = nn.BatchNorm2d(d_model)
  • x is the input image of shape [batch_size, channels, height, width]
127    def forward(self, x: torch.Tensor):

Apply convolution layer

132        x = self.conv(x)

Activation and normalization

134        x = self.act(x)
135        x = self.norm(x)

138        return x

Classification Head

They do average pooling (taking the mean of all patch embeddings) and a final linear transformation to predict the log-probabilities of the image classes.

141class ClassificationHead(Module):
  • d_model is the number of channels in patch embeddings,
  • n_classes is the number of classes in the classification task
151    def __init__(self, d_model: int, n_classes: int):
156        super().__init__()

Average Pool

158        self.pool = nn.AdaptiveAvgPool2d((1, 1))

Linear layer

160        self.linear = nn.Linear(d_model, n_classes)
162    def forward(self, x: torch.Tensor):

Average pooling

164        x = self.pool(x)

Get the embedding, x will have shape [batch_size, d_model, 1, 1]

166        x = x[:, :, 0, 0]

Linear layer

168        x = self.linear(x)

171        return x

ConvMixer

This combines the patch embeddings block, a number of ConvMixer layers and a classification head.

174class ConvMixer(Module):
181    def __init__(self, conv_mixer_layer: ConvMixerLayer, n_layers: int,
182                 patch_emb: PatchEmbeddings,
183                 classification: ClassificationHead):
191        super().__init__()

Patch embeddings

193        self.patch_emb = patch_emb

Classification head

195        self.classification = classification

Make copies of the ConvMixer layer

197        self.conv_mixer_layers = clone_module_list(conv_mixer_layer, n_layers)
  • x is the input image of shape [batch_size, channels, height, width]
199    def forward(self, x: torch.Tensor):

Get patch embeddings. This gives a tensor of shape [batch_size, d_model, height / patch_size, width / patch_size] .

204        x = self.patch_emb(x)

Pass through ConvMixer layers

207        for layer in self.conv_mixer_layers:
208            x = layer(x)

Classification head, to get logits

211        x = self.classification(x)

214        return x