圧縮変圧器

これは PyTorch の長距離シーケンスモデリング用の圧縮トランスフォーマーの実装です

これはTransformer XLの拡張版で、過去の記憶を圧縮して注意範囲を広げています。つまり、最も遠いメモリがメモリに圧縮されます。ここで、は圧縮率です

圧縮操作

圧縮操作は次のように定義されます。この論文では複数の選択肢を紹介していますが、最良の結果が得られると思われる1次元の畳み込みのみを実装しています。各レイヤーには個別の圧縮操作があります。ここで、はレイヤー番号です。

トレーニング用圧縮操作

BPTTによるトレーニング圧縮では、非常に大きな計算グラフ(多くのタイムステップ)を維持する必要があるため、この論文では自動エンコーディング損失と注意再構成損失を提案しています自動エンコーディング損失は、圧縮されたメモリから元のメモリをデコードし、損失を計算します。アテンション再構成損失では、圧縮メモリと非圧縮メモリでマルチヘッドアテンションの結果を計算し、それらの間の平均二乗誤差を求めます。後者の方が良い結果が得られるため、ここでは後者を実装しました。

この実装ではレイヤー前の正規化を使用しますが、ペーパーではレイヤー後の正規化を使用します。前層ノルムはFFNやセルフアテンション前の層ノルムを行い、残差接続でのパススルーは正規化されません。これは標準的な変圧器の設定ではより安定しているはずです

Tiny Shakespeareデータセットで圧縮トランスフォーマーモデルをトレーニングするためのトレーニングコードとノートブックは次のとおりです

Open In Colab

53from typing import Optional, List
54
55import torch
56import torch.nn.functional as F
57from torch import nn
58
59from labml_helpers.module import Module, TypedModuleList
60from labml_nn.transformers.feed_forward import FeedForward
61from labml_nn.transformers.mha import PrepareForMultiHeadAttention
62from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
63from labml_nn.utils import clone_module_list

1D コンボリューション圧縮

これは、nn.Conv1d テンソル次元の順列を組み合わせた単純なラッパーです。

66class Conv1dCompression(Module):
  • compression_rate
  • d_model は埋め込みサイズ
74    def __init__(self, compression_rate: int, d_model: int):
79        super().__init__()
80        self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)

mem 形がある [seq_len, batch, d_model]

82    def forward(self, mem: torch.Tensor):

の次元を並べ替えて、mem 畳み込み層に通せるようにします。畳み込み層は次の形式を受け入れます [batch, features, sequence]

89        mem = mem.permute(1, 2, 0)

畳み込み層に通して圧縮メモリを取得

91        c_mem = self.conv(mem)

フォームに戻す [seq_len, batch, d_model]

93        return c_mem.permute(2, 0, 1)

圧縮変圧器層

これは単一の圧縮変圧器層の実装です

96class CompressiveTransformerLayer(Module):
102    def __init__(self, *,
103                 d_model: int,
104                 self_attn: RelativeMultiHeadAttention,
105                 feed_forward: FeedForward,
106                 dropout_prob: float,
107                 compress: Conv1dCompression):
115        super().__init__()
116        self.compress = compress
117        self.size = d_model
118        self.self_attn = self_attn
119        self.feed_forward = feed_forward
120        self.dropout = nn.Dropout(dropout_prob)
121        self.norm_self_attn = nn.LayerNorm([d_model])
122        self.norm_ff = nn.LayerNorm([d_model])

