FNet:将令牌与傅里叶变换混合

这是论文《FNet:将代币与傅里叶变换混合》的 PyTor ch 实现。

本文用两个傅里叶变换取代了自我注意力层,以混合令牌。这比自我注意力更有效。在 GLUE 基准测试中,BERT 使用它而不是自我注意力的准确性损失约为92%。

将令牌与两个傅里叶变换混合

我们沿隐藏维度(嵌入维度)应用傅里叶变换,然后沿序列维度应用傅里叶变换。

其中是嵌入输入,代表傅里叶变换,代表复数中的实分量。

这在 PyTorch 上实现起来非常简单-只需一行代码。本文建议使用预先计算的DFT矩阵并进行矩阵乘法来获得傅里叶变换。

以下是使用基于 FNet 的模型对 AG News 进行分类的训练代码。

41from typing import Optional
42
43import torch
44from torch import nn

FNet-混合代币

这个模块简单地实现了

该模块的结构类似于标准的注意力模块,因此我们可以简单地替换它。

47class FNetMix(nn.Module):

普通注意力模块可以使用和的不同令牌嵌入以及掩码进行馈送。

我们遵循相同的函数签名,以便我们可以直接替换它。

对于 FNet 混合,屏蔽是不可能的。query (和keyvalue )的形状为[seq_len, batch_size, d_model]

60    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):

、,对于令牌混合,al l 应该等于

72        assert query is key and key is value

令牌混合不支持掩码。即所有令牌都将看到所有其他令牌嵌入。

74        assert mask is None

为了清楚起见,x 请分配给

77        x = query

沿隐藏(嵌入)维度应用傅里叶变换

傅里叶变换的输出是复数张量。

84        fft_hidden = torch.fft.fft(x, dim=2)

沿序列维度应用傅里叶变换

87        fft_seq = torch.fft.fft(fft_hidden, dim=0)

获取真正的组件

91        return torch.real(fft_seq)