RoPER is work by Georges Harik (@gharik), and this implementation is based on his original code.
Rotary Positional Embeddings (RoPE) includes relative positions in attention score calculation. However, the embeddings themselves do not get any positional information , except what it can get implicitly from causal attention.
RoPER adds relative positional information explicitly to value embeddings. Specifically, it adds the relative positions of the tokens it paid attention to. We use same rotary positional embeddings to rotate the values in attention, Then, after taking the weighted sum, we rotate the final in the opposite direction. Which is equivalent to rotating each of the values (before attention) relative to the current position.
Here's the training code for training a transformer model with RoPER on an arithmetic addition where we can see significant improvement over RoPE.
For any head, let be the attention from position to position , and be the value embeddings at position . Let's denote individual features as .
Normally, we would take the weight sum of value embeddings
This doesn't explicitly add any distance information about the positions to final result .
RoPER pairs features like RoPE and transform them. For a pair and it transforms them by . Let us donate the transformed features with . Then it rotates the weighted sum in the the reverse direction with . Note the .
Note that,
Final output after with the transformations is,
Note that .
Let's expand the first term ,
Simiarly we can show the second term is equal to,
Which gives,
That is, the weighted average of values rotated relative to current position.
Here's an experiment that uses RoPER on an arthmetic addition task.
118from typing import Optional
119
120import torch
121
122from labml_nn.transformers.rope import RotaryPositionalEmbeddings, RotaryPEMultiHeadAttention
This inherits from RoPE rotation implementation and changes the direction.
125class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
x
is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
132 def forward(self, x: torch.Tensor):
Cache and values
137 self._build_cache(x)
Split the features, we can choose to apply rotary embeddings only to a partial set of features.
140 x_rope, x_pass = x[..., :self.d], x[..., self.d:]
Calculate
144 neg_half_x = self._neg_half(x_rope)
160 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
163 return torch.cat((x_rope, x_pass), dim=-1)
We override multi-head attention from original transformer.
166class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
173 def __init__(self, heads: int, d_model: int,
174 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
175 dropout_prob: float = 0.0):
176 super().__init__(heads, d_model, rope_percentage, dropout_prob)
Rotary positional embedding layers
179 d_rope_value = int(self.d_k * rope_value_percentage)
180
181 self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
182 self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
query
, key
and value
are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model]
.
mask
has shape [seq_len, seq_len, batch_size]
and mask[i, j, b]
indicates whether for batch b
, query at position i
has access to key-value at position j
.
184 def forward(self, *,
185 query: torch.Tensor,
186 key: torch.Tensor,
187 value: torch.Tensor,
188 mask: Optional[torch.Tensor] = None):
query
, key
and value
have shape [seq_len, batch_size, d_model]
200 seq_len, batch_size, _ = query.shape
201
202 if mask is not None:
203 mask = self.prepare_mask(mask, query.shape, key.shape)
Prepare query
, key
and value
for attention computation. These will then have shape [seq_len, batch_size, heads, d_k]
.
207 query = self.query(query)
208 key = self.key(key)
209 value = self.value(value)
Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
.
213 scores = self.get_scores(query, key)
Scale scores
216 scores *= self.scale
Apply mask
219 if mask is not None:
220 scores = scores.masked_fill(mask == 0, float('-inf'))
attention along the key sequence dimension
224 attn = self.softmax(scores)
Apply dropout
227 attn = self.dropout(attn)
Rotate value embeddings before taking the weighted sum so that they contain positional information
230 value = self.value_rotary_pe(value)
Multiply by values
234 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
Rotate in the opposite direction so that each embedding hold the relative positions
237 x = self.value_reverse_rotary_pe(x)
Save attentions for any other calculations
240 self.attn = attn.detach()
Concatenate multiple heads
243 x = x.reshape(seq_len, batch_size, -1)
Output layer
246 return self.output(x)