これは PyTorch の長距離シーケンスモデリング用の圧縮トランスフォーマーの実装です。
これはTransformer XLの拡張版で、過去の記憶を圧縮して注意範囲を広げています。つまり、最も遠いメモリがメモリに圧縮されます。ここで、は圧縮率です
。圧縮操作は次のように定義されます。この論文では複数の選択肢を紹介していますが、最良の結果が得られると思われる1次元の畳み込みのみを実装しています。各レイヤーには個別の圧縮操作があります。ここで、はレイヤー番号です。
BPTTによるトレーニング圧縮では、非常に大きな計算グラフ(多くのタイムステップ)を維持する必要があるため、この論文では自動エンコーディング損失と注意再構成損失を提案しています。自動エンコーディング損失は、圧縮されたメモリから元のメモリをデコードし、損失を計算します。アテンション再構成損失では、圧縮メモリと非圧縮メモリでマルチヘッドアテンションの結果を計算し、それらの間の平均二乗誤差を求めます。後者の方が良い結果が得られるため、ここでは後者を実装しました。
この実装ではレイヤー前の正規化を使用しますが、ペーパーではレイヤー後の正規化を使用します。前層ノルムはFFNやセルフアテンション前の層ノルムを行い、残差接続でのパススルーは正規化されません。これは標準的な変圧器の設定ではより安定しているはずです
。Tiny Shakespeareデータセットで圧縮トランスフォーマーモデルをトレーニングするためのトレーニングコードとノートブックは次のとおりです。
53from typing import Optional, List
54
55import torch
56import torch.nn.functional as F
57from torch import nn
58
59from labml_helpers.module import Module, TypedModuleList
60from labml_nn.transformers.feed_forward import FeedForward
61from labml_nn.transformers.mha import PrepareForMultiHeadAttention
62from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
63from labml_nn.utils import clone_module_list
compression_rate
d_model
は埋め込みサイズ74 def __init__(self, compression_rate: int, d_model: int):
79 super().__init__()
80 self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)
mem
形がある [seq_len, batch, d_model]
82 def forward(self, mem: torch.Tensor):
の次元を並べ替えて、mem
畳み込み層に通せるようにします。畳み込み層は次の形式を受け入れます [batch, features, sequence]
89 mem = mem.permute(1, 2, 0)
畳み込み層に通して圧縮メモリを取得
91 c_mem = self.conv(mem)
フォームに戻す [seq_len, batch, d_model]
93 return c_mem.permute(2, 0, 1)
96class CompressiveTransformerLayer(Module):
d_model
トークンの埋め込みサイズですself_attn
セルフアテンションモジュールですfeed_forward
フィードフォワードモジュールですdropout_prob
セルフアテンションとFFNの後に脱落する確率ですcompress
は圧縮関数です 102 def __init__(self, *,
103 d_model: int,
104 self_attn: RelativeMultiHeadAttention,
105 feed_forward: FeedForward,
106 dropout_prob: float,
107 compress: Conv1dCompression):
115 super().__init__()
116 self.compress = compress
117 self.size = d_model
118 self.self_attn = self_attn
119 self.feed_forward = feed_forward
120 self.dropout = nn.Dropout(dropout_prob)
121 self.norm_self_attn = nn.LayerNorm([d_model])
122 self.norm_ff = nn.LayerNorm([d_model])
124 def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
メモリがない場合は、トークンの埋め込みを返してください
133 if mem is None:
134 return z
圧縮メモリがある場合は、それをメモリと連結します。
137 if c_mem is not None:
138 mem = torch.cat((c_mem, mem), dim=0)
メモリを正規化層に通す
141 mem = self.norm_self_attn(mem)
正規化されたメモリと正規化されたトークンの埋め込みを連結する
143 return torch.cat((mem, z), dim=0)
x
形状のトークンレベルの特徴ベクトルのテンソルです [seq_len, batch_size, d_model]
mem
過去のトークンレベルの形状の特徴ベクトル (メモリ) のテンソルです [mem_len, batch_size, d_model]
c_mem
圧縮メモリのテンソルです [c_mem_len, batch_size, d_model]
mask
[seq_len, c_mem_len + mem_len + seq_len, batch_size]
[seq_len, c_mem_len + mem_len + seq_len, 1]
は形状のマトリックスかmask[i, j]
トークン at が at i
のトークンを参照できる場合は true j
になります。145 def forward(self, *,
146 x: torch.Tensor,
147 mem: Optional[torch.Tensor],
148 c_mem: Optional[torch.Tensor],
149 mask: torch.Tensor):
セルフアテンションを行う前にベクトルを正規化してください
159 z = self.norm_self_attn(x)
メモリと圧縮メモリの正規化と連結
161 m_z = self.concat_memory(z, mem, c_mem)
注意
163 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
アテンション結果を追加
165 x = x + self.dropout(self_attn)
フィードフォワード用に正規化
168 z = self.norm_ff(x)
フィードフォワードネットワークを通過
170 ff = self.feed_forward(z)
フィードフォワードの結果を追加し直す
172 x = x + self.dropout(ff)
175 return x
178class CompressiveTransformer(Module):
185 def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
186 super().__init__()
トランスレイヤーのコピーを作成
188 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
190 self.norm = nn.LayerNorm([layer.size])
x
トークン埋め込みの形状ベクトルのテンソルです [seq_len, batch_size, d_model]
mem
過去のトークンレベルのテンソル、[mem_len, batch_size, d_model]
各レイヤーの形状ベクトルのリストですc_mem
[c_mem_len, batch_size, d_model]
各レイヤーの圧縮メモリのテンソルのリストですmask
はマスキングマトリックスです192 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):
次のシーケンシャルバッチのメモリとなるトークンレベルの特徴ベクトルを格納するリスト。
203 new_mem = []
各変圧器層に通す
205 for i, layer in enumerate(self.layers):
特徴ベクトルのリストに追加
207 new_mem.append(x.detach())
メモリー
209 m = mem[i] if mem else None
圧縮メモリ
211 cm = c_mem[i] if c_mem else None
トランスフォーマーXLレイヤーを通す
213 x = layer(x=x, mem=m, c_mem=cm, mask=mask)
最後に、ベクトルを正規化します。
215 return self.norm(x), new_mem
注意再構成損失は、非圧縮メモリと圧縮メモリで自己注意出力を再現し、両者の平均二乗誤差を計算します。これは位置エンコーディングなしで行います
。注意再構成損失を伴う圧縮関数の計算とトレーニングを行うと、すべてのパラメーターがフリーズします。これには、正規化後のキー/値の予測とバイアス/スケーリングが含まれます
。この損失はモデルのクロスエントロピー損失とは独立して計算できるため、更新のみを行う別のオプティマイザーを使用できます。ただし、更新には同じオプティマイザーを使用するため、注意再構成損失を計算するときは、勾配計算を除く他のすべてのパラメーターを切り離します
。218class AttentionReconstructionLoss:
layers
圧縮トランスレイヤーのリストです
236 def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
240 self.layers = layers
241 self.loss_func = nn.MSELoss()
これは 「PrepareForMultiHeadAttention」を再実装したもので、勾配計算から切り離されたパラメーターを使用して投影が行われます。
pmha
は 「マルチヘッドアテンション対策」モジュールですx
トークンが埋め込まれたテンソルです243 def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):
埋め込み寸法以外の入力の形状;[seq_len, batch_size]
.
253 head_shape = x.shape[:-1]
プロジェクションウェイトとバイアスをデタッチ
256 weight = pmha.linear.weight.detach()
257 bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None
線形変換
259 x = F.linear(x, weight, bias)
最後のディメンションをヘッドに分割
262 x = x.view(*head_shape, pmha.heads, pmha.d_k)
[seq_len, batch_size, heads, d_k]
出力の形状があるか [batch_size, d_model]
265 return x
これは 「マルチヘッドアテンション」を再実装したもので、「PrepareForMultiHead Attention」prepare_for_attn
の代わりに呼び出してプロジェクションパラメータをデタッチします。
267 def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
クエリ、キー、値の予測を計算
274 query = self.prepare_for_attn(layer.query, query)
275 key = self.prepare_for_attn(layer.key, key)
276 value = self.prepare_for_attn(layer.value, value)
アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads]
これにより形状のテンソルが得られます
280 scores = torch.einsum('ibhd,jbhd->ijbh', query, key)
スケールスコア
283 scores *= layer.scale
キーシーケンス次元に沿って注目
287 attn = layer.softmax(scores)
値による乗算
291 return torch.einsum("ijbh,jbhd->ibhd", attn, value)
シフトとスケールのパラメーターをデタッチしてレイヤーの正規化を実行します。
293 def norm(self, ln: nn.LayerNorm, x: torch.Tensor):
shift (bias
) とスケーリング (weight
) パラメーターのデタッチ
299 weight = ln.weight.detach() if ln.weight is not None else None
300 bias = ln.bias.detach() if ln.bias is not None else None
レイヤー正規化
303 return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)
これにより、レイヤーの損失が計算されます
305 def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):
トークンの埋め込みとメモリを切り離します。
311 h = h.detach()
312 mem = mem.detach()
でメモリを圧縮します。のパラメータは、勾配計算から切り離されない唯一のパラメータです
。316 c_mem = layer.compress(mem)
埋め込みとメモリを正規化
319 h = self.norm(layer.norm_self_attn, h)
320 mem = self.norm(layer.norm_self_attn, mem)
321 c_mem = self.norm(layer.norm_self_attn, c_mem)
非圧縮メモリで注意度を計算
324 attn_mem = self.attn(layer.self_attn, h, mem, mem)
圧縮メモリで注意度を計算
326 attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)
平均二乗誤差の計算
329 return self.loss_func(attn_cmem, attn_mem)
331 def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):
各層の損失の計算
333 losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]
損失の合計
335 return sum(losses)