This paper applies the model on vision tasks. The model is similar to a transformer with attention layer being replaced by a MLP that is applied across the patches (or tokens in case of a NLP task).
Our implementation of MLP Mixer is a drop in replacement for the self-attention layer in our transformer implementation. So it's just a couple of lines of code, transposing the tensor to apply the MLP across the sequence dimension.
29from typing import Optional 30 31import torch 32from torch import nn
This module is a drop-in replacement for self-attention layer. It transposes the input tensor before feeding it to the MLP and transposes back, so that the MLP is applied across the sequence dimension (across tokens or image patches) instead of the feature dimension.
ffnis the MLP module.
45 def __init__(self, mlp: nn.Module):
49 super().__init__() 50 self.mlp = mlp
The normal attention module can be fed with different token embeddings for ,, and and a mask.
We follow the same function signature so that we can replace it directly.
For MLP mixing, and masking is not possible. Shape of
[seq_len, batch_size, d_model]
52 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
,, and all should be the same
64 assert query is key and key is value
MLP mixer doesn't support masking. i.e. all tokens will see all other token embeddings.
66 assert mask is None
69 x = query
Transpose so that the last dimension is the sequence dimension. New shape is
[d_model, batch_size, seq_len]
73 x = x.transpose(0, 2)
Apply the MLP across tokens
75 x = self.mlp(x)
Transpose back into original form
77 x = x.transpose(0, 2)
80 return x