これは、フィードバックメモリを用いたシーケンシャル・トランスフォーマーの高位表現へのアクセスに関する論文「PyTorch」のPyTorch実装です。
通常のトランスフォーマーはトークンを並行して処理します。各トランス層は、前の層の出力に注目します。フィードバックトランスは、前のステップのすべてのレイヤーの出力に注目します。そのため、これによって繰り返しが発生し、トークンごとに処理する必要があります。これにより、トレーニングが大幅に遅くなります(シーケンスの長さにもよりますが、約5倍から10倍です)。ただし、Feedback Transformerを予測する場合、メモリベクトルをキャッシュすれば次のトークンを予測できるため、より高速です
。トレーニングをスピードアップするために、この論文では短いシーケンス長から始めて、徐々に長くする方法について説明します。また、事前学習済みの並列変圧器を出発点として使用する方法についても説明します
。オリジナルのフィードバックトランスは、すべてのレイヤーの出力を保持するわけではありません。代わりに、すべてのレイヤーの出力の加重合計が保持されます。これにより、予測中のキャッシュに使用されるメモリが減ります。このファイルの前半はこれを実装しています。
更新されたフィードバックトランスフォーマーは重みを共有し、レイヤー間のキーと値の計算に使用されます。その後、各ステップのキーと値を一度だけ計算し、キャッシュに保存します。このファイルの後半はこれを実装しています。パフォーマンスを向上させるために、カスタム PyTorch 関数を実装しました
。Tiny Shakespeareデータセットのフィードバックトランスフォーマーをトレーニングするためのトレーニングコードとノートブックは次のとおりです。
42import math
43from typing import Optional
44
45import torch
46from torch import nn
47
48from labml_helpers.module import Module
49from labml_nn.transformers.feed_forward import FeedForward
50from labml_nn.transformers.mha import PrepareForMultiHeadAttention
51from labml_nn.utils import clone_module_list
54class FeedbackAttention(Module):
d_model
はトランスフォーマー内のフィーチャーの数ですdropout_prob
は注意力低下の確率ですis_kv_precomputed
キー、値のテンソルが既に計算されているかどうかです65 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
66 is_kv_precomputed: bool = False):
74 super().__init__()
ヘッドあたりの機能数
77 self.d_k = d_model // heads
79 self.heads = heads
query
これらは多面的な注意力を変えます。
82 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
これらは頭の中を一変させkey
、value
多面的な注目を集めます。
84 if not is_kv_precomputed:
85 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
86 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
キーと値はすでに計算されています
88 else:
89 self.key = None
90 self.value = None
出力レイヤー
93 self.output = nn.Linear(d_model, d_model)
ドロップアウト
95 self.dropout = nn.Dropout(dropout_prob)
ソフトマックス前のスケーリングファクター
97 self.scale = 1 / math.sqrt(self.d_k)
時間軸に沿った注目のソフトマックス key
100 self.softmax = nn.Softmax(dim=0)
相対位置の数
103 self.P = 2 ** 12
クエリを基準としたキーの相対位置埋め込み。
106 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
クエリに対するキーの相対的な位置埋め込みバイアス。
108 self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
クエリの位置埋め込みはクエリの位置とは無関係です
110 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
必要に応じてロギングやその他の計算に使用できるように、アテンションを保存します
113 self.attn = None
Transformer-XLの論文で紹介されている相対マルチヘッドアテンションと同様に、アテンションには相対位置エンコーディングを使用しています。
現在のステップのクエリからキーインステップ(現在のステップを基準とする)への注意は、
ここで、は元の埋め込みの線形変換で、は位置エンコーディングの線形変換です。
用語をに置き換えます。
115 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
143 key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
145 query_pos_bias = self.query_pos_bias[None, :, :]
147 key_pos_bias = self.key_pos_bias[-key.shape[0]:]
150 ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
152 bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
155 return ac + bd
query
形がある [batch_size, d_model]
key
value
そして形をしている [seq_len, batch_size, d_model]
157 def forward(self, *,
158 query: torch.Tensor,
159 key: torch.Tensor,
160 value: torch.Tensor):
key
value
key
注意力計算の準備をしてquery
、value
そうすれば形ができて、[seq_len, batch_size, heads, d_k]
形ができていく query
[batch_size, heads, d_k]
169 query = self.query(query)
170 if self.key:
171 key = self.key(key)
172 if self.value:
173 value = self.value(value)
アテンションスコアを計算します。結果として形状のテンソルが生成されます [seq_len, batch_size, heads]
177 scores = self.get_scores(query, key)
スケールスコア
180 scores *= self.scale
ソフトマックス
183 attn = self.softmax(scores)
ドロップアウトを適用
186 attn = self.dropout(attn)
値を掛ける
189 x = torch.einsum("jbh,jbhd->bhd", attn, value)
複数のヘッドを連結
192 x = x.reshape(x.shape[0], -1)
出力レイヤー
195 return self.output(x)
198class FeedbackTransformerLayer(Module):
d_model
はトランスフォーマー内のフィーチャーの数ですattn
フィードバック・アテンション・モジュールですfeed_forward
位置単位のフィードフォワード層ですdropout_prob
アテンションとフィードフォワード後のドロップアウト層のドロップアウト確率です205 def __init__(self, *,
206 d_model: int,
207 attn: FeedbackAttention,
208 feed_forward: FeedForward,
209 dropout_prob: float):
216 super().__init__()
変圧器サイズ
218 self.size = d_model
220 self.attn = attn
221 self.feed_forward = feed_forward
222 self.dropout = nn.Dropout(dropout_prob)
正規化レイヤー
225 self.norm_self_attn = nn.LayerNorm([d_model])
226 self.norm_ff = nn.LayerNorm([d_model])
228 def forward(self, *,
229 x: torch.Tensor,
230 key: Optional[torch.Tensor],
231 value: Optional[torch.Tensor]):
メモリがあれば
233 if key is not None:
セルフアテンションを行う前にベクトルを正規化してください
235 z = self.norm_self_attn(x)
自己注意を向ける。つまり、キーと値は自己からのものだ
237 self_attn = self.attn(query=z, key=key, value=value)
セルフアテンションの結果を追加
239 x = x + self.dropout(self_attn)
フィードフォワード用に正規化
242 z = self.norm_ff(x)
フィードフォワードネットワークを通過
244 ff = self.feed_forward(z)
フィードフォワードの結果を追加し直す
246 x = x + self.dropout(ff)
249 return x
252class FeedbackTransformer(Module):
layer
はフィードバックトランス層で、各層ごとにクローンを作成しますn_layers
変圧器の層数です257 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
263 super().__init__()
トランスレイヤーのコピーを作成
265 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
267 self.norm = nn.LayerNorm([layer.size])
メモリベクトルは、各レイヤーの表現の加重和として計算されます。これがそのためのウェイトパラメータです
。270 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
加重和を取る前の重みのソフトマックス
272 self.softmax = nn.Softmax(0)
x_seq
は形状付きの入力です [seq_len, batch_size, d_model]
274 def forward(self, x_seq: torch.Tensor):
入力をシーケンス軸に沿ってリストに分割します
280 x_seq = torch.unbind(x_seq, dim=0)
出力を保存するリスト
282 res = []
メモリベクトルを格納するリスト
284 mem = []
各インプットステップについて
286 for x in x_seq:
レイヤー出力を保存するリスト
288 layer_outputs = [x]
メモリがあれば、それらをベクトルに積み重ねる
291 mem_tensor = torch.stack(mem) if mem else None
各レイヤーを貫通する
294 for layer in self.layers:
レイヤー出力を取得
296 x = layer(x=x, key=mem_tensor, value=mem_tensor)
それらをレイヤー出力のリストに追加します
298 layer_outputs.append(x)
レイヤーの出力をテンソルに積み重ねる
301 layer_outputs = torch.stack(layer_outputs)
メモリベクトルを層出力の加重和として計算します
303 mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
出力を結果に追加
305 res.append(x)
出力テンソルを積み重ねる
308 res = torch.stack(res)
出力を正規化
310 return self.norm(res)
torch.stack
Pythonリストに追加してから実行する代わりに、カスタム関数を実装します。これにより、torch.stack
シーケンスの各ステップで呼び出すよりもパフォーマンスが大幅に向上します。torch.stack
呼び出されるたびに新しいテンソルが作成され、Stack
このメソッドとそれに付随するクラスは各ステップのメモリを共有します
317class StackFunction(torch.autograd.Function):
ctx
関数のコンテキストです (キャッシュできます)memory
は、各ステップの値(キーと値)を積み重ねて保存する共有メモリテンソルですmemory_grad
各ステップの勾配を保存して蓄積する共有メモリテンソルですlast
最後に積み重ねられた値ですn
はステップ数 (つまりスタックのサイズ)これにより、ステップアップのスタックテンソルが返されます。n
329 @staticmethod
330 def forward(ctx, memory, memory_grad, last, n):
累積グラデーションをキャッシュ
342 ctx._mem_grad = memory_grad
スタックのサイズをキャッシュする
344 ctx._n = n
スタックを返す
346 return memory[:n + 1]
348 @staticmethod
349 def backward(ctx, grad_output):
スタックの現在のサイズを取得
357 n = ctx._n
累積したグラデーションを取得
359 memory_grad = ctx._mem_grad
グラデーションを追加
361 memory_grad[:n + 1] += grad_output
スタックの最後の値までのグラデーションを返します
363 return None, None, memory_grad[n], None
366class Stack:
max_len
スタックの最大サイズです373 def __init__(self, max_len: int):
377 self.max_len = max_len
378 self.memory = None
379 self.memory_grad = None
380 self.last = None
381 self.n = -1
382 self.last_get_n = -1
n
はスタックのサイズvalue
スタックに追加する必要があるテンソルです384 def append(self, n: int, value: torch.Tensor):
値を追加した後にスタックを取得(使用)する必要があります。そうしないと、この実装は失敗します
392 assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"
これをグラデーションなしで行う
395 with torch.no_grad():
共有メモリテンソルを初期化してスタックを維持する
397 if self.memory is None or self.memory.shape[1:] != value.shape:
これはスタックが空の場合にのみ発生するはずです。
399 assert n == 0
スタック用のテンソルの作成
401 self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)
テンソルを作成して勾配を累積する
403 self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)
408 elif n == 0:
累積したグラデーションをリセット
410 self.memory_grad.fill_(0.)
値をスタックの正しい位置に設定します
413 self.memory.data[n] = value.detach()
スタックを追跡する (デバッグ用)
415 self.n = n
420 self.last = value
スタックを返します
422 def get(self):
使用時のスタックのサイズを記録しておきます。append
これはサニティチェックインに使用されます
429 self.last_get_n = self.n
それをすべて実行して、バックプロパゲーション中に PyTorch StackFunction
StackFunction.backwards
から呼び出されるようにします。
432 return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)
メモリを解放するには
434 def free(self):
439 self.memory = None
440 self.memory_grad = None
441 self.last = None
444class FeedbackTransformerKV(Module):
layer
はフィードバックトランス層で、各層ごとにクローンを作成しますn_layers
変圧器の層数ですd_model
はトランスフォーマー内のフィーチャーの数です451 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
459 super().__init__()
トランスレイヤーのコピーを作成
461 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
463 self.norm = nn.LayerNorm([layer.size])
メモリベクトルは、各レイヤーの表現の加重和として計算されます。これがそのためのウェイトパラメータです
。466 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
加重和を取る前の重みのソフトマックス
468 self.softmax = nn.Softmax(0)
頭の中の機能の数
471 d_k = d_model // heads
埋め込み (メモリ) を変換してキーを取得するモジュール
473 self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
埋め込み (メモリ) を変換してキーを取得するモジュール
475 self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
スタックキー用メモリ
478 self.mem_key = Stack(512)
積み重ねた値用のメモリ
480 self.mem_value = Stack(512)
x_seq
は形状付きの入力です [seq_len, batch_size, d_model]
482 def forward(self, x_seq: torch.Tensor):
入力をシーケンス軸に沿ってリストに分割します
488 x_seq = torch.unbind(x_seq, dim=0)
出力を保存するリスト
490 res = []
各インプットステップについて
492 for step, x in enumerate(x_seq):
レイヤー出力を保存するリスト
494 layer_outputs = [x]
キーと値のスタック
497 key_tensor = None
498 value_tensor = None
最初のステップを超えたら、キーと値のテンソルを取得してください
500 if step > 0:
501 key_tensor = self.mem_key.get()
502 value_tensor = self.mem_value.get()
各レイヤーを貫通する
505 for layer in self.layers:
レイヤー出力を取得
507 x = layer(x=x, key=key_tensor, value=value_tensor)
それらをレイヤー出力のリストに追加します
509 layer_outputs.append(x)
レイヤーの出力をテンソルに積み重ねる
512 layer_outputs = torch.stack(layer_outputs)
メモリベクトルを層出力の加重和として計算します
514 mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))
メモリからキーを計算してスタックに追加する
516 self.mem_key.append(step, self.key(mem))
メモリから値を計算してスタックに追加する
518 self.mem_value.append(step, self.value(mem))
出力を結果に追加
520 res.append(x)
出力テンソルを積み重ねる
523 res = torch.stack(res)
出力を正規化
525 return self.norm(res)
527 def free(self):
528 self.mem_key.free()
529 self.mem_value.free()