トランスフォーマー XL

これは PyTorch の Transformer-XL: 固定長のコンテキストを超えた注意深い言語モデルの実装です

Transformer のアテンションスパンは、並行してトレーニングされたシーケンスの長さと同じくらいの制限があります。これらの位置はすべて固定された位置エンコーディングになっています。Transformer XLは、事前に計算された過去の埋め込みに各ポジションに注目させることで、このアテンションスパンを増やします。たとえば、コンテキストの長さがの場合前のバッチの長さのすべてのレイヤーの埋め込みを保持し、それらを現在のステップに送ります。固定位置エンコーディングを使用すると、これらの事前に計算された埋め込みは現在のコンテキストと同じ位置になります。相対位置エンコーディングが導入され、アテンション計算時に位置エンコーディングが導入されます

相対的多面的注意の注釈付き実装が導入されました。relative_mha.py

Tiny ShakespeareデータセットでトランスフォーマーXLモデルをトレーニングするためのトレーニングコードとノートブックです

Open In Colab

35from typing import List, Optional
36
37import torch
38import torch.nn as nn
39
40from labml_helpers.module import Module
41from labml_nn.utils import clone_module_list
42from .relative_mha import RelativeMultiHeadAttention
43from ..feed_forward import FeedForward

トランスフォーマー XL レイヤー

トランスフォーマーXLモデルは、これらのレイヤーを多数備えています。

46class TransformerXLLayer(Module):
  • d_model トークンの埋め込みサイズです
  • self_attn セルフアテンションモジュールです
  • feed_forward フィードフォワードモジュールです
  • dropout_prob セルフアテンションとFFNの後に脱落する確率です
52    def __init__(self, *,
53                 d_model: int,
54                 self_attn: RelativeMultiHeadAttention,
55                 feed_forward: FeedForward,
56                 dropout_prob: float):
63        super().__init__()
64        self.size = d_model
65        self.self_attn = self_attn
66        self.feed_forward = feed_forward
67        self.dropout = nn.Dropout(dropout_prob)
68        self.norm_self_attn = nn.LayerNorm([d_model])
69        self.norm_ff = nn.LayerNorm([d_model])
  • x トークンレベルの形状ベクトルのテンソルです [seq_len, batch_size, d_model]
  • mem 過去のトークンレベルの形状ベクトルのテンソルです [mem_len, batch_size, d_model]
  • mask [seq_len, mem_len + seq_len, batch_size] [seq_len, mem_len + seq_len, 1] は形状のマトリックスかmask[i, j] トークン at が at i のトークンを参照できる場合は true j になります。
71    def forward(self, *,
72                x: torch.Tensor,
73                mem: Optional[torch.Tensor],
74                mask: torch.Tensor):

セルフアテンションを行う前にベクトルを正規化してください

82        z = self.norm_self_attn(x)

メモリがあれば

84        if mem is not None:

正規化してください

86            mem = self.norm_self_attn(mem)

と連結 z

88            m_z = torch.cat((mem, z), dim=0)

メモリがない場合は無視してください

90        else:
91            m_z = z

注意

93        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)

アテンション結果を追加

95        x = x + self.dropout(self_attn)

フィードフォワード用に正規化

98        z = self.norm_ff(x)

フィードフォワードネットワークを通過

100        ff = self.feed_forward(z)

フィードフォワードの結果を追加し直す

102        x = x + self.dropout(ff)

105        return x

トランスフォーマー XL モデル

これは複数のトランスXL層で構成されています

108class TransformerXL(Module):
115    def __init__(self, layer: TransformerXLLayer, n_layers: int):
116        super().__init__()

トランスレイヤーのコピーを作成

118        self.layers = clone_module_list(layer, n_layers)

最終正規化レイヤー

120        self.norm = nn.LayerNorm([layer.size])
  • x トークン埋め込みの形状ベクトルのテンソルです [seq_len, batch_size, d_model]
  • mem 過去のトークンレベルのテンソル、[mem_len, batch_size, d_model] 各レイヤーの形状ベクトルのリストです
  • mask はマスキングマトリックスです
122    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):

次のシーケンシャルバッチのメモリとなるトークンレベルの特徴ベクトルを格納するリスト。

131        new_mem = []

各変圧器層に通す

133        for i, layer in enumerate(self.layers):

特徴ベクトルのリストに追加

135            new_mem.append(x.detach())

メモリー

137            m = mem[i] if mem else None

トランスフォーマーXLレイヤーを通す

139            x = layer(x=x, mem=m, mask=mask)

最後に、ベクトルを正規化します。

141        return self.norm(x), new_mem