This is a PyTorch implementation of the paper MLP-Mixer: An all-MLP Architecture for Vision.
This paper applies the model on vision tasks. The model is similar to a transformer with attention layer being replaced by a MLP that is applied across the patches (or tokens in case of a NLP task).
Our implementation of MLP Mixer is a drop in replacement for the self-attention layer in our transformer implementation. So it's just a couple of lines of code, transposing the tensor to apply the MLP across the sequence dimension.
Although the paper applied MLP Mixer on vision tasks, we tried it on a masked language model. Here is the experiment code.
27from typing import Optional
28
29import torch
30from torch import nn
This module is a drop-in replacement for self-attention layer. It transposes the input tensor before feeding it to the MLP and transposes back, so that the MLP is applied across the sequence dimension (across tokens or image patches) instead of the feature dimension.
33class MLPMixer(nn.Module):
ffn
is the MLP module.43 def __init__(self, mlp: nn.Module):
47 super().__init__()
48 self.mlp = mlp
The normal attention module can be fed with different token embeddings for ,, and and a mask.
We follow the same function signature so that we can replace it directly.
For MLP mixing, and masking is not possible. Shape of query
(and key
and value
) is [seq_len, batch_size, d_model]
.
50 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
,, and all should be the same
62 assert query is key and key is value
MLP mixer doesn't support masking. i.e. all tokens will see all other token embeddings.
64 assert mask is None
Assign to x
for clarity
67 x = query
Transpose so that the last dimension is the sequence dimension. New shape is [d_model, batch_size, seq_len]
71 x = x.transpose(0, 2)
Apply the MLP across tokens
73 x = self.mlp(x)
Transpose back into original form
75 x = x.transpose(0, 2)
78 return x