14import math
15from typing import Set
16
17import torch
18from torch import nn
19
20from labml.logger import inspect

ロープ埋め込み

セルフアテンションレイヤーには回転位置埋め込みを使用しています。位置情報は埋め込みに埋め込まれているため、因果関係には使用しないと想定しています。因果関係のない自己注意には、推測できないため、明確な位置情報が必要です

23class RotaryPositionalEmbeddings(nn.Module):
  • d は機能の数
  • base は計算に使用される定数です
34    def __init__(self, d: int, base: int = 10_000):
39        super().__init__()

41        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
  • x キーまたは形状のあるクエリの先頭にあるテンソルです [ batch_size, seq_len, n_heads, d]
43    def forward(self, x: torch.Tensor):

形状を抽出

48        batch_size, seq_len, n_heads, d = x.shape

51        d_2 = d // 2

位置インデックスの作成 [0, 1, ..., seq_len - 1]

54        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

位置指数の積を計算し、

57        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)

行が次のようになるように連結します

61        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

計算

65        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

計算

にとって

77        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])

80        return rx
83class SelfAttention(nn.Module):
  • d_model トランスフォーマー埋め込みのフィーチャー数です
  • n_heads アテンション・ヘッドの数です
  • d_k はヘッドあたりのフィーチャ数です
  • is_causal これが因果関係であるかどうかを示します (マスクされています)
90    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
97        super().__init__()
98
99        self.is_causal = is_causal
100        self.n_heads = n_heads
101        self.d_k = d_k

ソフトマックスの前にアテンションをスケーリングするには

104        self.scale = 1 / math.sqrt(self.d_k)

クエリ、キー、バリューヘッド用のリニアレイヤー。

107        self.query = nn.Linear(d_model, n_heads * d_k)
108        self.key = nn.Linear(d_model, n_heads * d_k)
109        self.value = nn.Linear(d_model, n_heads * d_k)

プレ・ノルム・レイヤーこの論文では代わりにRMSnormを使用しています

112        self.norm = nn.LayerNorm(d_model)

注意確率のソフトマックス

115        self.softmax = nn.Softmax(dim=-1)

ロータリーポジションエンベディング

118        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)

最終線形レイヤー

121        self.output = nn.Linear(n_heads * d_k, d_model)

アテンションレイヤーをマスクして原因となる注意を促す

  • attn シェイプ・アテンション・マトリックスです [batch_size, n_heads, seq_len, seq_len]
123    def mask_attention(self, attn: torch.Tensor):

因果関係のない注意のためのマスキングなし

131        if not self.is_causal:
132            return attn

三角マスクの作成

135        mask = torch.tril(attn.new_ones(attn.shape[-2:]))

マスクで絞り込む

137        return attn.masked_fill(mask == 0, float('-inf'))
  • h 形をした変圧器の埋め込みです [batch_size, seq_len, d_model]
139    def forward(self, h: torch.Tensor):

残留接続

145        h_res = h

事前正規化

148        h = self.norm(h)

クエリ、キー、値を取得し、それらをヘッドに分割します。これらには形があります [batch_size, seq_len, n_heads, d_k]

152        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
153        q = self.query(h).view(mh_shape)
154        k = self.key(h).view(mh_shape)
155        v = self.value(h).view(mh_shape)

回転式位置埋め込みを適用

158        q = self.rotary_pe(q)
159        k = self.rotary_pe(k)

注意事項の計算

162        attn = torch.einsum('bihd,bjhd->bhij', q, k)

スケーリング・バイ・スケーリング

164        attn = attn * self.scale

原因となる注意が必要な場合はマスクを着用してください

167        attn = self.mask_attention(attn)

注意確率の計算

170        attn = self.softmax(attn)

値を取得

173        h = torch.einsum("bhij,bjhd->bihd", attn, v)

[batch_size, seq_len, n_heads, d_k] 形状を次のように変更 [batch_size, seq_len, n_heads * d_k]

177        h = h.reshape(*h.shape[:-2], -1)

最後の線形レイヤーを適用します。結果は形になります [batch_size, seq_len, d_model]

