フィードバック変圧器

これは、フィードバックメモリを用いたシーケンシャル・トランスフォーマーの高位表現へのアクセスに関する論文PyTorch」のPyTorch実装です

通常のトランスフォーマーはトークンを並行して処理します。各トランス層は、前の層の出力に注目します。フィードバックトランスは、前のステップのすべてのレイヤーの出力に注目します。そのため、これによって繰り返しが発生し、トークンごとに処理する必要があります。これにより、トレーニングが大幅に遅くなります(シーケンスの長さにもよりますが、約5倍から10倍です)。ただし、Feedback Transformerを予測する場合、メモリベクトルをキャッシュすれば次のトークンを予測できるため、より高速です

トレーニングをスピードアップするために、この論文では短いシーケンス長から始めて、徐々に長くする方法について説明します。また、事前学習済みの並列変圧器を出発点として使用する方法についても説明します

オリジナルのフィードバックトランスは、すべてのレイヤーの出力を保持するわけではありません。代わりに、すべてのレイヤーの出力の加重合計が保持されます。これにより、予測中のキャッシュに使用されるメモリが減ります。このファイルの前半はこれを実装しています。

更新されたフィードバックトランスフォーマーは重みを共有し、レイヤー間のキーと値の計算に使用されます。その後、各ステップのキーと値を一度だけ計算し、キャッシュに保存します。このファイルの後半はこれを実装しています。パフォーマンスを向上させるために、カスタム PyTorch 関数を実装しました

Tiny Shakespeareデータセットのフィードバックトランスフォーマーをトレーニングするためのトレーニングコードとノートブックは次のとおりです

Open In Colab

42import math
43from typing import Optional
44
45import torch
46from torch import nn
47
48from labml_helpers.module import Module
49from labml_nn.transformers.feed_forward import FeedForward
50from labml_nn.transformers.mha import PrepareForMultiHeadAttention
51from labml_nn.utils import clone_module_list

フィードバック注意

このモジュールは、オリジナルのトランスフォーマーの論文で出てきたアテンションと同様の反復アテンションを計算します。

54class FeedbackAttention(Module):
  • 「ヘッド」はアテンションヘッドの数です
  • d_model はトランスフォーマー内のフィーチャーの数です
  • dropout_prob は注意力低下の確率です
  • is_kv_precomputed キー、値のテンソルが既に計算されているかどうかです
65    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
66                 is_kv_precomputed: bool = False):
74        super().__init__()

ヘッドあたりの機能数

77        self.d_k = d_model // heads

79        self.heads = heads

query これらは多面的な注意力を変えます。

82        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

これらは頭の中を一変させkeyvalue 多面的な注目を集めます。

84        if not is_kv_precomputed:
85            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
86            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

キーと値はすでに計算されています

88        else:
89            self.key = None
90            self.value = None

出力レイヤー

93        self.output = nn.Linear(d_model, d_model)

ドロップアウト

95        self.dropout = nn.Dropout(dropout_prob)

ソフトマックス前のスケーリングファクター

97        self.scale = 1 / math.sqrt(self.d_k)

時間軸に沿った注目のソフトマックス key

100        self.softmax = nn.Softmax(dim=0)

相対位置の数

103        self.P = 2 ** 12

クエリを基準としたキーの相対位置埋め込み。

106        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)

クエリに対するキーの相対的な位置埋め込みバイアス。

108        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)

クエリの位置埋め込みはクエリの位置とは無関係です

110        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

必要に応じてロギングやその他の計算に使用できるように、アテンションを保存します

113        self.attn = None

アテンションスコアを取得

Transformer-XLの論文で紹介されている相対マルチヘッドアテンションと同様に、アテンションには相対位置エンコーディングを使用しています

現在のステップのクエリからキーインステップ(現在のステップを基準とする)への注意は、

ここでは元の埋め込みの線形変換で、は位置エンコーディングの線形変換です。

用語をに置き換えます

115    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

143        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]

145        query_pos_bias = self.query_pos_bias[None, :, :]

147        key_pos_bias = self.key_pos_bias[-key.shape[0]:]

150        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)

152        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]

155        return ac + bd
  • query 形がある [batch_size, d_model]
  • key value そして形をしている [seq_len, batch_size, d_model]
157    def forward(self, *,
158                query: torch.Tensor,
159                key: torch.Tensor,
160                value: torch.Tensor):

key value key 注意力計算の準備をしてqueryvalue そうすれば形ができて、[seq_len, batch_size, heads, d_k] 形ができていく query [batch_size, heads, d_k]

169        query = self.query(query)
170        if self.key:
171            key = self.key(key)
172        if self.value:
173            value = self.value(value)

アテンションスコアを計算します。結果として形状のテンソルが生成されます [seq_len, batch_size, heads]

177        scores = self.get_scores(query, key)

スケールスコア

180        scores *= self.scale

ソフトマックス

183        attn = self.softmax(scores)

ドロップアウトを適用

186        attn = self.dropout(attn)

値を掛ける

