これは PyTorch の長距離シーケンスモデリング用の圧縮トランスフォーマーの実装です。
これはTransformer XLの拡張版で、過去の記憶を圧縮して注意範囲を広げています。つまり、最も遠いメモリがメモリに圧縮されます。ここで、は圧縮率です
。圧縮操作は次のように定義されます。この論文では複数の選択肢を紹介していますが、最良の結果が得られると思われる1次元の畳み込みのみを実装しています。各レイヤーには個別の圧縮操作があります。ここで、はレイヤー番号です。
BPTTによるトレーニング圧縮では、非常に大きな計算グラフ(多くのタイムステップ)を維持する必要があるため、この論文では自動エンコーディング損失と注意再構成損失を提案しています。自動エンコーディング損失は、圧縮されたメモリから元のメモリをデコードし、損失を計算します。アテンション再構成損失では、圧縮メモリと非圧縮メモリでマルチヘッドアテンションの結果を計算し、それらの間の平均二乗誤差を求めます。後者の方が良い結果が得られるため、ここでは後者を実装しました。
この実装ではレイヤー前の正規化を使用しますが、ペーパーではレイヤー後の正規化を使用します。前層ノルムはFFNやセルフアテンション前の層ノルムを行い、残差接続でのパススルーは正規化されません。これは標準的な変圧器の設定ではより安定しているはずです
。Tiny Shakespeareデータセットで圧縮トランスフォーマーモデルをトレーニングするためのトレーニングコードとノートブックは次のとおりです。
53from typing import Optional, List
54
55import torch
56import torch.nn.functional as F
57from torch import nn
58
59from labml_helpers.module import Module, TypedModuleList
60from labml_nn.transformers.feed_forward import FeedForward
61from labml_nn.transformers.mha import PrepareForMultiHeadAttention
62from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
63from labml_nn.utils import clone_module_listcompression_rate
d_model
は埋め込みサイズ74    def __init__(self, compression_rate: int, d_model: int):79        super().__init__()
80        self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)mem
形がある [seq_len, batch, d_model]
82    def forward(self, mem: torch.Tensor):の次元を並べ替えて、mem
畳み込み層に通せるようにします。畳み込み層は次の形式を受け入れます [batch, features, sequence]
89        mem = mem.permute(1, 2, 0)畳み込み層に通して圧縮メモリを取得
91        c_mem = self.conv(mem)フォームに戻す [seq_len, batch, d_model]
93        return c_mem.permute(2, 0, 1)96class CompressiveTransformerLayer(Module):d_model
トークンの埋め込みサイズですself_attn
セルフアテンションモジュールですfeed_forward
フィードフォワードモジュールですdropout_prob
セルフアテンションとFFNの後に脱落する確率ですcompress
は圧縮関数です 102    def __init__(self, *,
103                 d_model: int,
104                 self_attn: RelativeMultiHeadAttention,
105                 feed_forward: FeedForward,
106                 dropout_prob: float,
107                 compress: Conv1dCompression):115        super().__init__()
116        self.compress = compress
117        self.size = d_model
118        self.self_attn = self_attn
119        self.feed_forward = feed_forward
120        self.dropout = nn.Dropout(dropout_prob)
121        self.norm_self_attn = nn.LayerNorm([d_model])
122        self.norm_ff = nn.LayerNorm([d_model])124    def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):メモリがない場合は、トークンの埋め込みを返してください
133        if mem is None:
134            return z圧縮メモリがある場合は、それをメモリと連結します。
137        if c_mem is not None:
138            mem = torch.cat((c_mem, mem), dim=0)メモリを正規化層に通す
141        mem = self.norm_self_attn(mem)正規化されたメモリと正規化されたトークンの埋め込みを連結する
143        return torch.cat((mem, z), dim=0)x
形状のトークンレベルの特徴ベクトルのテンソルです [seq_len, batch_size, d_model]
mem
過去のトークンレベルの形状の特徴ベクトル (メモリ) のテンソルです [mem_len, batch_size, d_model]
c_mem
圧縮メモリのテンソルです [c_mem_len, batch_size, d_model]
mask
[seq_len, c_mem_len + mem_len + seq_len, batch_size]
[seq_len, c_mem_len + mem_len + seq_len, 1]
は形状のマトリックスかmask[i, j]
トークン at が at i
 のトークンを参照できる場合は true j
 になります。145    def forward(self, *,
146                x: torch.Tensor,
147                mem: Optional[torch.Tensor],
148                c_mem: Optional[torch.Tensor],
149                mask: torch.Tensor):セルフアテンションを行う前にベクトルを正規化してください
159        z = self.norm_self_attn(x)メモリと圧縮メモリの正規化と連結
161        m_z = self.concat_memory(z, mem, c_mem)注意
163        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)アテンション結果を追加
165        x = x + self.dropout(self_attn)フィードフォワード用に正規化
168        z = self.norm_ff(x)フィードフォワードネットワークを通過
170        ff = self.feed_forward(z)フィードフォワードの結果を追加し直す
172        x = x + self.dropout(ff)175        return x178class CompressiveTransformer(Module):185    def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
186        super().__init__()トランスレイヤーのコピーを作成
188        self.layers = clone_module_list(layer, n_layers)最終正規化レイヤー
190        self.norm = nn.LayerNorm([layer.size])x
トークン埋め込みの形状ベクトルのテンソルです [seq_len, batch_size, d_model]
mem
過去のトークンレベルのテンソル、[mem_len, batch_size, d_model]
各レイヤーの形状ベクトルのリストですc_mem
[c_mem_len, batch_size, d_model]
各レイヤーの圧縮メモリのテンソルのリストですmask
はマスキングマトリックスです192    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):次のシーケンシャルバッチのメモリとなるトークンレベルの特徴ベクトルを格納するリスト。
203        new_mem = []各変圧器層に通す
205        for i, layer in enumerate(self.layers):特徴ベクトルのリストに追加
207            new_mem.append(x.detach())メモリー
209            m = mem[i] if mem else None圧縮メモリ
211            cm = c_mem[i] if c_mem else NoneトランスフォーマーXLレイヤーを通す
213            x = layer(x=x, mem=m, c_mem=cm, mask=mask)最後に、ベクトルを正規化します。
215        return self.norm(x), new_mem注意再構成損失は、非圧縮メモリと圧縮メモリで自己注意出力を再現し、両者の平均二乗誤差を計算します。これは位置エンコーディングなしで行います
。注意再構成損失を伴う圧縮関数の計算とトレーニングを行うと、すべてのパラメーターがフリーズします。これには、正規化後のキー/値の予測とバイアス/スケーリングが含まれます
。この損失はモデルのクロスエントロピー損失とは独立して計算できるため、更新のみを行う別のオプティマイザーを使用できます。ただし、更新には同じオプティマイザーを使用するため、注意再構成損失を計算するときは、勾配計算を除く他のすべてのパラメーターを切り離します
。218class AttentionReconstructionLoss:layers
圧縮トランスレイヤーのリストです
236    def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):240        self.layers = layers
241        self.loss_func = nn.MSELoss()これは 「PrepareForMultiHeadAttention」を再実装したもので、勾配計算から切り離されたパラメーターを使用して投影が行われます。
pmha
は 「マルチヘッドアテンション対策」モジュールですx
トークンが埋め込まれたテンソルです243    def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):埋め込み寸法以外の入力の形状;[seq_len, batch_size]
.
253        head_shape = x.shape[:-1]プロジェクションウェイトとバイアスをデタッチ
256        weight = pmha.linear.weight.detach()
257        bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None線形変換
259        x = F.linear(x, weight, bias)最後のディメンションをヘッドに分割
262        x = x.view(*head_shape, pmha.heads, pmha.d_k)[seq_len, batch_size, heads, d_k]
出力の形状があるか [batch_size, d_model]
265        return xこれは 「マルチヘッドアテンション」を再実装したもので、「PrepareForMultiHead Attention」prepare_for_attn
 の代わりに呼び出してプロジェクションパラメータをデタッチします。
