MLP-Mixer:适用于视觉的全 MLP 架构

这是 PyTorch 对论文 MLP-Mixer:适用于视觉的全 MLP 架构的实现。

本文将该模型应用于视觉任务。该模型类似于变压器,注意力层被应用于补丁的 MLP(如果是 NLP 任务,则为代币)。

我们实现的 MLP Mixer 完全取代了变压器实现中的自注意力层。因此,这只是几行代码,对张量进行转置以在序列维度上应用 MLP。

尽管该论文将 MLP Mixer 应用于视觉任务,但我们在掩码语言模型上进行了尝试。这是实验代码

27from typing import Optional
28
29import torch
30from torch import nn

MLP 混音器

该模块是自我注意力层的直接替代品。它在将输入张量馈送到 MLP 之前对其进行转置并向后移调,这样 MLP 将应用于整个序列维度(跨标记或图像补丁),而不是特征维度。

33class MLPMixer(nn.Module):
  • ffn 是 MLP 模块。
43    def __init__(self, mlp: nn.Module):
47        super().__init__()
48        self.mlp = mlp

普通注意力模块可以使用和的不同令牌嵌入以及掩码进行馈送。

我们遵循相同的函数签名,以便我们可以直接替换它。

对于 MLP 混合,屏蔽是不可能的。query (和keyvalue )的形状为[seq_len, batch_size, d_model]

50    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):

,并且都应该是一样的

62        assert query is key and key is value

MLP 混音器不支持屏蔽。也就是说,所有令牌都会看到所有其他令牌嵌入。

64        assert mask is None

为了清楚起见,x 请分配给

67        x = query
进行@@

转置,使最后一个维度成为序列维度。新形状是[d_model, batch_size, seq_len]

71        x = x.transpose(0, 2)

跨代币应用 MLP

73        x = self.mlp(x)

移调回原始形式

75        x = x.transpose(0, 2)

78        return x