これは PyTorch のロータリー位置埋め込み (RoPE) の実装です。
Rotary Positional Embeddings(RoPE)は、トークンの位置情報を回転マトリックスでエンコードします。回転マトリックスには、明示的な相対位置依存性が自然に組み込まれています。
Tiny ShakespeareデータセットでRoPEを使用してトランスフォーマーモデルをトレーニングするためのトレーニングコードは次のとおりです。
23import torch
24from torch import nn
25
26from labml.logger import inspect
27from labml_nn.transformers.mha import MultiHeadAttention
ロータリーエンコーディングでは、2 つのフィーチャを 2D 平面上で回転させて変換します。つまり、フィーチャをペアとして整理します。各ペアは2D平面内の座標と見なすことができ、エンコーディングではトークンの位置に応じて角度だけ回転します。
任意の位置で任意のヘッドのキーまたはクエリの2つの特徴としましょう。または、簡単にするために、2 つの機能しか持っていないと仮定します。そうすると、変換は、
ここで、は一定の角度です。他のフィーチャペアも同様に変換されます
。2 つの特徴について、2 つの位置間のアテンションスコアを点積すると、
このことから、ドットプロダクションで注目される場合は、ロータリーエンコーディングが相対的に注目されることがわかります。
フィーチャはペアにグループ化され、上記のように処理されます。彼らはペアごとに違うものを使います。
この論文では、これらの機能を組み合わせて使用することを提案しています。
機能と機能を組み合わせます。だから位置は変身する
に
30class RotaryPositionalEmbeddings(nn.Module):
d
は機能の数 base
は計算に使用される定数です 117 def __init__(self, d: int, base: int = 10_000):
122 super().__init__()
123
124 self.base = base
125 self.d = d
126 self.cos_cached = None
127 self.sin_cached = None
キャッシュと値
129 def _build_cache(self, x: torch.Tensor):
キャッシュが既に構築されている場合は返す
134 if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
135 return
シーケンスの長さを取得
138 seq_len = x.shape[0]
141 theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
位置インデックスの作成 [0, 1, ..., seq_len - 1]
144 seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
位置指数の積を計算し、
147 idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
行が次のようになるように連結します
151 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
それらをキャッシュする
154 self.cos_cached = idx_theta2.cos()[:, None, None, :]
155 self.sin_cached = idx_theta2.sin()[:, None, None, :]
157 def _neg_half(self, x: torch.Tensor):
159 d_2 = self.d // 2
計算
162 return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
x
キーまたは形状のあるクエリの先頭にあるテンソルです [seq_len, batch_size, n_heads, d]
164 def forward(self, x: torch.Tensor):
キャッシュと値
169 self._build_cache(x)
機能を分割して、一部の機能セットにのみロータリー埋め込みを適用することもできます。
172 x_rope, x_pass = x[..., :self.d], x[..., self.d:]
計算
176 neg_half_x = self._neg_half(x_rope)
188 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
191 return torch.cat((x_rope, x_pass), dim=-1)
194class RotaryPEMultiHeadAttention(MultiHeadAttention):
201 def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
202 super().__init__(heads, d_model, dropout_prob)
ロータリーポジショナル埋め込みレイヤー
205 d_rope = int(self.d_k * rope_percentage)
206 self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
207 self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
209 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
RoPE によるドット積の計算
215 return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
簡単な例での RoPE のテスト
218def _test_rotary():
222 x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
223 x = x[:, None, None, :]
224 inspect(x)
225
226 rotary_pe = RotaryPositionalEmbeddings(3)
227 inspect(rotary_pe(x))
228
229
230if __name__ == '__main__':
231 _test_rotary()