这是 PyTorch 对论文 MLP-Mixer:适用于视觉的全 MLP 架构的实现。
本文将该模型应用于视觉任务。该模型类似于变压器,注意力层被应用于补丁的 MLP(如果是 NLP 任务,则为代币)。
我们实现的 MLP Mixer 完全取代了变压器实现中的自注意力层。因此,这只是几行代码,对张量进行转置以在序列维度上应用 MLP。
27from typing import Optional
28
29import torch
30from torch import nn
33class MLPMixer(nn.Module):
ffn
是 MLP 模块。43 def __init__(self, mlp: nn.Module):
47 super().__init__()
48 self.mlp = mlp
普通注意力模块可以使用、和的不同令牌嵌入以及掩码进行馈送。
我们遵循相同的函数签名,以便我们可以直接替换它。
对于 MLP 混合,屏蔽是不可能的。query
(和key
和value
)的形状为[seq_len, batch_size, d_model]
。
50 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
、,并且都应该是一样的
62 assert query is key and key is value
MLP 混音器不支持屏蔽。也就是说,所有令牌都会看到所有其他令牌嵌入。
64 assert mask is None
为了清楚起见,x
请分配给
67 x = query
转置,使最后一个维度成为序列维度。新形状是[d_model, batch_size, seq_len]
71 x = x.transpose(0, 2)
跨代币应用 MLP
73 x = self.mlp(x)
移调回原始形式
75 x = x.transpose(0, 2)
78 return x