旋转位置嵌入 (RoPE)

这是 PyT orch 中旋转位置嵌入 (RoP E) 的实现。

Rotary Positional Embeddings (RoPE) 使用自然包含明确的相对位置依赖关系的旋转矩阵对代币的位置信息进行编码。

以下是在 Tiny Shakespeare 数据集上使用 RoPE 训练变压器模型的训练代码

23import torch
24from torch import nn
25
26from labml.logger import inspect
27from labml_nn.transformers.mha import MultiHeadAttention

绳索模块

旋转编码通过在 2D 平面中旋转来转换成对的要素。也就是说,它将要素组织成对。每对都可以被视为二维平面中的一个坐标,编码将根据令牌的位置将其旋转一个角度。

对于一对功能

成为任何头部位置的键或查询的两个特征。或者为了简单起见,假设只有两个功能。那么转变就是,

其中是恒定角度。其他要素对的变换方式类似。

注意力是相对的

对于一对功能,点产品注意力分数介于两个位置之间,将为

这表明,对于点生产的关注,旋转编码给予了相对的关注。

对于所有功能

这些要素分组成对,并按上述方式处理。他们对每对使用不同的。

本文建议使用对的特征。

我们将功能与功能配对。因此,对于位置我们进行转换

30class RotaryPositionalEmbeddings(nn.Module):
  • d 是要素的数量
  • base 是用于计算的常数
  • 117    def __init__(self, d: int, base: int = 10_000):
    122        super().__init__()
    123
    124        self.base = base
    125        self.d = d
    126        self.cos_cached = None
    127        self.sin_cached = None

    缓存

    129    def _build_cache(self, x: torch.Tensor):

    如果缓存已经构建,则返回

    134        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
    135            return

    获取序列长度

    138        seq_len = x.shape[0]

    141        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)

    创建头寸指数[0, 1, ..., seq_len - 1]

    144        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)

    计算持仓指数的乘积和

    147        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)

    连接这样我们就有 row

    151        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

    缓存它们

    154        self.cos_cached = idx_theta2.cos()[:, None, None, :]
    155        self.sin_cached = idx_theta2.sin()[:, None, None, :]
    157    def _neg_half(self, x: torch.Tensor):

    159        d_2 = self.d // 2

    计算

    162        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
    • x 是位于键或带有形状的查询开头的 Tensor[seq_len, batch_size, n_heads, d]
    164    def forward(self, x: torch.Tensor):

    缓存

    169        self._build_cache(x)

    拆分特征,我们可以选择仅将旋转嵌入应用于部分特征集。

    172        x_rope, x_pass = x[..., :self.d], x[..., self.d:]

    计算

    176        neg_half_x = self._neg_half(x_rope)

    计算

    对于

    188        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])

    191        return torch.cat((x_rope, x_pass), dim=-1)

    通过旋转定位嵌入实现多头关注

    我们超越了原装变压器的多头注意力

    194class RotaryPEMultiHeadAttention(MultiHeadAttention):
    201    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
    202        super().__init__(heads, d_model, dropout_prob)

    旋转位置嵌入层

    205        d_rope = int(self.d_k * rope_percentage)
    206        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
    207        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)

    计算查询和键之间的分数

    209    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

    使用 ROPE 计算点积

    215        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))

    用一个简单的例子测试 RoPe

    218def _test_rotary():
    222    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
    223    x = x[:, None, None, :]
    224    inspect(x)
    225
    226    rotary_pe = RotaryPositionalEmbeddings(3)
    227    inspect(rotary_pe(x))
    228
    229
    230if __name__ == '__main__':
    231    _test_rotary()