这是 PyT orch 中旋转位置嵌入 (RoP E) 的实现。
Rotary Positional Embeddings (RoPE) 使用自然包含明确的相对位置依赖关系的旋转矩阵对代币的位置信息进行编码。
23import torch
24from torch import nn
25
26from labml.logger import inspect
27from labml_nn.transformers.mha import MultiHeadAttention
旋转编码通过在 2D 平面中旋转来转换成对的要素。也就是说,它将要素组织成对。每对都可以被视为二维平面中的一个坐标,编码将根据令牌的位置将其旋转一个角度。
让和成为任何头部位置的键或查询的两个特征。或者为了简单起见,假设只有两个功能。那么转变就是,
其中是恒定角度。其他要素对的变换方式类似。
对于一对功能,点产品注意力分数介于两个位置之间,将为
这表明,对于点生产的关注,旋转编码给予了相对的关注。
这些要素分组成对,并按上述方式处理。他们对每对使用不同的。
本文建议使用成对的特征。
我们将功能与功能配对。因此,对于位置我们进行转换
至
30class RotaryPositionalEmbeddings(nn.Module):
117 def __init__(self, d: int, base: int = 10_000):
122 super().__init__()
123
124 self.base = base
125 self.d = d
126 self.cos_cached = None
127 self.sin_cached = None
缓存和值
129 def _build_cache(self, x: torch.Tensor):
如果缓存已经构建,则返回
134 if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
135 return
获取序列长度
138 seq_len = x.shape[0]
141 theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
创建头寸指数[0, 1, ..., seq_len - 1]
144 seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
计算持仓指数的乘积和
147 idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
连接这样我们就有 row
151 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
缓存它们
154 self.cos_cached = idx_theta2.cos()[:, None, None, :]
155 self.sin_cached = idx_theta2.sin()[:, None, None, :]
157 def _neg_half(self, x: torch.Tensor):
159 d_2 = self.d // 2
计算
162 return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
x
是位于键或带有形状的查询开头的 Tensor[seq_len, batch_size, n_heads, d]
164 def forward(self, x: torch.Tensor):
缓存和值
169 self._build_cache(x)
拆分特征,我们可以选择仅将旋转嵌入应用于部分特征集。
172 x_rope, x_pass = x[..., :self.d], x[..., self.d:]
计算
176 neg_half_x = self._neg_half(x_rope)
188 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
191 return torch.cat((x_rope, x_pass), dim=-1)
194class RotaryPEMultiHeadAttention(MultiHeadAttention):
201 def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
202 super().__init__(heads, d_model, dropout_prob)
旋转位置嵌入层
205 d_rope = int(self.d_k * rope_percentage)
206 self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
207 self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
209 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
使用 ROPE 计算点积
215 return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
用一个简单的例子测试 RoPe
218def _test_rotary():
222 x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
223 x = x[:, None, None, :]
224 inspect(x)
225
226 rotary_pe = RotaryPositionalEmbeddings(4)
227 inspect(rotary_pe(x))
228
229
230if __name__ == '__main__':
231 _test_rotary()