189        x = torch.einsum("jbh,jbhd->bhd", attn, value)

複数のヘッドを連結

192        x = x.reshape(x.shape[0], -1)

出力レイヤー

195        return self.output(x)

フィードバックトランス層

これにより、フィードバックトランスに単一のトランス層が実装されます。

198class FeedbackTransformerLayer(Module):
  • d_model はトランスフォーマー内のフィーチャーの数です
  • attn フィードバック・アテンション・モジュールです
  • feed_forward 位置単位のフィードフォワード層です
  • dropout_prob アテンションとフィードフォワード後のドロップアウト層のドロップアウト確率です
205    def __init__(self, *,
206                 d_model: int,
207                 attn: FeedbackAttention,
208                 feed_forward: FeedForward,
209                 dropout_prob: float):
216        super().__init__()

変圧器サイズ

218        self.size = d_model

220        self.attn = attn
221        self.feed_forward = feed_forward
222        self.dropout = nn.Dropout(dropout_prob)

正規化レイヤー

225        self.norm_self_attn = nn.LayerNorm([d_model])
226        self.norm_ff = nn.LayerNorm([d_model])
228    def forward(self, *,
229                x: torch.Tensor,
230                key: Optional[torch.Tensor],
231                value: Optional[torch.Tensor]):

メモリがあれば

233        if key is not None:

セルフアテンションを行う前にベクトルを正規化してください

235            z = self.norm_self_attn(x)

自己注意を向ける。つまり、キーと値は自己からのものだ

237            self_attn = self.attn(query=z, key=key, value=value)

セルフアテンションの結果を追加

239            x = x + self.dropout(self_attn)

フィードフォワード用に正規化

242        z = self.norm_ff(x)

フィードフォワードネットワークを通過

244        ff = self.feed_forward(z)

フィードフォワードの結果を追加し直す

246        x = x + self.dropout(ff)

249        return x

フィードバックトランスモジュール

252class FeedbackTransformer(Module):
  • layer はフィードバックトランス層で、各層ごとにクローンを作成します
  • n_layers 変圧器の層数です
257    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
263        super().__init__()

トランスレイヤーのコピーを作成

265        self.layers = clone_module_list(layer, n_layers)

最終正規化レイヤー

267        self.norm = nn.LayerNorm([layer.size])

メモリベクトルは、各レイヤーの表現の加重和として計算されます。これがそのためのウェイトパラメータです

270        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

加重和を取る前の重みのソフトマックス

272        self.softmax = nn.Softmax(0)
  • x_seq は形状付きの入力です [seq_len, batch_size, d_model]
274    def forward(self, x_seq: torch.Tensor):

入力をシーケンス軸に沿ってリストに分割します

280        x_seq = torch.unbind(x_seq, dim=0)

出力を保存するリスト

282        res = []

メモリベクトルを格納するリスト

284        mem = []

各インプットステップについて

286        for x in x_seq:

レイヤー出力を保存するリスト

288            layer_outputs = [x]

メモリがあれば、それらをベクトルに積み重ねる

291            mem_tensor = torch.stack(mem) if mem else None

各レイヤーを貫通する

294            for layer in self.layers:

レイヤー出力を取得

296                x = layer(x=x, key=mem_tensor, value=mem_tensor)

それらをレイヤー出力のリストに追加します

298                layer_outputs.append(x)

レイヤーの出力をテンソルに積み重ねる

301            layer_outputs = torch.stack(layer_outputs)

メモリベクトルを層出力の加重和として計算します

303            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))

出力を結果に追加

305            res.append(x)

出力テンソルを積み重ねる

308        res = torch.stack(res)

出力を正規化

310        return self.norm(res)

レイヤー間での共有キーと値

スタック機能の実装

torch.stack Pythonリストに追加してから実行する代わりに、カスタム関数を実装します。これにより、torch.stack シーケンスの各ステップで呼び出すよりもパフォーマンスが大幅に向上します。torch.stack 呼び出されるたびに新しいテンソルが作成され、Stack このメソッドとそれに付随するクラスは各ステップのメモリを共有します

317class StackFunction(torch.autograd.Function):
  • ctx 関数のコンテキストです (キャッシュできます)
  • memory は、各ステップの値(キーと値)を積み重ねて保存する共有メモリテンソルです
  • memory_grad 各ステップの勾配を保存して蓄積する共有メモリテンソルです
  • last 最後に積み重ねられた値です
  • n はステップ数 (つまりスタックのサイズ)

これにより、ステップアップのスタックテンソルが返されます。n

329    @staticmethod
330    def forward(ctx, memory, memory_grad, last, n):

累積グラデーションをキャッシュ

342        ctx._mem_grad = memory_grad

スタックのサイズをキャッシュする

344        ctx._n = n

スタックを返す

346        return memory[:n + 1]
  • grad_output about 関数の出力に対する勾配です forward

これにより、共有メモリテンソルに勾配が蓄積され、スタック内の結果に対する勾配が返されます。last

