これは、「階層型トランスフォーマーはより効率的な言語モデル」という論文をPyTorchで実装したものです。
本稿では、長いシーケンスを効率的に処理するための階層型トランスフォーマーアーキテクチャを紹介します。トランスフォーマーレイヤーの前半はトークンをダウンサンプリングし、後半は同じ解像度のレイヤー間を直接スキップ接続してアップサンプリングします。これはビジョンタスク用のU-Netに少し似ています
。彼らはさまざまなアップサンプリングとダウンサンプリングの手法を試し、砂時計モデルと呼ばれる最もパフォーマンスの高いアップサンプリング手法とダウンサンプリング手法を使用してモデルを構築します。
ここでは、わかりやすくするために、最も単純なアップサンプリングとダウンサンプリングの手法を実装しました。後ほど、より複雑な (そしてより高性能な) 実装を追加することを検討します
。28from typing import List
29
30import torch
31from torch import nn
32
33from labml_helpers.module import Module
34from labml_nn.transformers import MultiHeadAttention, TransformerLayer
35from labml_nn.transformers.feed_forward import FeedForward
36from labml_nn.transformers.utils import subsequent_mask
このモデルは、ダウンサンプリングによってシーケンスを短縮しながら、中央にレイヤーを再帰的に追加します。別の砂時計モデルで処理された短縮シーケンスは、2つの通常のトランスレイヤーの間に挟まれます。(トランス層にはセルフアテンション層と位置ごとのフィードフォワード層があります)
。39class HourGlass(Module):
n_heads
マルチヘッド・アテンション・レイヤー内のヘッド数ですd_model
トークンの埋め込みのサイズですdropout
は脱落確率ですd_ff
位置ごとのフィードフォワード層における隠れ層の次元ですshortening_factors
短縮係数のリストです49 def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
57 super().__init__()
ダウンサンプリング前のトランスレイヤー
60 self.pre = TransformerLayer(d_model=d_model,
62 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
64 feed_forward=FeedForward(d_model, d_ff, dropout),
66 dropout_prob=dropout)
自動回帰マスク
68 self.mask = AutoregressiveMask()
短縮係数 (またはダウンサンプリングレート)
71 k = shortening_factors[0]
ダウンサンプリングとアップサンプリングの結果、将来のトークンから過去のトークンに情報が漏れないように、トークンを段階的に右にシフトします。
76 self.shift_right = ShiftRight(k - 1)
ショートニングまたはダウンサンプリングレイヤー。最も単純な形式、つまり平均プーリングを使用します。この論文では、注意に基づくダウンサンプリングが最も効果的であることが示されていますが、まだ実装していません。
79 self.shortening = AvgPoolShortening(k)
ショートニングがなくなったら (砂時計の真ん中)
82 if len(shortening_factors) == 1:
中央の層は別の変圧器層です
84 self.shortened = TransformerLayer(d_model=d_model,
85 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
86 feed_forward=FeedForward(d_model, d_ff, dropout),
87 dropout_prob=dropout)
自己回帰マスク
89 self.mask_short = AutoregressiveMask()
90 self.hour_glass = None
91 else:
別の砂時計モデルを再帰的に挿入
93 self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])
アップサンプリングレイヤー。簡略化のためにナイーブなアップサンプリングを使用しており、論文では注意に基づくサンプリングの方が効果的であることが示されています
。97 self.up_sampling = NaiveUpSampling(k)
アップサンプリング後の最後のトランス層
100 self.post = TransformerLayer(d_model=d_model,
101 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
102 feed_forward=FeedForward(d_model, d_ff, dropout),
103 dropout_prob=dropout)
105 def forward(self, x: torch.Tensor):
初期変圧器層
108 x = self.pre(x=x, mask=self.mask(x))
シフトとショートニング
111 x_short = self.shortening(self.shift_right(x))
砂時計の中心にいると
115 if self.hour_glass is None:
センタートランス層
118 x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))
120 else:
122 x_short = self.hour_glass(x_short)
短縮されたシーケンスをアップサンプリングし、スキップ接続を追加します
126 x = x + self.up_sampling(x, x_short)
最終変圧器層
129 x = self.post(x=x, mask=self.mask(x))
132 return x
135class ShiftRight(Module):
shift
はシフトするステップ数です142 def __init__(self, shift: int):
146 super().__init__()
負の値にすることはできません
148 assert shift >= 0
150 self.shift = shift
x
形状のテンソルです [seq_len, ...]
152 def forward(self, x: torch.Tensor):
シフトがの場合、元の状態に戻す
157 if self.shift == 0:
158 return x
左にゼロを追加
160 prefix = x.new_zeros([self.shift, *x.shape[1:]])
0 を連結して右を切り捨てる
162 return torch.cat([prefix, x[:-self.shift]])
165class AvgPoolShortening(Module):
k
は短縮係数172 def __init__(self, k: int):
176 super().__init__()
平均プーリング層
178 self.pool = nn.AvgPool1d(k, ceil_mode=True)
x
形が合っている [seq_len, batch_size, d_model]
180 def forward(self, x: torch.Tensor):
[batch_size, d_model, seq_len]
プーリング層は形状を受け入れるので、軸を並べ替えます。
186 return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)
189class NaiveUpSampling(Module):
k
は短縮係数196 def __init__(self, k: int):
200 super().__init__()
201 self.k = k
x
ダウンサンプリング前の埋め込みを含むテンソルですx_short
より高い密度の (アップサンプリング対象の) 表現のテンソルです203 def forward(self, x: torch.Tensor, x_short: torch.Tensor):
シーケンスディメンション全体で繰り返します
209 expanded = torch.repeat_interleave(x_short, self.k, dim=0)
最後の余分な埋め込みは切り捨ててください
211 expanded = expanded[:x.shape[0]]
214 return expanded
217class AutoregressiveMask(Module):
222 def __init__(self):
223 super().__init__()
224 self.mask = None
226 def forward(self, x: torch.Tensor):
まだ作成していない場合やサイズが変更された場合はマスクを作成してください
228 if self.mask is None or self.mask.size(0) != len(x):
233 return self.mask
236class LinearPoolingShortening(Module):
244 def __init__(self):
245 super().__init__()
246 raise NotImplementedError
249class AttentionBasedShortening(Module):
261 def __init__(self):
262 super().__init__()
263 raise NotImplementedError
266class LinearUpSampling(Module):
273 def __init__(self):
274 super().__init__()
275 raise NotImplementedError
278class AttentionBasedUpSampling(Module):
290 def __init__(self):
291 super().__init__()
292 raise NotImplementedError