14import math
15from typing import Set
16
17import torch
18from torch import nn
19
20from labml.logger import inspectセルフアテンションレイヤーには回転位置埋め込みを使用しています。位置情報は埋め込みに埋め込まれているため、因果関係には使用しないと想定しています。因果関係のない自己注意には、推測できないため、明確な位置情報が必要です
。23class RotaryPositionalEmbeddings(nn.Module):d
は機能の数 base
は計算に使用される定数です 34    def __init__(self, d: int, base: int = 10_000):39        super().__init__()41        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)x
キーまたは形状のあるクエリの先頭にあるテンソルです [ batch_size, seq_len, n_heads, d]
43    def forward(self, x: torch.Tensor):形状を抽出
48        batch_size, seq_len, n_heads, d = x.shape51        d_2 = d // 2位置インデックスの作成 [0, 1, ..., seq_len - 1]
54        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)位置指数の積を計算し、
57        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)行が次のようになるように連結します
61        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)計算
65        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)77        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])80        return rx83class SelfAttention(nn.Module):d_model
トランスフォーマー埋め込みのフィーチャー数ですn_heads
アテンション・ヘッドの数ですd_k
はヘッドあたりのフィーチャ数ですis_causal
これが因果関係であるかどうかを示します (マスクされています)90    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):97        super().__init__()
98
99        self.is_causal = is_causal
100        self.n_heads = n_heads
101        self.d_k = d_kソフトマックスの前にアテンションをスケーリングするには
104        self.scale = 1 / math.sqrt(self.d_k)クエリ、キー、バリューヘッド用のリニアレイヤー。
107        self.query = nn.Linear(d_model, n_heads * d_k)
108        self.key = nn.Linear(d_model, n_heads * d_k)
109        self.value = nn.Linear(d_model, n_heads * d_k)プレ・ノルム・レイヤーこの論文では代わりにRMSnormを使用しています
。112        self.norm = nn.LayerNorm(d_model)注意確率のソフトマックス
115        self.softmax = nn.Softmax(dim=-1)ロータリーポジションエンベディング
118        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)最終線形レイヤー
121        self.output = nn.Linear(n_heads * d_k, d_model)123    def mask_attention(self, attn: torch.Tensor):因果関係のない注意のためのマスキングなし
131        if not self.is_causal:
132            return attn三角マスクの作成
135        mask = torch.tril(attn.new_ones(attn.shape[-2:]))マスクで絞り込む
137        return attn.masked_fill(mask == 0, float('-inf'))h
形をした変圧器の埋め込みです [batch_size, seq_len, d_model]
139    def forward(self, h: torch.Tensor):残留接続
145        h_res = h事前正規化
148        h = self.norm(h)クエリ、キー、値を取得し、それらをヘッドに分割します。これらには形があります [batch_size, seq_len, n_heads, d_k]
152        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
153        q = self.query(h).view(mh_shape)
154        k = self.key(h).view(mh_shape)
155        v = self.value(h).view(mh_shape)回転式位置埋め込みを適用
158        q = self.rotary_pe(q)
159        k = self.rotary_pe(k)注意事項の計算
162        attn = torch.einsum('bihd,bjhd->bhij', q, k)スケーリング・バイ・スケーリング
164        attn = attn * self.scale原因となる注意が必要な場合はマスクを着用してください
167        attn = self.mask_attention(attn)注意確率の計算
170        attn = self.softmax(attn)値を取得
173        h = torch.einsum("bhij,bjhd->bihd", attn, v)[batch_size, seq_len, n_heads, d_k]
形状を次のように変更 [batch_size, seq_len, n_heads * d_k]
177        h = h.reshape(*h.shape[:-2], -1)最後の線形レイヤーを適用します。結果は形になります [batch_size, seq_len, d_model]
181        h = self.output(h)残余接続を追加
184        return h + h_resこれは上で定義したセルフアテンションレイヤーと似ていますが、クエリとは異なる埋め込みセットからキーと値を取得する点が異なります。
これをエンコーダーで使用して、取得したチャンクを入力チャンクに基づいてエンコードします。
ここでは明示的な位置埋め込みは一切使用していません。モデルは埋め込み内の位置情報を暗黙的に表現できると仮定します
。187class CrossAttention(nn.Module):d_model
トランスフォーマー埋め込みのフィーチャー数ですn_heads
アテンション・ヘッドの数ですd_k
はヘッドあたりのフィーチャ数です201    def __init__(self, d_model: int, n_heads: int, d_k: int):207        super().__init__()
208
209        self.n_heads = n_heads
210        self.d_k = d_kソフトマックスの前にアテンションをスケーリングするには
213        self.scale = 1 / math.sqrt(self.d_k)クエリ、キー、バリューヘッド用のリニアレイヤー。
216        self.query = nn.Linear(d_model, n_heads * d_k)
217        self.key = nn.Linear(d_model, n_heads * d_k)
218        self.value = nn.Linear(d_model, n_heads * d_k)クエリ埋め込み用のプレノルムレイヤー。この論文では代わりにRMSnormを使用しています
。221        self.norm = nn.LayerNorm(d_model)注意確率のソフトマックス
224        self.softmax = nn.Softmax(dim=-1)最終線形レイヤー
227        self.output = nn.Linear(n_heads * d_k, d_model)e
検索された最も近い近傍のチャンク埋め込みの形状の埋め込みです [batch_size, chunks, neighbors, neighbor_len, d_model]
h
は入力チャンクで、そこから最も近い近傍データが取得されました。[batch_size, chunks, chunk_len, d_model]
これはすでに標準化されています。229    def forward(self, e: torch.Tensor, h: torch.Tensor):残留接続
238        e_res = e取得したチャンクを正規化
241        e = self.norm(e)取得したチャンクからクエリを取得
244        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)入力チャンクからキーと値を取得
246        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
247        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)すべてのチャンクのアテンションスコアを計算します。取得された各ネイバーは、それを取得した元のチャンクに注目します。これは形になります [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
252        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)スケールアテンションスコア
254        attn = attn * self.scale最後の次元のソフトマックスを計算
257        attn = self.softmax(attn)価値を集める
260        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)[batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]
形状を次のように変更 [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
264        e = e.reshape(*e.shape[:-2], -1)最後の線形レイヤーを適用します。結果は形になります [batch_size, chunks, neighbors, neighbor_len, d_model]
268        e = self.output(e)残余接続を追加
271        return e + e_resこれは上で定義したクロス・アテンション・レイヤーに似ています。
これをデコーダで使用して、取得した隣接チャンクに注目します。
ここでは明示的な位置埋め込みは一切使用していません。モデルは埋め込み内の位置情報を暗黙的に表現できると仮定します
。274class ChunkedCrossAttention(nn.Module):d_model
トランスフォーマー埋め込みのフィーチャー数ですn_heads
アテンション・ヘッドの数ですd_k
はヘッドあたりのフィーチャ数ですchunk_len
チャンクの長さです286    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):294        super().__init__()
295
296        self.chunk_len = chunk_len
297        self.n_heads = n_heads
298        self.d_k = d_kソフトマックスの前にアテンションをスケーリングするには
301        self.scale = 1 / math.sqrt(self.d_k)クエリ、キー、バリューヘッド用のリニアレイヤー。
304        self.query = nn.Linear(d_model, n_heads * d_k)
305        self.key = nn.Linear(d_model, n_heads * d_k)
306        self.value = nn.Linear(d_model, n_heads * d_k)クエリ埋め込み用のプレノルムレイヤー。この論文では代わりにRMSnormを使用しています
。309        self.norm = nn.LayerNorm(d_model)注意確率のソフトマックス
312        self.softmax = nn.Softmax(dim=-1)最終線形レイヤー
315        self.output = nn.Linear(n_heads * d_k, d_model)h
図形の入力埋め込みは、[batch_size, seq_len, d_model]
e
検索された図形の最も近い近傍データです [batch_size, chunks, neighbors, neighbor_len, d_model]
317    def forward(self, h: torch.Tensor, e: torch.Tensor):シェイプを取得
324        batch_size, chunks, neighbors, neighbor_len, d_model = e.shapeチャンクがない場合は不要 (サンプリング時の入力が短い場合)
327        if chunks == 0:
328            return h残留接続
331        h_res = hchunk_len - 1
最初の埋め込みを削除します。入力は、過去のトークンのみを使用して取得およびエンコードされたネイバーに注目し、情報漏えいが発生しないようにします。つまり、最初のチャンクから取得したネイバーには、最初のチャンクからの情報が含まれます。そのため、シーケンスを左にシフトすることでchunk_len - 1
、情報が右にのみ流れるようにします
339        h = h[:, self.chunk_len - 1:]プレノルム
341        h = self.norm(h)入力をチャンクに分割できるように、最後に空の埋め込みを追加します
343        if h.shape[1] < chunks * self.chunk_len:
344            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)入力をチャンクにリシェイプします。
346        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)入力からクエリを取得
349        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)取得したネイバーからキーと値を取得
351        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
352        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)入力チャンクのアテンションスコアを計算します。各チャンクは、前のチャンクで取得した隣接チャンクに注目します。これは形になります [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
357        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)スケールアテンションスコア
359        attn = attn * self.scale最後の 2 次元にソフトマックスを適用 neighbors, neighbor_len
362        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)価値を集める
365        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)[batch_size, chunks, chunk_len, n_heads, d_k]
形状を次のように変更 [batch_size, chunks * chunk_len, n_heads * d_k]
369        h = h.reshape(batch_size, chunks * self.chunk_len, -1)最後の線形レイヤーを適用します。結果は形になります [batch_size, chunks * chunk_len, d_model]
373        h = self.output(h)chunk_len - 1
左にゼロの埋め込みを追加、つまり右にシフトして戻す
376        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)残りの接続を切り捨てて追加
379        return h[:, :h_res.shape[1]] + h_res382class FeedForward(nn.Module):d_model
トランスフォーマー埋め込みのフィーチャー数ですd_ff
は隠れレイヤーのナンバーフィーチャです389    def __init__(self, d_model: int, d_ff: int):395        super().__init__()2 つの線形レイヤー
398        self.lin1 = nn.Linear(d_model, d_ff)
399        self.lin2 = nn.Linear(d_ff, d_model)ReLU アクティベーション
402        self.act = nn.ReLU()プレ・ノルム・レイヤー
405        self.norm = nn.LayerNorm(d_model)h
形が埋め込まれているものです [batch_size, seq_len, d_model]
407    def forward(self, h: torch.Tensor):残余
413        h_res = hプレノルム
415        h = self.norm(h)第 1 線形レイヤー
417        h = self.lin1(h)アクティベーション
419        h = self.act(h)2 番目の線形レイヤー
421        h = self.lin2(h)残余接続を追加
424        return h + h_res427class NearestNeighborEncoder(nn.Module):chunk_len
チャンクの長さですn_layer
はエンコーダーのレイヤー数です ca_layers
クロスアテンションの対象となるレイヤーは d_model
は埋め込みに含まれるフィーチャの数ですn_heads
はアテンションレイヤーのヘッド数ですd_k
アテンションヘッドのサイズですd_ff
フィードフォワードネットワークの隠れ層のサイズ434    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
435                 d_model: int, n_heads: int, d_k: int, d_ff: int):446        super().__init__()
447        self.ca_layers = ca_layers
448        self.chunk_len = chunk_lenクロスアテンションレイヤー
450        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])双方向のセルフアテンションレイヤー
452        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])フィードフォワードレイヤー
454        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])の事前正規化レイヤー
457        self.norm_h = nn.LayerNorm(d_model)e
検索された最も近い近傍のトークンの埋め込みで、形も整っています  [batch_size, chunks, neighbors, neighbor_len, d_model]
h
は入力トークンの埋め込みで、形状はさまざまです  [batch_size, seq_len, d_model]
チャンクとネイバーは並行して処理されます。
459    def forward(self, e: torch.Tensor, h: torch.Tensor):シェイプを取得
472        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape475        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)プレノルム
478        h_split = self.norm_h(h_split)クロス・アテンション・レイヤーのインデックスを維持
481        p_ca = 0すべてのレイヤー用
483        for p in range(len(self.attn)):双方向のセルフアテンション
486            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)次の場合はクロスアテンション
489            if p in self.ca_layers:491                e = self.ca[p_ca](e, h_split)クロス・アテンション・インデックスを増やす
493                p_ca += 1フィードフォワードレイヤー
496            e = self.ffw[p](e)戻る
499        return e502class RetroModel(nn.Module):v_vocab
ボキャブラリ内のトークンの数ですd_model
は埋め込みに含まれるフィーチャの数ですn_layers
はデコーダーの層数です ca_layers
クロスアテンションの対象となるレイヤーは chunk_len
チャンクの長さですn_heads
はアテンションレイヤーのヘッド数ですd_k
アテンションヘッドのサイズですd_ff
フィードフォワードネットワークの隠れ層のサイズencoder
最も近い近傍エンコーダです509    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
510                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):522        super().__init__()
523
524        self.ca_layers = ca_layers
525        self.encoder = encoderトークン埋め込みレイヤー
528        self.emb = nn.Embedding(n_vocab, d_model)チャンク型クロス・アテンション・レイヤー
530        self.cca = nn.ModuleList(
531            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])アテンションレイヤー
533        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])フィードフォワードレイヤー
535        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])読み出し層
537        self.read = nn.Linear(d_model, n_vocab)からの最近傍埋め込み用の事前正規化レイヤー
541        self.norm_e = nn.LayerNorm(d_model)x
は形状の入力シーケンス [batch_size, seq_len]
ret
検索されたシェイプの近傍です  [batch_size, chunks, neighbors, neighbor_len]
543    def forward(self, x: torch.Tensor, ret: torch.Tensor):入力埋め込みを取得
552        h = self.emb(x)558        ret_emb = self.emb(ret)チャンク化されたクロス・アテンション・レイヤーのインデックスを保持
561        p_ca = 0すべてのレイヤー用
563        for p in range(len(self.attn)):因果的自己注意
565            h = self.attn[p](h)次の場合に、最初のレイヤーの前にエンコーダーの埋め込みを取得する
569            if self.ca_layers and p == min(self.ca_layers):573                e = self.encoder(ret_emb, h)エンコーダ埋め込みを正規化
575                e = self.norm_e(e)チャンク・クロス・アテンション if
578            if p in self.ca_layers:580                h = self.cca[p_ca](h, e)チャンク・クロス・アテンション・インデックスを増やす
582                p_ca += 1585            h = self.ffw[p](h)588        return self.read(h)591def _test():595    chunk_len = 4
596    d_model = 8
597    d_ff = 32
598    n_heads = 2
599    d_k = 4
600
601    device = torch.device('cuda:0')
602
603    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
604                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
605
606    m.to(device)
607    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
608    ret = [
609        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
610        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
611    ]
612    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
613
614    inspect(res)618if __name__ == '__main__':
619    _test()