267    def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):クエリ、キー、値の予測を計算
274        query = self.prepare_for_attn(layer.query, query)
275        key = self.prepare_for_attn(layer.key, key)
276        value = self.prepare_for_attn(layer.value, value)アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads]
これにより形状のテンソルが得られます
280        scores = torch.einsum('ibhd,jbhd->ijbh', query, key)スケールスコア
283        scores *= layer.scaleキーシーケンス次元に沿って注目
287        attn = layer.softmax(scores)値による乗算
291        return torch.einsum("ijbh,jbhd->ibhd", attn, value)シフトとスケールのパラメーターをデタッチしてレイヤーの正規化を実行します。
293    def norm(self, ln: nn.LayerNorm, x: torch.Tensor):shift (bias
) とスケーリング (weight
) パラメーターのデタッチ
299        weight = ln.weight.detach() if ln.weight is not None else None
300        bias = ln.bias.detach() if ln.bias is not None else Noneレイヤー正規化
303        return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)これにより、レイヤーの損失が計算されます
305    def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):トークンの埋め込みとメモリを切り離します。
311        h = h.detach()
312        mem = mem.detach()でメモリを圧縮します。のパラメータは、勾配計算から切り離されない唯一のパラメータです
。316        c_mem = layer.compress(mem)埋め込みとメモリを正規化
319        h = self.norm(layer.norm_self_attn, h)
320        mem = self.norm(layer.norm_self_attn, mem)
321        c_mem = self.norm(layer.norm_self_attn, c_mem)非圧縮メモリで注意度を計算
324        attn_mem = self.attn(layer.self_attn, h, mem, mem)圧縮メモリで注意度を計算
326        attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)平均二乗誤差の計算
329        return self.loss_func(attn_cmem, attn_mem)331    def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):各層の損失の計算
333        losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]損失の合計
335        return sum(losses)