181        h = self.output(h)

残余接続を追加

184        return h + h_res

クロスアテンションレイヤー

これは上で定義したセルフアテンションレイヤーと似ていますが、クエリとは異なる埋め込みセットからキーと値を取得する点が異なります。

これをエンコーダーで使用して、取得したチャンクを入力チャンクに基づいてエンコードします。

ここでは明示的な位置埋め込みは一切使用していません。モデルは埋め込み内の位置情報を暗黙的に表現できると仮定します

187class CrossAttention(nn.Module):
  • d_model トランスフォーマー埋め込みのフィーチャー数です
  • n_heads アテンション・ヘッドの数です
  • d_k はヘッドあたりのフィーチャ数です
201    def __init__(self, d_model: int, n_heads: int, d_k: int):
207        super().__init__()
208
209        self.n_heads = n_heads
210        self.d_k = d_k

ソフトマックスの前にアテンションをスケーリングするには

213        self.scale = 1 / math.sqrt(self.d_k)

クエリ、キー、バリューヘッド用のリニアレイヤー。

216        self.query = nn.Linear(d_model, n_heads * d_k)
217        self.key = nn.Linear(d_model, n_heads * d_k)
218        self.value = nn.Linear(d_model, n_heads * d_k)

クエリ埋め込み用のプレノルムレイヤー。この論文では代わりにRMSnormを使用しています

221        self.norm = nn.LayerNorm(d_model)

注意確率のソフトマックス

224        self.softmax = nn.Softmax(dim=-1)

最終線形レイヤー

227        self.output = nn.Linear(n_heads * d_k, d_model)
  • e 検索された最も近い近傍のチャンク埋め込みの形状の埋め込みです [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h は入力チャンクで、そこから最も近い近傍データが取得されました。[batch_size, chunks, chunk_len, d_model] これはすでに標準化されています。
229    def forward(self, e: torch.Tensor, h: torch.Tensor):

残留接続

238        e_res = e

取得したチャンクを正規化

241        e = self.norm(e)

取得したチャンクからクエリを取得

244        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)

入力チャンクからキーと値を取得

246        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
247        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)

すべてのチャンクのアテンションスコアを計算します。取得された各ネイバーは、それを取得した元のチャンクに注目します。これは形になります [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]

252        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)

スケールアテンションスコア

254        attn = attn * self.scale

最後の次元のソフトマックスを計算

257        attn = self.softmax(attn)

価値を集める

260        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)

[batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] 形状を次のように変更 [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]

264        e = e.reshape(*e.shape[:-2], -1)

最後の線形レイヤーを適用します。結果は形になります [batch_size, chunks, neighbors, neighbor_len, d_model]

268        e = self.output(e)

残余接続を追加

271        return e + e_res

チャンク・クロスアテンション・レイヤー

これは上で定義したクロス・アテンション・レイヤーに似ています。

これをデコーダで使用して、取得した隣接チャンクに注目します。

ここでは明示的な位置埋め込みは一切使用していません。モデルは埋め込み内の位置情報を暗黙的に表現できると仮定します

274class ChunkedCrossAttention(nn.Module):
  • d_model トランスフォーマー埋め込みのフィーチャー数です
  • n_heads アテンション・ヘッドの数です
  • d_k はヘッドあたりのフィーチャ数です
  • chunk_len チャンクの長さです
286    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
294        super().__init__()
295
296        self.chunk_len = chunk_len
297        self.n_heads = n_heads
298        self.d_k = d_k

ソフトマックスの前にアテンションをスケーリングするには

301        self.scale = 1 / math.sqrt(self.d_k)

クエリ、キー、バリューヘッド用のリニアレイヤー。

304        self.query = nn.Linear(d_model, n_heads * d_k)
305        self.key = nn.Linear(d_model, n_heads * d_k)
306        self.value = nn.Linear(d_model, n_heads * d_k)

クエリ埋め込み用のプレノルムレイヤー。この論文では代わりにRMSnormを使用しています

309        self.norm = nn.LayerNorm(d_model)

