Rotary Positional Embeddings (RoPE)

This is an implementation of Rotary Positional Embeddings (RoPE) in PyTorch.

Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency.

Here's the training code for training a transformer model with RoPE on Tiny Shakespeare dataset.

View Run

25import torch
26from torch import nn
27
28from labml.logger import inspect
29from labml_nn.transformers.mha import MultiHeadAttention

RoPE module

Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.

For a pair of features

Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,

where is a constant angle. The other pairs of features are transformed similarly.

Attention is relative

For a pair of features, dot-product attention score between two positions and would be

This shows that for dot-production attention the rotary encodings gives relative attention.

For all features

The features are grouped into pairs and handled as above. They use a different for each pair.

The paper suggests using for the pairs of features.

We pair feature with feature . So for position we transform

to

32class RotaryPositionalEmbeddings(nn.Module):
  • d is the number of features
  • base is the constant used for calculating
119    def __init__(self, d: int, base: int = 10_000):
124        super().__init__()
125
126        self.base = base
127        self.d = d
128        self.cos_cached = None
129        self.sin_cached = None

Cache and values

131    def _build_cache(self, x: torch.Tensor):

Return if cache is already built

136        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
137            return

Get sequence length

140        seq_len = x.shape[0]

143        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)

Create position indexes [0, 1, ..., seq_len - 1]

146        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)

Calculate the product of position index and

149        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)

Concatenate so that for row we have

153        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

Cache them

156        self.cos_cached = idx_theta2.cos()[:, None, None, :]
157        self.sin_cached = idx_theta2.sin()[:, None, None, :]
159    def _neg_half(self, x: torch.Tensor):

161        d_2 = self.d // 2

Calculate

164        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
  • x is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
166    def forward(self, x: torch.Tensor):

Cache and values

171        self._build_cache(x)

Split the features, we can choose to apply rotary embeddings only to a partial set of features.

174        x_rope, x_pass = x[..., :self.d], x[..., self.d:]

Calculate

178        neg_half_x = self._neg_half(x_rope)

Calculate

for

190        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])

193        return torch.cat((x_rope, x_pass), dim=-1)

Multi-head attention with rotary positional embeddings

We override multi-head attention from original transformer.

196class RotaryPEMultiHeadAttention(MultiHeadAttention):
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
204        super().__init__(heads, d_model, dropout_prob)

Rotary positional embedding layers

207        d_rope = int(self.d_k * rope_percentage)
208        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
209        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)

Calculate scores between queries and keys

211    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

Calculate dot-product with RoPE

217        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))

Testing RoPE with a simple example

220def _test_rotary():
224    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
225    x = x[:, None, None, :]
226    inspect(x)
227
228    rotary_pe = RotaryPositionalEmbeddings(3)
229    inspect(rotary_pe(x))
230
231
232if __name__ == '__main__':
233    _test_rotary()