これは、論文「FNet: トークンをフーリエ変換と混合する」をPyTorchで実装したものです。
この論文では、セルフアテンション層を2つのフーリエ変換に置き換えてトークンを混合します。これは自己注意よりも効率的です。BERT on GLUE ベンチマークでは、自己注意よりもこれを使用した場合の精度の低下は約 92%
です。フーリエ変換を非表示次元 (埋め込み次元) に沿って適用し、次にシーケンス次元に沿って適用します。
ここで、は埋め込み入力で、フーリエ変換を表し、複素数の実数成分を表します。
これをPyTorchに実装するのはとても簡単です。たった1行のコードです。この論文では、事前に計算されたDFT行列を使用し、行列の乗算を行ってフーリエ変換を行うことを提案しています
。以下は、FNetベースのモデルを使用してAG Newsを分類するためのトレーニングコードです。
41from typing import Optional
42
43import torch
44from torch import nn47class FNetMix(nn.Module):通常のアテンションモジュールには、、、、マスクにさまざまなトークンを埋め込むことができます。
同じ関数シグネチャに従うので、直接置換できます。
FNetミキシングの場合、マスキングはできません。query
(key
とvalue
) の形はです[seq_len, batch_size, d_model]
。
60    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):、、そしてトークンのミキシングではすべてが等しくなければなりません
72        assert query is key and key is valueトークンのミキシングはマスキングをサポートしていません。つまり、すべてのトークンに他のすべてのトークンの埋め込みが表示されます。
74        assert mask is Nonex
わかりやすいように割り当てる
77        x = query84        fft_hidden = torch.fft.fft(x, dim=2)シーケンスの次元に沿ってフーリエ変換を適用します
87        fft_seq = torch.fft.fft(fft_hidden, dim=0)本物のコンポーネントを手に入れよう
91        return torch.real(fft_seq)