注意確率のソフトマックス

312        self.softmax = nn.Softmax(dim=-1)

最終線形レイヤー

315        self.output = nn.Linear(n_heads * d_k, d_model)

h 図形の入力埋め込みは、[batch_size, seq_len, d_model] e 検索された図形の最も近い近傍データです [batch_size, chunks, neighbors, neighbor_len, d_model]

317    def forward(self, h: torch.Tensor, e: torch.Tensor):

シェイプを取得

324        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

チャンクがない場合は不要 (サンプリング時の入力が短い場合)

327        if chunks == 0:
328            return h

残留接続

331        h_res = h

chunk_len - 1 最初の埋め込みを削除します。入力は、過去のトークンのみを使用して取得およびエンコードされたネイバーに注目し、情報漏えいが発生しないようにします。つまり、最初のチャンクから取得したネイバーには、最初のチャンクからの情報が含まれます。そのため、シーケンスを左にシフトすることでchunk_len - 1 、情報が右にのみ流れるようにします

339        h = h[:, self.chunk_len - 1:]

プレノルム

341        h = self.norm(h)

入力をチャンクに分割できるように、最後に空の埋め込みを追加します

343        if h.shape[1] < chunks * self.chunk_len:
344            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)

入力をチャンクにリシェイプします。

346        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)

入力からクエリを取得

349        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)

取得したネイバーからキーと値を取得

351        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
352        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)

入力チャンクのアテンションスコアを計算します。各チャンクは、前のチャンクで取得した隣接チャンクに注目します。これは形になります [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]

357        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)

スケールアテンションスコア

359        attn = attn * self.scale

最後の 2 次元にソフトマックスを適用 neighbors, neighbor_len

362        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)

価値を集める

365        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)

[batch_size, chunks, chunk_len, n_heads, d_k] 形状を次のように変更 [batch_size, chunks * chunk_len, n_heads * d_k]

369        h = h.reshape(batch_size, chunks * self.chunk_len, -1)

最後の線形レイヤーを適用します。結果は形になります [batch_size, chunks * chunk_len, d_model]

373        h = self.output(h)

chunk_len - 1 左にゼロの埋め込みを追加、つまり右にシフトして戻す

376        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)

残りの接続を切り捨てて追加

379        return h[:, :h_res.shape[1]] + h_res

位置別フィードフォワード層

これは2つの線形レイヤーと中央のアクティベーションで構成されています。

382class FeedForward(nn.Module):
  • d_model トランスフォーマー埋め込みのフィーチャー数です
  • d_ff は隠れレイヤーのナンバーフィーチャです
389    def __init__(self, d_model: int, d_ff: int):
395        super().__init__()

2 つの線形レイヤー

398        self.lin1 = nn.Linear(d_model, d_ff)
399        self.lin2 = nn.Linear(d_ff, d_model)

ReLU アクティベーション

402        self.act = nn.ReLU()

プレ・ノルム・レイヤー

405        self.norm = nn.LayerNorm(d_model)

h 形が埋め込まれているものです [batch_size, seq_len, d_model]

407    def forward(self, h: torch.Tensor):

残余

413        h_res = h

プレノルム

415        h = self.norm(h)

第 1 線形レイヤー

417        h = self.lin1(h)

アクティベーション

419        h = self.act(h)

2 番目の線形レイヤー

421        h = self.lin2(h)

残余接続を追加

424        return h + h_res

近傍エンコーダ

このモジュールは、取得した最近傍をエンコードします

427class NearestNeighborEncoder(nn.Module):
  • chunk_len チャンクの長さです
  • n_layer はエンコーダーのレイヤー数です
  • ca_layers クロスアテンションの対象となるレイヤーは
  • d_model は埋め込みに含まれるフィーチャの数です
  • n_heads はアテンションレイヤーのヘッド数です
  • d_k アテンションヘッドのサイズです
  • d_ff フィードフォワードネットワークの隠れ層のサイズ
434    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
435                 d_model: int, n_heads: int, d_k: int, d_ff: int):
446        super().__init__()
447        self.ca_layers = ca_layers
448        self.chunk_len = chunk_len

