これは PyTorch の Transformer-XL: 固定長のコンテキストを超えた注意深い言語モデルの実装です。
Transformer のアテンションスパンは、並行してトレーニングされたシーケンスの長さと同じくらいの制限があります。これらの位置はすべて固定された位置エンコーディングになっています。Transformer XLは、事前に計算された過去の埋め込みに各ポジションに注目させることで、このアテンションスパンを増やします。たとえば、コンテキストの長さがの場合、前のバッチの長さのすべてのレイヤーの埋め込みを保持し、それらを現在のステップに送ります。固定位置エンコーディングを使用すると、これらの事前に計算された埋め込みは現在のコンテキストと同じ位置になります。相対位置エンコーディングが導入され、アテンション計算時に位置エンコーディングが導入されます
。相対的多面的注意の注釈付き実装が導入されました。relative_mha.py
Tiny ShakespeareデータセットでトランスフォーマーXLモデルをトレーニングするためのトレーニングコードとノートブックです。
35from typing import List, Optional
36
37import torch
38import torch.nn as nn
39
40from labml_helpers.module import Module
41from labml_nn.utils import clone_module_list
42from .relative_mha import RelativeMultiHeadAttention
43from ..feed_forward import FeedForward
46class TransformerXLLayer(Module):
d_model
トークンの埋め込みサイズですself_attn
セルフアテンションモジュールですfeed_forward
フィードフォワードモジュールですdropout_prob
セルフアテンションとFFNの後に脱落する確率です52 def __init__(self, *,
53 d_model: int,
54 self_attn: RelativeMultiHeadAttention,
55 feed_forward: FeedForward,
56 dropout_prob: float):
63 super().__init__()
64 self.size = d_model
65 self.self_attn = self_attn
66 self.feed_forward = feed_forward
67 self.dropout = nn.Dropout(dropout_prob)
68 self.norm_self_attn = nn.LayerNorm([d_model])
69 self.norm_ff = nn.LayerNorm([d_model])
x
トークンレベルの形状ベクトルのテンソルです [seq_len, batch_size, d_model]
mem
過去のトークンレベルの形状ベクトルのテンソルです [mem_len, batch_size, d_model]
mask
[seq_len, mem_len + seq_len, batch_size]
[seq_len, mem_len + seq_len, 1]
は形状のマトリックスかmask[i, j]
トークン at が at i
のトークンを参照できる場合は true j
になります。71 def forward(self, *,
72 x: torch.Tensor,
73 mem: Optional[torch.Tensor],
74 mask: torch.Tensor):
セルフアテンションを行う前にベクトルを正規化してください
82 z = self.norm_self_attn(x)
メモリがあれば
84 if mem is not None:
正規化してください
86 mem = self.norm_self_attn(mem)
と連結 z
88 m_z = torch.cat((mem, z), dim=0)
メモリがない場合は無視してください
90 else:
91 m_z = z
注意
93 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
アテンション結果を追加
95 x = x + self.dropout(self_attn)
フィードフォワード用に正規化
98 z = self.norm_ff(x)
フィードフォワードネットワークを通過
100 ff = self.feed_forward(z)
フィードフォワードの結果を追加し直す
102 x = x + self.dropout(ff)
105 return x
108class TransformerXL(Module):
115 def __init__(self, layer: TransformerXLLayer, n_layers: int):
116 super().__init__()
トランスレイヤーのコピーを作成
118 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
120 self.norm = nn.LayerNorm([layer.size])
x
トークン埋め込みの形状ベクトルのテンソルです [seq_len, batch_size, d_model]
mem
過去のトークンレベルのテンソル、[mem_len, batch_size, d_model]
各レイヤーの形状ベクトルのリストですmask
はマスキングマトリックスです122 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
次のシーケンシャルバッチのメモリとなるトークンレベルの特徴ベクトルを格納するリスト。
131 new_mem = []
各変圧器層に通す
133 for i, layer in enumerate(self.layers):
特徴ベクトルのリストに追加
135 new_mem.append(x.detach())
メモリー
137 m = mem[i] if mem else None
トランスフォーマーXLレイヤーを通す
139 x = layer(x=x, mem=m, mask=mask)
最後に、ベクトルを正規化します。
141 return self.norm(x), new_mem