这是论文《FNet:将代币与傅里叶变换混合》的 PyTor ch 实现。
本文用两个傅里叶变换取代了自我注意力层,以混合令牌。这比自我注意力更有效。在 GLUE 基准测试中,BERT 使用它而不是自我注意力的准确性损失约为92%。
我们沿隐藏维度(嵌入维度)应用傅里叶变换,然后沿序列维度应用傅里叶变换。
其中是嵌入输入,代表傅里叶变换,代表复数中的实分量。
这在 PyTorch 上实现起来非常简单-只需一行代码。本文建议使用预先计算的DFT矩阵并进行矩阵乘法来获得傅里叶变换。
以下是使用基于 FNet 的模型对 AG News 进行分类的训练代码。
41from typing import Optional
42
43import torch
44from torch import nn
普通注意力模块可以使用、和的不同令牌嵌入以及掩码进行馈送。
我们遵循相同的函数签名,以便我们可以直接替换它。
对于 FNet 混合,屏蔽是不可能的。query
(和key
和value
)的形状为[seq_len, batch_size, d_model]
。
60 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
、,对于令牌混合,al l 应该等于
72 assert query is key and key is value
令牌混合不支持掩码。即所有令牌都将看到所有其他令牌嵌入。
74 assert mask is None
为了清楚起见,x
请分配给
77 x = query
沿序列维度应用傅里叶变换
87 fft_seq = torch.fft.fft(fft_hidden, dim=0)
获取真正的组件
91 return torch.real(fft_seq)