14import math
15from typing import Set
16
17import torch
18from torch import nn
19
20from labml.logger import inspect
セルフアテンションレイヤーには回転位置埋め込みを使用しています。位置情報は埋め込みに埋め込まれているため、因果関係には使用しないと想定しています。因果関係のない自己注意には、推測できないため、明確な位置情報が必要です
。23class RotaryPositionalEmbeddings(nn.Module):
d
は機能の数 base
は計算に使用される定数です 34 def __init__(self, d: int, base: int = 10_000):
39 super().__init__()
41 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
x
キーまたは形状のあるクエリの先頭にあるテンソルです [ batch_size, seq_len, n_heads, d]
43 def forward(self, x: torch.Tensor):
形状を抽出
48 batch_size, seq_len, n_heads, d = x.shape
51 d_2 = d // 2
位置インデックスの作成 [0, 1, ..., seq_len - 1]
54 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
位置指数の積を計算し、
57 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
行が次のようになるように連結します
61 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
計算
65 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
77 rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
80 return rx
83class SelfAttention(nn.Module):
d_model
トランスフォーマー埋め込みのフィーチャー数ですn_heads
アテンション・ヘッドの数ですd_k
はヘッドあたりのフィーチャ数ですis_causal
これが因果関係であるかどうかを示します (マスクされています)90 def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
97 super().__init__()
98
99 self.is_causal = is_causal
100 self.n_heads = n_heads
101 self.d_k = d_k
ソフトマックスの前にアテンションをスケーリングするには
104 self.scale = 1 / math.sqrt(self.d_k)
クエリ、キー、バリューヘッド用のリニアレイヤー。
107 self.query = nn.Linear(d_model, n_heads * d_k)
108 self.key = nn.Linear(d_model, n_heads * d_k)
109 self.value = nn.Linear(d_model, n_heads * d_k)
プレ・ノルム・レイヤーこの論文では代わりにRMSnormを使用しています
。112 self.norm = nn.LayerNorm(d_model)
注意確率のソフトマックス
115 self.softmax = nn.Softmax(dim=-1)
ロータリーポジションエンベディング
118 self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
最終線形レイヤー
121 self.output = nn.Linear(n_heads * d_k, d_model)
123 def mask_attention(self, attn: torch.Tensor):
因果関係のない注意のためのマスキングなし
131 if not self.is_causal:
132 return attn
三角マスクの作成
135 mask = torch.tril(attn.new_ones(attn.shape[-2:]))
マスクで絞り込む
137 return attn.masked_fill(mask == 0, float('-inf'))
h
形をした変圧器の埋め込みです [batch_size, seq_len, d_model]
139 def forward(self, h: torch.Tensor):
残留接続
145 h_res = h
事前正規化
148 h = self.norm(h)
クエリ、キー、値を取得し、それらをヘッドに分割します。これらには形があります [batch_size, seq_len, n_heads, d_k]
152 mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
153 q = self.query(h).view(mh_shape)
154 k = self.key(h).view(mh_shape)
155 v = self.value(h).view(mh_shape)
回転式位置埋め込みを適用
158 q = self.rotary_pe(q)
159 k = self.rotary_pe(k)
注意事項の計算
162 attn = torch.einsum('bihd,bjhd->bhij', q, k)
スケーリング・バイ・スケーリング
164 attn = attn * self.scale
原因となる注意が必要な場合はマスクを着用してください
167 attn = self.mask_attention(attn)
注意確率の計算
170 attn = self.softmax(attn)
値を取得
173 h = torch.einsum("bhij,bjhd->bihd", attn, v)
[batch_size, seq_len, n_heads, d_k]
形状を次のように変更 [batch_size, seq_len, n_heads * d_k]
177 h = h.reshape(*h.shape[:-2], -1)
最後の線形レイヤーを適用します。結果は形になります [batch_size, seq_len, d_model]
181 h = self.output(h)
残余接続を追加
184 return h + h_res
これは上で定義したセルフアテンションレイヤーと似ていますが、クエリとは異なる埋め込みセットからキーと値を取得する点が異なります。
これをエンコーダーで使用して、取得したチャンクを入力チャンクに基づいてエンコードします。
ここでは明示的な位置埋め込みは一切使用していません。モデルは埋め込み内の位置情報を暗黙的に表現できると仮定します
。187class CrossAttention(nn.Module):
d_model
トランスフォーマー埋め込みのフィーチャー数ですn_heads
アテンション・ヘッドの数ですd_k
はヘッドあたりのフィーチャ数です201 def __init__(self, d_model: int, n_heads: int, d_k: int):
207 super().__init__()
208
209 self.n_heads = n_heads
210 self.d_k = d_k
ソフトマックスの前にアテンションをスケーリングするには
213 self.scale = 1 / math.sqrt(self.d_k)
クエリ、キー、バリューヘッド用のリニアレイヤー。
216 self.query = nn.Linear(d_model, n_heads * d_k)
217 self.key = nn.Linear(d_model, n_heads * d_k)
218 self.value = nn.Linear(d_model, n_heads * d_k)
クエリ埋め込み用のプレノルムレイヤー。この論文では代わりにRMSnormを使用しています
。221 self.norm = nn.LayerNorm(d_model)
注意確率のソフトマックス
224 self.softmax = nn.Softmax(dim=-1)
最終線形レイヤー
227 self.output = nn.Linear(n_heads * d_k, d_model)
e
検索された最も近い近傍のチャンク埋め込みの形状の埋め込みです [batch_size, chunks, neighbors, neighbor_len, d_model]
h
は入力チャンクで、そこから最も近い近傍データが取得されました。[batch_size, chunks, chunk_len, d_model]
これはすでに標準化されています。229 def forward(self, e: torch.Tensor, h: torch.Tensor):
残留接続
238 e_res = e
取得したチャンクを正規化
241 e = self.norm(e)
取得したチャンクからクエリを取得
244 q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
入力チャンクからキーと値を取得
246 k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
247 v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
すべてのチャンクのアテンションスコアを計算します。取得された各ネイバーは、それを取得した元のチャンクに注目します。これは形になります [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
252 attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
スケールアテンションスコア
254 attn = attn * self.scale
最後の次元のソフトマックスを計算
257 attn = self.softmax(attn)
価値を集める
260 e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
[batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]
形状を次のように変更 [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
264 e = e.reshape(*e.shape[:-2], -1)
最後の線形レイヤーを適用します。結果は形になります [batch_size, chunks, neighbors, neighbor_len, d_model]
268 e = self.output(e)
残余接続を追加
271 return e + e_res
これは上で定義したクロス・アテンション・レイヤーに似ています。
これをデコーダで使用して、取得した隣接チャンクに注目します。
ここでは明示的な位置埋め込みは一切使用していません。モデルは埋め込み内の位置情報を暗黙的に表現できると仮定します
。274class ChunkedCrossAttention(nn.Module):
d_model
トランスフォーマー埋め込みのフィーチャー数ですn_heads
アテンション・ヘッドの数ですd_k
はヘッドあたりのフィーチャ数ですchunk_len
チャンクの長さです286 def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
294 super().__init__()
295
296 self.chunk_len = chunk_len
297 self.n_heads = n_heads
298 self.d_k = d_k
ソフトマックスの前にアテンションをスケーリングするには
301 self.scale = 1 / math.sqrt(self.d_k)
クエリ、キー、バリューヘッド用のリニアレイヤー。
304 self.query = nn.Linear(d_model, n_heads * d_k)
305 self.key = nn.Linear(d_model, n_heads * d_k)
306 self.value = nn.Linear(d_model, n_heads * d_k)
クエリ埋め込み用のプレノルムレイヤー。この論文では代わりにRMSnormを使用しています
。309 self.norm = nn.LayerNorm(d_model)
注意確率のソフトマックス
312 self.softmax = nn.Softmax(dim=-1)
最終線形レイヤー
315 self.output = nn.Linear(n_heads * d_k, d_model)
h
図形の入力埋め込みは、[batch_size, seq_len, d_model]
e
検索された図形の最も近い近傍データです [batch_size, chunks, neighbors, neighbor_len, d_model]
317 def forward(self, h: torch.Tensor, e: torch.Tensor):
シェイプを取得
324 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
チャンクがない場合は不要 (サンプリング時の入力が短い場合)
327 if chunks == 0:
328 return h
残留接続
331 h_res = h
chunk_len - 1
最初の埋め込みを削除します。入力は、過去のトークンのみを使用して取得およびエンコードされたネイバーに注目し、情報漏えいが発生しないようにします。つまり、最初のチャンクから取得したネイバーには、最初のチャンクからの情報が含まれます。そのため、シーケンスを左にシフトすることでchunk_len - 1
、情報が右にのみ流れるようにします
339 h = h[:, self.chunk_len - 1:]
プレノルム
341 h = self.norm(h)
入力をチャンクに分割できるように、最後に空の埋め込みを追加します
343 if h.shape[1] < chunks * self.chunk_len:
344 h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
入力をチャンクにリシェイプします。
346 h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
入力からクエリを取得
349 q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
取得したネイバーからキーと値を取得
351 k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
352 v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
入力チャンクのアテンションスコアを計算します。各チャンクは、前のチャンクで取得した隣接チャンクに注目します。これは形になります [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
357 attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
スケールアテンションスコア
359 attn = attn * self.scale
最後の 2 次元にソフトマックスを適用 neighbors, neighbor_len
362 attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
価値を集める
365 h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
[batch_size, chunks, chunk_len, n_heads, d_k]
形状を次のように変更 [batch_size, chunks * chunk_len, n_heads * d_k]
369 h = h.reshape(batch_size, chunks * self.chunk_len, -1)
最後の線形レイヤーを適用します。結果は形になります [batch_size, chunks * chunk_len, d_model]
373 h = self.output(h)
chunk_len - 1
左にゼロの埋め込みを追加、つまり右にシフトして戻す
376 h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
残りの接続を切り捨てて追加
379 return h[:, :h_res.shape[1]] + h_res
382class FeedForward(nn.Module):
d_model
トランスフォーマー埋め込みのフィーチャー数ですd_ff
は隠れレイヤーのナンバーフィーチャです389 def __init__(self, d_model: int, d_ff: int):
395 super().__init__()
2 つの線形レイヤー
398 self.lin1 = nn.Linear(d_model, d_ff)
399 self.lin2 = nn.Linear(d_ff, d_model)
ReLU アクティベーション
402 self.act = nn.ReLU()
プレ・ノルム・レイヤー
405 self.norm = nn.LayerNorm(d_model)
h
形が埋め込まれているものです [batch_size, seq_len, d_model]
407 def forward(self, h: torch.Tensor):
残余
413 h_res = h
プレノルム
415 h = self.norm(h)
第 1 線形レイヤー
417 h = self.lin1(h)
アクティベーション
419 h = self.act(h)
2 番目の線形レイヤー
421 h = self.lin2(h)
残余接続を追加
424 return h + h_res
427class NearestNeighborEncoder(nn.Module):
chunk_len
チャンクの長さですn_layer
はエンコーダーのレイヤー数です ca_layers
クロスアテンションの対象となるレイヤーは d_model
は埋め込みに含まれるフィーチャの数ですn_heads
はアテンションレイヤーのヘッド数ですd_k
アテンションヘッドのサイズですd_ff
フィードフォワードネットワークの隠れ層のサイズ434 def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
435 d_model: int, n_heads: int, d_k: int, d_ff: int):
446 super().__init__()
447 self.ca_layers = ca_layers
448 self.chunk_len = chunk_len
クロスアテンションレイヤー
450 self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
双方向のセルフアテンションレイヤー
452 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
フィードフォワードレイヤー
454 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
の事前正規化レイヤー
457 self.norm_h = nn.LayerNorm(d_model)
e
検索された最も近い近傍のトークンの埋め込みで、形も整っています [batch_size, chunks, neighbors, neighbor_len, d_model]
h
は入力トークンの埋め込みで、形状はさまざまです [batch_size, seq_len, d_model]
チャンクとネイバーは並行して処理されます。
459 def forward(self, e: torch.Tensor, h: torch.Tensor):
シェイプを取得
472 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
475 h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
プレノルム
478 h_split = self.norm_h(h_split)
クロス・アテンション・レイヤーのインデックスを維持
481 p_ca = 0
すべてのレイヤー用
483 for p in range(len(self.attn)):
双方向のセルフアテンション
486 e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
次の場合はクロスアテンション
489 if p in self.ca_layers:
491 e = self.ca[p_ca](e, h_split)
クロス・アテンション・インデックスを増やす
493 p_ca += 1
フィードフォワードレイヤー
496 e = self.ffw[p](e)
戻る
499 return e
502class RetroModel(nn.Module):
v_vocab
ボキャブラリ内のトークンの数ですd_model
は埋め込みに含まれるフィーチャの数ですn_layers
はデコーダーの層数です ca_layers
クロスアテンションの対象となるレイヤーは chunk_len
チャンクの長さですn_heads
はアテンションレイヤーのヘッド数ですd_k
アテンションヘッドのサイズですd_ff
フィードフォワードネットワークの隠れ層のサイズencoder
最も近い近傍エンコーダです509 def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
510 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
522 super().__init__()
523
524 self.ca_layers = ca_layers
525 self.encoder = encoder
トークン埋め込みレイヤー
528 self.emb = nn.Embedding(n_vocab, d_model)
チャンク型クロス・アテンション・レイヤー
530 self.cca = nn.ModuleList(
531 [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
アテンションレイヤー
533 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
フィードフォワードレイヤー
535 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
読み出し層
537 self.read = nn.Linear(d_model, n_vocab)
からの最近傍埋め込み用の事前正規化レイヤー
541 self.norm_e = nn.LayerNorm(d_model)
x
は形状の入力シーケンス [batch_size, seq_len]
ret
検索されたシェイプの近傍です [batch_size, chunks, neighbors, neighbor_len]
543 def forward(self, x: torch.Tensor, ret: torch.Tensor):
入力埋め込みを取得
552 h = self.emb(x)
558 ret_emb = self.emb(ret)
チャンク化されたクロス・アテンション・レイヤーのインデックスを保持
561 p_ca = 0
すべてのレイヤー用
563 for p in range(len(self.attn)):
因果的自己注意
565 h = self.attn[p](h)
次の場合に、最初のレイヤーの前にエンコーダーの埋め込みを取得する
569 if self.ca_layers and p == min(self.ca_layers):
573 e = self.encoder(ret_emb, h)
エンコーダ埋め込みを正規化
575 e = self.norm_e(e)
チャンク・クロス・アテンション if
578 if p in self.ca_layers:
580 h = self.cca[p_ca](h, e)
チャンク・クロス・アテンション・インデックスを増やす
582 p_ca += 1
585 h = self.ffw[p](h)
588 return self.read(h)
591def _test():
595 chunk_len = 4
596 d_model = 8
597 d_ff = 32
598 n_heads = 2
599 d_k = 4
600
601 device = torch.device('cuda:0')
602
603 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
604 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
605
606 m.to(device)
607 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
608 ret = [
609 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
610 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
611 ]
612 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
613
614 inspect(res)
618if __name__ == '__main__':
619 _test()