MLP-Mixer: An all-MLP Architecture for Vision

This is a PyTorch implementation of the paper MLP-Mixer: An all-MLP Architecture for Vision.

This paper applies the model on vision tasks. The model is similar to a transformer with attention layer being replaced by a MLP that is applied across the patches (or tokens in case of a NLP task).

Our implementation of MLP Mixer is a drop in replacement for the self-attention layer in our transformer implementation. So it's just a couple of lines of code, transposing the tensor to apply the MLP across the sequence dimension.

Although the paper applied MLP Mixer on vision tasks, we tried it on a masked language model. Here is the experiment code.

27from typing import Optional
29import torch
30from torch import nn

MLP Mixer

This module is a drop-in replacement for self-attention layer. It transposes the input tensor before feeding it to the MLP and transposes back, so that the MLP is applied across the sequence dimension (across tokens or image patches) instead of the feature dimension.

33class MLPMixer(nn.Module):
  • ffn is the MLP module.
43    def __init__(self, mlp: nn.Module):
47        super().__init__()
48        self.mlp = mlp

The normal attention module can be fed with different token embeddings for ,, and and a mask.

We follow the same function signature so that we can replace it directly.

For MLP mixing, and masking is not possible. Shape of query (and key and value ) is [seq_len, batch_size, d_model] .

50    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):

,, and all should be the same

62        assert query is key and key is value

MLP mixer doesn't support masking. i.e. all tokens will see all other token embeddings.

64        assert mask is None

Assign to x for clarity

67        x = query

Transpose so that the last dimension is the sequence dimension. New shape is [d_model, batch_size, seq_len]

71        x = x.transpose(0, 2)

Apply the MLP across tokens

73        x = self.mlp(x)

Transpose back into original form

75        x = x.transpose(0, 2)

78        return x