正規化されたトークンの埋め込みをメモリと圧縮メモリと連結します。

  • z レイヤー正規化トークンの埋め込みです。
  • mem c_mem メモリと圧縮メモリ (正規化されていない) です。
  • 124    def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):

    メモリがない場合は、トークンの埋め込みを返してください

    133        if mem is None:
    134            return z

    圧縮メモリがある場合は、それをメモリと連結します。

    137        if c_mem is not None:
    138            mem = torch.cat((c_mem, mem), dim=0)

    メモリを正規化層に通す

    141        mem = self.norm_self_attn(mem)

    正規化されたメモリと正規化されたトークンの埋め込みを連結する

    143        return torch.cat((mem, z), dim=0)
    • x 形状のトークンレベルの特徴ベクトルのテンソルです [seq_len, batch_size, d_model]
    • mem 過去のトークンレベルの形状の特徴ベクトル (メモリ) のテンソルです [mem_len, batch_size, d_model]
    • c_mem 圧縮メモリのテンソルです [c_mem_len, batch_size, d_model]
    • mask [seq_len, c_mem_len + mem_len + seq_len, batch_size] [seq_len, c_mem_len + mem_len + seq_len, 1] は形状のマトリックスかmask[i, j] トークン at が at i のトークンを参照できる場合は true j になります。
    145    def forward(self, *,
    146                x: torch.Tensor,
    147                mem: Optional[torch.Tensor],
    148                c_mem: Optional[torch.Tensor],
    149                mask: torch.Tensor):

    セルフアテンションを行う前にベクトルを正規化してください

    159        z = self.norm_self_attn(x)

    メモリと圧縮メモリの正規化と連結

    161        m_z = self.concat_memory(z, mem, c_mem)

    注意

    163        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)

    アテンション結果を追加

    165        x = x + self.dropout(self_attn)

    フィードフォワード用に正規化

    168        z = self.norm_ff(x)

    フィードフォワードネットワークを通過

    170        ff = self.feed_forward(z)

    フィードフォワードの結果を追加し直す

    172        x = x + self.dropout(ff)

    175        return x

    圧縮変圧器モデル

    これは複数の圧縮変圧器層で構成されています

    178class CompressiveTransformer(Module):
    185    def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
    186        super().__init__()

    トランスレイヤーのコピーを作成

    188        self.layers = clone_module_list(layer, n_layers)

    最終正規化レイヤー

    190        self.norm = nn.LayerNorm([layer.size])
    • x トークン埋め込みの形状ベクトルのテンソルです [seq_len, batch_size, d_model]
    • mem 過去のトークンレベルのテンソル、[mem_len, batch_size, d_model] 各レイヤーの形状ベクトルのリストです
    • c_mem [c_mem_len, batch_size, d_model] 各レイヤーの圧縮メモリのテンソルのリストです
    • mask はマスキングマトリックスです
    192    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):

    次のシーケンシャルバッチのメモリとなるトークンレベルの特徴ベクトルを格納するリスト。

    203        new_mem = []

    各変圧器層に通す

    205        for i, layer in enumerate(self.layers):

    特徴ベクトルのリストに追加

    207            new_mem.append(x.detach())

    メモリー

    209            m = mem[i] if mem else None

    圧縮メモリ

    211            cm = c_mem[i] if c_mem else None

    トランスフォーマーXLレイヤーを通す

    213            x = layer(x=x, mem=m, c_mem=cm, mask=mask)

    最後に、ベクトルを正規化します。

    215        return self.norm(x), new_mem

    注意力再建ロス

    注意再構成損失は、非圧縮メモリと圧縮メモリで自己注意出力を再現し、両者の平均二乗誤差を計算します。これは位置エンコーディングなしで行います

    注意再構成損失を伴う圧縮関数の計算とトレーニングを行うと、すべてのパラメーターがフリーズします。これには、正規化後のキー/値の予測とバイアス/スケーリングが含まれます

    この損失はモデルのクロスエントロピー損失とは独立して計算できるため、更新のみを行う別のオプティマイザーを使用できます。ただし、更新には同じオプティマイザーを使用するため、注意再構成損失を計算するときは、勾配計算を除く他のすべてのパラメーターを切り離します

    218class AttentionReconstructionLoss:

    layers 圧縮トランスレイヤーのリストです

    236    def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
    240        self.layers = layers
    241        self.loss_func = nn.MSELoss()
    243    def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):

    埋め込み寸法以外の入力の形状;[seq_len, batch_size] .

    253        head_shape = x.shape[:-1]

    プロジェクションウェイトとバイアスをデタッチ

    256        weight = pmha.linear.weight.detach()
    257        bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None

    線形変換

    259        x = F.linear(x, weight, bias)

    最後のディメンションをヘッドに分割

    262        x = x.view(*head_shape, pmha.heads, pmha.d_k)

    [seq_len, batch_size, heads, d_k] 出力の形状があるか [batch_size, d_model]

    265        return x

    これは 「マルチヘッドアテンション」を再実装したもので、「PrepareForMultiHead Attention」prepare_for_attn の代わりに呼び出してプロジェクションパラメータをデタッチします。

    267    def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):

    クエリ、キー、値の予測を計算

    274        query = self.prepare_for_attn(layer.query, query)
    275        key = self.prepare_for_attn(layer.key, key)
    276        value = self.prepare_for_attn(layer.value, value)

    アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads] これにより形状のテンソルが得られます

    280        scores = torch.einsum('ibhd,jbhd->ijbh', query, key)

    スケールスコア

    283        scores *= layer.scale

    キーシーケンス次元に沿って注目

    287        attn = layer.softmax(scores)

    値による乗算

    291        return torch.einsum("ijbh,jbhd->ibhd", attn, value)

    シフトとスケールのパラメーターをデタッチしてレイヤーの正規化を実行します。

    293    def norm(self, ln: nn.LayerNorm, x: torch.Tensor):

    shift (bias ) とスケーリング (weight ) パラメーターのデタッチ

    299        weight = ln.weight.detach() if ln.weight is not None else None
    300        bias = ln.bias.detach() if ln.bias is not None else None

    レイヤー正規化

    303        return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)

    これにより、レイヤーの損失が計算されます

    305    def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):

    トークンの埋め込みとメモリを切り離します。

    311        h = h.detach()
    312        mem = mem.detach()

    でメモリを圧縮します。のパラメータは、勾配計算から切り離されない唯一のパラメータです

    316        c_mem = layer.compress(mem)

    埋め込みとメモリを正規化

    319        h = self.norm(layer.norm_self_attn, h)
    320        mem = self.norm(layer.norm_self_attn, mem)
    321        c_mem = self.norm(layer.norm_self_attn, c_mem)

    非圧縮メモリで注意度を計算

    324        attn_mem = self.attn(layer.self_attn, h, mem, mem)

    圧縮メモリで注意度を計算

    326        attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)

    平均二乗誤差の計算

    329        return self.loss_func(attn_cmem, attn_mem)
    331    def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):

    各層の損失の計算

    333        losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]

    損失の合計

    335        return sum(losses)