クロスアテンションレイヤー

450        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])

双方向のセルフアテンションレイヤー

452        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])

フィードフォワードレイヤー

454        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

の事前正規化レイヤー

457        self.norm_h = nn.LayerNorm(d_model)
  • e 検索された最も近い近傍のトークンの埋め込みで、形も整っています [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h は入力トークンの埋め込みで、形状はさまざまです [batch_size, seq_len, d_model]

チャンクとネイバーは並行して処理されます。

459    def forward(self, e: torch.Tensor, h: torch.Tensor):

シェイプを取得

472        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

475        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)

プレノルム

478        h_split = self.norm_h(h_split)

クロス・アテンション・レイヤーのインデックスを維持

481        p_ca = 0

すべてのレイヤー用

483        for p in range(len(self.attn)):

双方向のセルフアテンション

486            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)

次の場合はクロスアテンション

489            if p in self.ca_layers:

491                e = self.ca[p_ca](e, h_split)

クロス・アテンション・インデックスを増やす

493                p_ca += 1

フィードフォワードレイヤー

496            e = self.ffw[p](e)

戻る

499        return e

レトロモデル

これはレトロデコーダーです

502class RetroModel(nn.Module):
  • v_vocab ボキャブラリ内のトークンの数です
  • d_model は埋め込みに含まれるフィーチャの数です
  • n_layers はデコーダーの層数です
  • ca_layers クロスアテンションの対象となるレイヤーは
  • chunk_len チャンクの長さです
  • n_heads はアテンションレイヤーのヘッド数です
  • d_k アテンションヘッドのサイズです
  • d_ff フィードフォワードネットワークの隠れ層のサイズ
  • encoder 最も近い近傍エンコーダです
509    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
510                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
522        super().__init__()
523
524        self.ca_layers = ca_layers
525        self.encoder = encoder

トークン埋め込みレイヤー

528        self.emb = nn.Embedding(n_vocab, d_model)

チャンク型クロス・アテンション・レイヤー

530        self.cca = nn.ModuleList(
531            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])

アテンションレイヤー

533        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])

フィードフォワードレイヤー

535        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

読み出し層

537        self.read = nn.Linear(d_model, n_vocab)

からの最近傍埋め込み用の事前正規化レイヤー

541        self.norm_e = nn.LayerNorm(d_model)
  • x は形状の入力シーケンス [batch_size, seq_len]
  • ret 検索されたシェイプの近傍です [batch_size, chunks, neighbors, neighbor_len]
543    def forward(self, x: torch.Tensor, ret: torch.Tensor):

入力埋め込みを取得

552        h = self.emb(x)

検索したネイバーの埋め込み。

入力と近傍の両方に同じ埋め込みを使用します

558        ret_emb = self.emb(ret)

チャンク化されたクロス・アテンション・レイヤーのインデックスを保持

561        p_ca = 0

すべてのレイヤー用

563        for p in range(len(self.attn)):

因果的自己注意

565            h = self.attn[p](h)

次の場合に、最初のレイヤーの前にエンコーダーの埋め込みを取得する

569            if self.ca_layers and p == min(self.ca_layers):

の埋め込みをエンコーダに渡しました。

573                e = self.encoder(ret_emb, h)

エンコーダ埋め込みを正規化

575                e = self.norm_e(e)

チャンク・クロス・アテンション if

578            if p in self.ca_layers:

580                h = self.cca[p_ca](h, e)

チャンク・クロス・アテンション・インデックスを増やす

582                p_ca += 1

585            h = self.ffw[p](h)

588        return self.read(h)

フェイクデータでモデルをテスト

591def _test():
595    chunk_len = 4
596    d_model = 8
597    d_ff = 32
598    n_heads = 2
599    d_k = 4
600
601    device = torch.device('cuda:0')
602
603    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
604                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
605
606    m.to(device)
607    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
608    ret = [
609        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
610        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
611    ]
612    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
613
614    inspect(res)

618if __name__ == '__main__':
619    _test()