14import math
15from typing import Set
16
17import torch
18from torch import nn
19
20from labml.logger import inspect
We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.
23class RotaryPositionalEmbeddings(nn.Module):
d
is the number of features base
is the constant used for calculating 34 def __init__(self, d: int, base: int = 10_000):
39 super().__init__()
41 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
x
is the Tensor at the head of a key or a query with shape [ batch_size, seq_len, n_heads, d]
43 def forward(self, x: torch.Tensor):
Extract the shape
48 batch_size, seq_len, n_heads, d = x.shape
51 d_2 = d // 2
Create position indexes [0, 1, ..., seq_len - 1]
54 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
Calculate the product of position index and
57 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
Concatenate so that for row we have
61 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
Calculate
65 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
77 rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
80 return rx
83class SelfAttention(nn.Module):
d_model
is the number of features in transformer embeddings n_heads
is the number of attention heads d_k
is the number of features per head is_causal
indicates whether this is causal attention (masked)90 def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
97 super().__init__()
98
99 self.is_causal = is_causal
100 self.n_heads = n_heads
101 self.d_k = d_k
To scale attentions before softmax by
104 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
107 self.query = nn.Linear(d_model, n_heads * d_k)
108 self.key = nn.Linear(d_model, n_heads * d_k)
109 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer. The paper uses RMSNorm instead.
112 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
115 self.softmax = nn.Softmax(dim=-1)
Rotary positional embeddings
118 self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
Final linear layer
121 self.output = nn.Linear(n_heads * d_k, d_model)
attn
is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]
123 def mask_attention(self, attn: torch.Tensor):
No masking for non-causal attention
131 if not self.is_causal:
132 return attn
Create a triangular mask
135 mask = torch.tril(attn.new_ones(attn.shape[-2:]))
Filter by the mask
137 return attn.masked_fill(mask == 0, float('-inf'))
h
is the transformer embeddings of shape [batch_size, seq_len, d_model]
139 def forward(self, h: torch.Tensor):
Residual connection
145 h_res = h
Pre-normalization
148 h = self.norm(h)
Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]
152 mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
153 q = self.query(h).view(mh_shape)
154 k = self.key(h).view(mh_shape)
155 v = self.value(h).view(mh_shape)
Apply rotary positional embeddings
158 q = self.rotary_pe(q)
159 k = self.rotary_pe(k)
Calculate attentions
162 attn = torch.einsum('bihd,bjhd->bhij', q, k)
Scale it by
164 attn = attn * self.scale
Apply masks if it's causal attention
167 attn = self.mask_attention(attn)
Calculate attention probabilities
170 attn = self.softmax(attn)
Get values
173 h = torch.einsum("bhij,bjhd->bihd", attn, v)
Change from shape [batch_size, seq_len, n_heads, d_k]
to [batch_size, seq_len, n_heads * d_k]
177 h = h.reshape(*h.shape[:-2], -1)
Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]
181 h = self.output(h)
Add the residual connection
184 return h + h_res
This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.
This is used in the encoder to encode the retrieved chunks based on the input chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
187class CrossAttention(nn.Module):
d_model
is the number of features in transformer embeddings n_heads
is the number of attention heads d_k
is the number of features per head201 def __init__(self, d_model: int, n_heads: int, d_k: int):
207 super().__init__()
208
209 self.n_heads = n_heads
210 self.d_k = d_k
To scale attentions before softmax by
213 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
216 self.query = nn.Linear(d_model, n_heads * d_k)
217 self.key = nn.Linear(d_model, n_heads * d_k)
218 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
221 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
224 self.softmax = nn.Softmax(dim=-1)
Final linear layer
227 self.output = nn.Linear(n_heads * d_k, d_model)
e
are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model]
h
are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model]
. This is already normalized.229 def forward(self, e: torch.Tensor, h: torch.Tensor):
Residual connection
238 e_res = e
Normalize retrieved chunks
241 e = self.norm(e)
Get query from the retrieved chunks
244 q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
Get keys and values from the input chunks
246 k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
247 v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
252 attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
Scale attention scores
254 attn = attn * self.scale
Calculate softmax across the last dimension
257 attn = self.softmax(attn)
Gather values
260 e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]
to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
264 e = e.reshape(*e.shape[:-2], -1)
Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]
268 e = self.output(e)
Add residual connection
271 return e + e_res
This is similar to the cross-attention layer defined above.
This is used in the decoder to pay attention to the retrieved neighbor chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
274class ChunkedCrossAttention(nn.Module):
d_model
is the number of features in transformer embeddings n_heads
is the number of attention heads d_k
is the number of features per head chunk_len
is the length of a chunk286 def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
294 super().__init__()
295
296 self.chunk_len = chunk_len
297 self.n_heads = n_heads
298 self.d_k = d_k
To scale attentions before softmax by
301 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
304 self.query = nn.Linear(d_model, n_heads * d_k)
305 self.key = nn.Linear(d_model, n_heads * d_k)
306 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
309 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
312 self.softmax = nn.Softmax(dim=-1)
Final linear layer
315 self.output = nn.Linear(n_heads * d_k, d_model)
h
are the input embeddings of shape [batch_size, seq_len, d_model]
e
are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
317 def forward(self, h: torch.Tensor, e: torch.Tensor):
Get shape
324 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
No attention if there are no chunks (for short inputs when sampling)
327 if chunks == 0:
328 return h
Residual connection
331 h_res = h
Remove the first chunk_len - 1
embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1
we make sure that information only flows to the right.
339 h = h[:, self.chunk_len - 1:]
Pre-norm
341 h = self.norm(h)
Append empty embeddings to the end to be able to split the input into chunks
343 if h.shape[1] < chunks * self.chunk_len:
344 h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
Reshape the input into chunks.
346 h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
Get query from the input
349 q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
Get keys and values from the retrieved neighbors
351 k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
352 v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
357 attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
Scale attention scores
359 attn = attn * self.scale
Apply softmax over the last two dimensions neighbors, neighbor_len
362 attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
Gather values
365 h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
Change from shape [batch_size, chunks, chunk_len, n_heads, d_k]
to [batch_size, chunks * chunk_len, n_heads * d_k]
369 h = h.reshape(batch_size, chunks * self.chunk_len, -1)
Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]
373 h = self.output(h)
Append chunk_len - 1
zero embedding to the left; i.e. right shift it back
376 h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
Truncate and add the residual connection
379 return h[:, :h_res.shape[1]] + h_res
This consists of two linear layers and an activation in the middle.
382class FeedForward(nn.Module):
d_model
is the number of features in transformer embeddings d_ff
is the number features in the hidden layer389 def __init__(self, d_model: int, d_ff: int):
395 super().__init__()
The two linear layers
398 self.lin1 = nn.Linear(d_model, d_ff)
399 self.lin2 = nn.Linear(d_ff, d_model)
ReLU Activation
402 self.act = nn.ReLU()
Pre-norm layer
405 self.norm = nn.LayerNorm(d_model)
h
are the embeddings of shape [batch_size, seq_len, d_model]
407 def forward(self, h: torch.Tensor):
Residual
413 h_res = h
Pre-norm
415 h = self.norm(h)
First linear layer
417 h = self.lin1(h)
Activation
419 h = self.act(h)
Second linear layer
421 h = self.lin2(h)
Add the residual connection
424 return h + h_res
427class NearestNeighborEncoder(nn.Module):
chunk_len
is the length of a chunk n_layer
is the number of layers in the encoder ca_layers
are the layers with cross attention d_model
is the number of features in embeddings n_heads
is the number of heads in attention layers d_k
is the size of attention heads d_ff
is the size of the feed-forward networks hidden layers434 def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
435 d_model: int, n_heads: int, d_k: int, d_ff: int):
446 super().__init__()
447 self.ca_layers = ca_layers
448 self.chunk_len = chunk_len
Cross-attention layers
450 self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
Bi-directional self attention layers
452 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
Feed forward layers
454 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
Pre-normalization layer for
457 self.norm_h = nn.LayerNorm(d_model)
e
are token embeddings of the retrieved nearest neighbors, of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
h
is are the input token embeddings, of shape [batch_size, seq_len, d_model]
The chunks and neighbors are processed in parallel.
459 def forward(self, e: torch.Tensor, h: torch.Tensor):
Get shape
472 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
475 h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
Pre-norm
478 h_split = self.norm_h(h_split)
Keep the index of the cross attention layer
481 p_ca = 0
For all layers
483 for p in range(len(self.attn)):
Bi-directional self attention
486 e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
Cross attention if
489 if p in self.ca_layers:
491 e = self.ca[p_ca](e, h_split)
Incremnt the cross attention index
493 p_ca += 1
Feed forward layer
496 e = self.ffw[p](e)
return
499 return e
502class RetroModel(nn.Module):
v_vocab
is the number of tokens in the vocabulary d_model
is the number of features in embeddings n_layers
is the number of layers in the decoder ca_layers
are the layers with cross attention chunk_len
is the length of a chunk n_heads
is the number of heads in attention layers d_k
is the size of attention heads d_ff
is the size of the feed-forward networks hidden layers encoder
is the nearest neighbor encoder509 def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
510 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
522 super().__init__()
523
524 self.ca_layers = ca_layers
525 self.encoder = encoder
Token embedding layer
528 self.emb = nn.Embedding(n_vocab, d_model)
Chunked cross attention layers
530 self.cca = nn.ModuleList(
531 [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
Attention layers
533 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
Feed forward layers
535 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
Readout layer
537 self.read = nn.Linear(d_model, n_vocab)
Pre-normalization layer for nearest neighbor embeddings from
541 self.norm_e = nn.LayerNorm(d_model)
x
is the input sequence, of shape [batch_size, seq_len]
ret
are the retrieved neighbors of shape [batch_size, chunks, neighbors, neighbor_len]
543 def forward(self, x: torch.Tensor, ret: torch.Tensor):
Get input embeddings
552 h = self.emb(x)
558 ret_emb = self.emb(ret)
Keep index of the chunked cross attention layer
561 p_ca = 0
For all layers
563 for p in range(len(self.attn)):
Causal self attention
565 h = self.attn[p](h)
Get encoder embeddings before the first layer, when
569 if self.ca_layers and p == min(self.ca_layers):
573 e = self.encoder(ret_emb, h)
Normalize encoder embeddings
575 e = self.norm_e(e)
Chunked-cross attention if
578 if p in self.ca_layers:
580 h = self.cca[p_ca](h, e)
Increment chunked cross-attention index
582 p_ca += 1
585 h = self.ffw[p](h)
588 return self.read(h)
591def _test():
595 chunk_len = 4
596 d_model = 8
597 d_ff = 32
598 n_heads = 2
599 d_k = 4
600
601 device = torch.device('cuda:0')
602
603 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
604 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
605
606 m.to(device)
607 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
608 ret = [
609 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
610 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
611 ]
612 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
613
614 inspect(res)
618if __name__ == '__main__':
619 _test()