階層型トランスフォーマーはより効率的な言語モデル

これは、「階層型トランスフォーマーはより効率的な言語モデルという論文をPyTorchで実装したものです

本稿では、長いシーケンスを効率的に処理するための階層型トランスフォーマーアーキテクチャを紹介します。トランスフォーマーレイヤーの前半はトークンをダウンサンプリングし、後半は同じ解像度のレイヤー間を直接スキップ接続してアップサンプリングします。これはビジョンタスク用のU-Netに少し似ています

彼らはさまざまなアップサンプリングとダウンサンプリングの手法を試し、砂時計モデルと呼ばれる最もパフォーマンスの高いアップサンプリング手法とダウンサンプリング手法を使用してモデルを構築します。

ここでは、わかりやすくするために、最も単純なアップサンプリングとダウンサンプリングの手法を実装しました。後ほど、より複雑な (そしてより高性能な) 実装を追加することを検討します

砂時計モデルのトレーニングコードは次のとおりです

28from typing import List
29
30import torch
31from torch import nn
32
33from labml_helpers.module import Module
34from labml_nn.transformers import MultiHeadAttention, TransformerLayer
35from labml_nn.transformers.feed_forward import FeedForward
36from labml_nn.transformers.utils import subsequent_mask

砂時計モデル

このモデルは、ダウンサンプリングによってシーケンスを短縮しながら、中央にレイヤーを再帰的に追加します。別の砂時計モデルで処理された短縮シーケンスは、2つの通常のトランスレイヤーの間に挟まれます。(トランス層にはセルフアテンション層と位置ごとのフィードフォワード層があります

39class HourGlass(Module):
49    def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
57        super().__init__()

ダウンサンプリング前のトランスレイヤー

60        self.pre = TransformerLayer(d_model=d_model,
62                                    self_attn=MultiHeadAttention(n_heads, d_model, dropout),
64                                    feed_forward=FeedForward(d_model, d_ff, dropout),

66                                    dropout_prob=dropout)

自動回帰マスク

68        self.mask = AutoregressiveMask()

短縮係数 (またはダウンサンプリングレート)

71        k = shortening_factors[0]

ダウンサンプリングとアップサンプリングの結果、将来のトークンから過去のトークンに情報が漏れないように、トークンを段階的に右にシフトします。

76        self.shift_right = ShiftRight(k - 1)

ショートニングまたはダウンサンプリングレイヤー。最も単純な形式、つまり平均プーリングを使用します。この論文では、注意に基づくダウンサンプリングが最も効果的であることが示されていますが、まだ実装していません。

79        self.shortening = AvgPoolShortening(k)

ショートニングがなくなったら (砂時計の真ん中)

82        if len(shortening_factors) == 1:

中央の層は別の変圧器層です

84            self.shortened = TransformerLayer(d_model=d_model,
85                                              self_attn=MultiHeadAttention(n_heads, d_model, dropout),
86                                              feed_forward=FeedForward(d_model, d_ff, dropout),
87                                              dropout_prob=dropout)

自己回帰マスク

89            self.mask_short = AutoregressiveMask()
90            self.hour_glass = None
91        else:

別の砂時計モデルを再帰的に挿入

93            self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])

アップサンプリングレイヤー。簡略化のためにナイーブなアップサンプリングを使用しており、論文では注意に基づくサンプリングの方が効果的であることが示されています

97        self.up_sampling = NaiveUpSampling(k)

アップサンプリング後の最後のトランス層

100        self.post = TransformerLayer(d_model=d_model,
101                                     self_attn=MultiHeadAttention(n_heads, d_model, dropout),
102                                     feed_forward=FeedForward(d_model, d_ff, dropout),
103                                     dropout_prob=dropout)
105    def forward(self, x: torch.Tensor):

初期変圧器層

108        x = self.pre(x=x, mask=self.mask(x))

シフトとショートニング

111        x_short = self.shortening(self.shift_right(x))

砂時計の中心にいると

115        if self.hour_glass is None:

センタートランス層

118            x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))

120        else:

122            x_short = self.hour_glass(x_short)

短縮されたシーケンスをアップサンプリングし、スキップ接続を追加します

126        x = x + self.up_sampling(x, x_short)

最終変圧器層

129        x = self.post(x=x, mask=self.mask(x))

132        return x

右シフト操作

これにより、指定したステップ数だけシーケンスが右にシフトします。

135class ShiftRight(Module):
  • shift はシフトするステップ数です
142    def __init__(self, shift: int):
146        super().__init__()

負の値にすることはできません

148        assert shift >= 0

150        self.shift = shift
  • x 形状のテンソルです [seq_len, ...]
152    def forward(self, x: torch.Tensor):

シフトがの場合、元の状態に戻す

157        if self.shift == 0:
158            return x

左にゼロを追加

160        prefix = x.new_zeros([self.shift, *x.shape[1:]])

0 を連結して右を切り捨てる

162        return torch.cat([prefix, x[:-self.shift]])

平均的なプール短縮

これは、平均プーリングを使用して特定の係数でダウンサンプリングします。

165class AvgPoolShortening(Module):
  • k は短縮係数
172    def __init__(self, k: int):
176        super().__init__()

平均プーリング層

178        self.pool = nn.AvgPool1d(k, ceil_mode=True)
  • x 形が合っている [seq_len, batch_size, d_model]
180    def forward(self, x: torch.Tensor):

[batch_size, d_model, seq_len] プーリング層は形状を受け入れるので、軸を並べ替えます。

186        return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)

ナイーブアップサンプリング

これを繰り返してアップサンプリングします

189class NaiveUpSampling(Module):
  • k は短縮係数
196    def __init__(self, k: int):
200        super().__init__()
201        self.k = k
  • x ダウンサンプリング前の埋め込みを含むテンソルです
  • x_short より高い密度の (アップサンプリング対象の) 表現のテンソルです
203    def forward(self, x: torch.Tensor, x_short: torch.Tensor):

シーケンスディメンション全体で繰り返します

209        expanded = torch.repeat_interleave(x_short, self.k, dim=0)

最後の余分な埋め込みは切り捨ててください

211        expanded = expanded[:x.shape[0]]

214        return expanded

自動回帰マスクを生成

217class AutoregressiveMask(Module):
222    def __init__(self):
223        super().__init__()
224        self.mask = None
226    def forward(self, x: torch.Tensor):

まだ作成していない場合やサイズが変更された場合はマスクを作成してください

228        if self.mask is None or self.mask.size(0) != len(x):

次にマスクすると、トークンがマスクされ、将来のトークンが見えなくなります

230            self.mask = subsequent_mask(len(x)).to(x.device)

233        return self.mask

🚧 ダウンサンプリング用のリニアプーリング

これにより、マージが必要な連続したトークンの埋め込みが連結され、1 つのトークン埋め込みのサイズに合わせて線形変換が行われます。

236class LinearPoolingShortening(Module):
244    def __init__(self):
245        super().__init__()
246        raise NotImplementedError

🚧 注意してダウンサンプリング

ここで、は平均プーリングか線形プーリングかです。

249class AttentionBasedShortening(Module):
261    def __init__(self):
262        super().__init__()
263        raise NotImplementedError

🚧 アップサンプリング用のリニアプロジェクション

密度の高いトークンの埋め込みを、のサイズに合わせて線形投影します。

266class LinearUpSampling(Module):
273    def __init__(self):
274        super().__init__()
275        raise NotImplementedError

🚧 アテンションベースのアップサンプリング

どこ

278class AttentionBasedUpSampling(Module):
290    def __init__(self):
291        super().__init__()
292        raise NotImplementedError