348    @staticmethod
349    def backward(ctx, grad_output):

スタックの現在のサイズを取得

357        n = ctx._n

累積したグラデーションを取得

359        memory_grad = ctx._mem_grad

グラデーションを追加

361        memory_grad[:n + 1] += grad_output

スタックの最後の値までのグラデーションを返します

363        return None, None, memory_grad[n], None

スタックモジュール

これは上記で定義したスタック関数を使用し、必要な初期化を行います。

366class Stack:
  • max_len スタックの最大サイズです
373    def __init__(self, max_len: int):
377        self.max_len = max_len
378        self.memory = None
379        self.memory_grad = None
380        self.last = None
381        self.n = -1
382        self.last_get_n = -1
  • n はスタックのサイズ
  • value スタックに追加する必要があるテンソルです
384    def append(self, n: int, value: torch.Tensor):

値を追加した後にスタックを取得(使用)する必要があります。そうしないと、この実装は失敗します

392        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"

これをグラデーションなしで行う

395        with torch.no_grad():

共有メモリテンソルを初期化してスタックを維持する

397            if self.memory is None or self.memory.shape[1:] != value.shape:

これはスタックが空の場合にのみ発生するはずです。

399                assert n == 0

スタック用のテンソルの作成

401                self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)

テンソルを作成して勾配を累積する

403                self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)

メモリはすでに初期化されていますが、スタックをリセットしています。

これは別の機能でもかまいませんがreset 、こちらの方が使いやすいことがわかりました。

408            elif n == 0:

累積したグラデーションをリセット

410                self.memory_grad.fill_(0.)

値をスタックの正しい位置に設定します

413            self.memory.data[n] = value.detach()

スタックを追跡する (デバッグ用)

415            self.n = n

スタックに最後に追加された値を記録します。グラデーションを逆方向に伝播させるには、これを渡す必要があります

StackFunction
420        self.last = value

スタックを返します

422    def get(self):

使用時のスタックのサイズを記録しておきます。append これはサニティチェックインに使用されます

429        self.last_get_n = self.n

それをすべて実行して、バックプロパゲーション中に PyTorch StackFunction StackFunction.backwards から呼び出されるようにします。

432        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)

メモリを解放するには

434    def free(self):
439        self.memory = None
440        self.memory_grad = None
441        self.last = None

更新されたフィードバックトランスモジュール

これは、キーと値をキャッシュする更新されたフィードバックトランスモジュールです。

444class FeedbackTransformerKV(Module):
  • layer はフィードバックトランス層で、各層ごとにクローンを作成します
  • n_layers 変圧器の層数です
  • d_model はトランスフォーマー内のフィーチャーの数です
  • 「ヘッド」はアテンションヘッドの数です
451    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
459        super().__init__()

トランスレイヤーのコピーを作成

461        self.layers = clone_module_list(layer, n_layers)

最終正規化レイヤー

463        self.norm = nn.LayerNorm([layer.size])

メモリベクトルは、各レイヤーの表現の加重和として計算されます。これがそのためのウェイトパラメータです

466        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

加重和を取る前の重みのソフトマックス

468        self.softmax = nn.Softmax(0)

頭の中の機能の数

471        d_k = d_model // heads

埋め込み (メモリ) を変換してキーを取得するモジュール

473        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

埋め込み (メモリ) を変換してキーを取得するモジュール

475        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

スタックキー用メモリ

478        self.mem_key = Stack(512)

積み重ねた値用のメモリ

480        self.mem_value = Stack(512)
  • x_seq は形状付きの入力です [seq_len, batch_size, d_model]
482    def forward(self, x_seq: torch.Tensor):

入力をシーケンス軸に沿ってリストに分割します

488        x_seq = torch.unbind(x_seq, dim=0)

出力を保存するリスト

490        res = []

各インプットステップについて

492        for step, x in enumerate(x_seq):

レイヤー出力を保存するリスト

494            layer_outputs = [x]

キーと値のスタック

497            key_tensor = None
498            value_tensor = None

最初のステップを超えたら、キーと値のテンソルを取得してください

500            if step > 0:
501                key_tensor = self.mem_key.get()
502                value_tensor = self.mem_value.get()

各レイヤーを貫通する

505            for layer in self.layers:

レイヤー出力を取得

507                x = layer(x=x, key=key_tensor, value=value_tensor)

それらをレイヤー出力のリストに追加します

509                layer_outputs.append(x)

レイヤーの出力をテンソルに積み重ねる

512            layer_outputs = torch.stack(layer_outputs)

メモリベクトルを層出力の加重和として計算します

514            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))

メモリからキーを計算してスタックに追加する

516            self.mem_key.append(step, self.key(mem))

メモリから値を計算してスタックに追加する

518            self.mem_value.append(step, self.value(mem))

出力を結果に追加

520            res.append(x)

出力テンソルを積み重ねる

523        res = torch.stack(res)

出力を正規化

525        return self.norm(res)
527    def free(self):
528        self.mem_key.free()
529        self.mem_value.free()