RETRO model

This is the model definition for RETRO.

14import math
15from typing import Set
16
17import torch
18from torch import nn
19
20from labml.logger import inspect

RoPE embeddings

We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.

23class RotaryPositionalEmbeddings(nn.Module):
  • d is the number of features
  • base is the constant used for calculating
34    def __init__(self, d: int, base: int = 10_000):
39        super().__init__()

41        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
  • x is the Tensor at the head of a key or a query with shape [ batch_size, seq_len, n_heads, d]
43    def forward(self, x: torch.Tensor):

Extract the shape

48        batch_size, seq_len, n_heads, d = x.shape

51        d_2 = d // 2

Create position indexes [0, 1, ..., seq_len - 1]

54        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

Calculate the product of position index and

57        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)

Concatenate so that for row we have

61        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

Calculate

65        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

Calculate

for

77        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])

80        return rx

Self-Attention Layer

This applies causal and non-causal multi-headed self-attention.

83class SelfAttention(nn.Module):
  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
  • is_causal indicates whether this is causal attention (masked)
90    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
97        super().__init__()
98
99        self.is_causal = is_causal
100        self.n_heads = n_heads
101        self.d_k = d_k

To scale attentions before softmax by

104        self.scale = 1 / math.sqrt(self.d_k)

Linear layers for query, key and value heads.

107        self.query = nn.Linear(d_model, n_heads * d_k)
108        self.key = nn.Linear(d_model, n_heads * d_k)
109        self.value = nn.Linear(d_model, n_heads * d_k)

Pre-norm layer. The paper uses RMSNorm instead.

112        self.norm = nn.LayerNorm(d_model)

Softmax for attention probabilities

115        self.softmax = nn.Softmax(dim=-1)

Rotary positional embeddings

118        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)

Final linear layer

121        self.output = nn.Linear(n_heads * d_k, d_model)

Mask the attention layer for causal attention

  • attn is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]
123    def mask_attention(self, attn: torch.Tensor):

No masking for non-causal attention

131        if not self.is_causal:
132            return attn

Create a triangular mask

135        mask = torch.tril(attn.new_ones(attn.shape[-2:]))

Filter by the mask

137        return attn.masked_fill(mask == 0, float('-inf'))
  • h is the transformer embeddings of shape [batch_size, seq_len, d_model]
139    def forward(self, h: torch.Tensor):

Residual connection

145        h_res = h

Pre-normalization

148        h = self.norm(h)

Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]

152        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
153        q = self.query(h).view(mh_shape)
154        k = self.key(h).view(mh_shape)
155        v = self.value(h).view(mh_shape)

Apply rotary positional embeddings

158        q = self.rotary_pe(q)
159        k = self.rotary_pe(k)

Calculate attentions

162        attn = torch.einsum('bihd,bjhd->bhij', q, k)

Scale it by

164        attn = attn * self.scale

Apply masks if it's causal attention

167        attn = self.mask_attention(attn)

Calculate attention probabilities

170        attn = self.softmax(attn)

Get values

173        h = torch.einsum("bhij,bjhd->bihd", attn, v)

Change from shape [batch_size, seq_len, n_heads, d_k] to [batch_size, seq_len, n_heads * d_k]

177        h = h.reshape(*h.shape[:-2], -1)

Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]

181        h = self.output(h)

Add the residual connection

184        return h + h_res

Cross-Attention Layer

This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.

This is used in the encoder to encode the retrieved chunks based on the input chunks.

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

187class CrossAttention(nn.Module):
  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
201    def __init__(self, d_model: int, n_heads: int, d_k: int):
207        super().__init__()
208
209        self.n_heads = n_heads
210        self.d_k = d_k

To scale attentions before softmax by

213        self.scale = 1 / math.sqrt(self.d_k)

Linear layers for query, key and value heads.

216        self.query = nn.Linear(d_model, n_heads * d_k)
217        self.key = nn.Linear(d_model, n_heads * d_k)
218        self.value = nn.Linear(d_model, n_heads * d_k)

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

221        self.norm = nn.LayerNorm(d_model)

Softmax for attention probabilities

224        self.softmax = nn.Softmax(dim=-1)

Final linear layer

227        self.output = nn.Linear(n_heads * d_k, d_model)
  • e are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model] . This is already normalized.
229    def forward(self, e: torch.Tensor, h: torch.Tensor):

Residual connection

238        e_res = e

Normalize retrieved chunks

241        e = self.norm(e)

Get query from the retrieved chunks

244        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)

Get keys and values from the input chunks

246        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
247        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)

Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]

252        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)

Scale attention scores

254        attn = attn * self.scale

Calculate softmax across the last dimension

257        attn = self.softmax(attn)

Gather values

260        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)

Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]

264        e = e.reshape(*e.shape[:-2], -1)

Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]

268        e = self.output(e)

Add residual connection

271        return e + e_res

Chunked Cross-Attention Layer

This is similar to the cross-attention layer defined above.

This is used in the decoder to pay attention to the retrieved neighbor chunks.

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

274class ChunkedCrossAttention(nn.Module):
  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
  • chunk_len is the length of a chunk
286    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
294        super().__init__()
295
296        self.chunk_len = chunk_len
297        self.n_heads = n_heads
298        self.d_k = d_k

To scale attentions before softmax by

301        self.scale = 1 / math.sqrt(self.d_k)

Linear layers for query, key and value heads.

304        self.query = nn.Linear(d_model, n_heads * d_k)
305        self.key = nn.Linear(d_model, n_heads * d_k)
306        self.value = nn.Linear(d_model, n_heads * d_k)

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

309        self.norm = nn.LayerNorm(d_model)

Softmax for attention probabilities

312        self.softmax = nn.Softmax(dim=-1)

Final linear layer

315        self.output = nn.Linear(n_heads * d_k, d_model)

h are the input embeddings of shape [batch_size, seq_len, d_model] e are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]

317    def forward(self, h: torch.Tensor, e: torch.Tensor):

Get shape

324        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

No attention if there are no chunks (for short inputs when sampling)

327        if chunks == 0:
328            return h

Residual connection

331        h_res = h

Remove the first chunk_len - 1 embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1 we make sure that information only flows to the right.

339        h = h[:, self.chunk_len - 1:]

Pre-norm

341        h = self.norm(h)

Append empty embeddings to the end to be able to split the input into chunks

343        if h.shape[1] < chunks * self.chunk_len:
344            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)

Reshape the input into chunks.

346        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)

Get query from the input

349        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)

Get keys and values from the retrieved neighbors

351        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
352        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)

Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]

357        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)

Scale attention scores

359        attn = attn * self.scale

Apply softmax over the last two dimensions neighbors, neighbor_len

362        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)

Gather values

365        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)

Change from shape [batch_size, chunks, chunk_len, n_heads, d_k] to [batch_size, chunks * chunk_len, n_heads * d_k]

369        h = h.reshape(batch_size, chunks * self.chunk_len, -1)

Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]

373        h = self.output(h)

Append chunk_len - 1 zero embedding to the left; i.e. right shift it back

376        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)

Truncate and add the residual connection

379        return h[:, :h_res.shape[1]] + h_res

Position-wise Feed Forward Layer

This consists of two linear layers and an activation in the middle.

382class FeedForward(nn.Module):
  • d_model is the number of features in transformer embeddings
  • d_ff is the number features in the hidden layer
389    def __init__(self, d_model: int, d_ff: int):
395        super().__init__()

The two linear layers

398        self.lin1 = nn.Linear(d_model, d_ff)
399        self.lin2 = nn.Linear(d_ff, d_model)

ReLU Activation

402        self.act = nn.ReLU()

Pre-norm layer

405        self.norm = nn.LayerNorm(d_model)

h are the embeddings of shape [batch_size, seq_len, d_model]

407    def forward(self, h: torch.Tensor):

Residual

413        h_res = h

Pre-norm

415        h = self.norm(h)

First linear layer

417        h = self.lin1(h)

Activation

419        h = self.act(h)

Second linear layer

421        h = self.lin2(h)

Add the residual connection

424        return h + h_res

Nearest Neighbor Encoder

This module encodes the retrieved nearest neighbors

427class NearestNeighborEncoder(nn.Module):
  • chunk_len is the length of a chunk
  • n_layer is the number of layers in the encoder
  • ca_layers are the layers with cross attention
  • d_model is the number of features in embeddings
  • n_heads is the number of heads in attention layers
  • d_k is the size of attention heads
  • d_ff is the size of the feed-forward networks hidden layers
434    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
435                 d_model: int, n_heads: int, d_k: int, d_ff: int):
446        super().__init__()
447        self.ca_layers = ca_layers
448        self.chunk_len = chunk_len

Cross-attention layers

450        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])

Bi-directional self attention layers

452        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])

Feed forward layers

454        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

Pre-normalization layer for

457        self.norm_h = nn.LayerNorm(d_model)
  • e are token embeddings of the retrieved nearest neighbors, of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h is are the input token embeddings, of shape [batch_size, seq_len, d_model]

The chunks and neighbors are processed in parallel.

459    def forward(self, e: torch.Tensor, h: torch.Tensor):

Get shape

472        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

475        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)

Pre-norm

478        h_split = self.norm_h(h_split)

Keep the index of the cross attention layer

481        p_ca = 0

For all layers

483        for p in range(len(self.attn)):

Bi-directional self attention

486            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)

Cross attention if

489            if p in self.ca_layers:

491                e = self.ca[p_ca](e, h_split)

Incremnt the cross attention index

493                p_ca += 1

Feed forward layer

496            e = self.ffw[p](e)

return

499        return e

Retro Model

This is the Retro decoder

502class RetroModel(nn.Module):
  • v_vocab is the number of tokens in the vocabulary
  • d_model is the number of features in embeddings
  • n_layers is the number of layers in the decoder
  • ca_layers are the layers with cross attention
  • chunk_len is the length of a chunk
  • n_heads is the number of heads in attention layers
  • d_k is the size of attention heads
  • d_ff is the size of the feed-forward networks hidden layers
  • encoder is the nearest neighbor encoder
509    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
510                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
522        super().__init__()
523
524        self.ca_layers = ca_layers
525        self.encoder = encoder

Token embedding layer

528        self.emb = nn.Embedding(n_vocab, d_model)

Chunked cross attention layers

530        self.cca = nn.ModuleList(
531            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])

Attention layers

533        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])

Feed forward layers

535        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

Readout layer

537        self.read = nn.Linear(d_model, n_vocab)

Pre-normalization layer for nearest neighbor embeddings from

541        self.norm_e = nn.LayerNorm(d_model)
  • x is the input sequence, of shape [batch_size, seq_len]
  • ret are the retrieved neighbors of shape [batch_size, chunks, neighbors, neighbor_len]
543    def forward(self, x: torch.Tensor, ret: torch.Tensor):

Get input embeddings

552        h = self.emb(x)

Embeddings of the retrieved neighbors .

We use same embeddings for both input and neighbors

558        ret_emb = self.emb(ret)

Keep index of the chunked cross attention layer

561        p_ca = 0

For all layers

563        for p in range(len(self.attn)):

Causal self attention

565            h = self.attn[p](h)

Get encoder embeddings before the first layer, when

569            if self.ca_layers and p == min(self.ca_layers):

We passed the embeddings of to encoder.

573                e = self.encoder(ret_emb, h)

Normalize encoder embeddings

575                e = self.norm_e(e)

Chunked-cross attention if

578            if p in self.ca_layers:

580                h = self.cca[p_ca](h, e)

Increment chunked cross-attention index

582                p_ca += 1

585            h = self.ffw[p](h)

588        return self.read(h)

Test the model with fake data

591def _test():
595    chunk_len = 4
596    d_model = 8
597    d_ff = 32
598    n_heads = 2
599    d_k = 4
600
601    device = torch.device('cuda:0')
602
603    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
604                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
605
606    m.to(device)
607    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
608    ret = [
609        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
610        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
611    ]
612    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
613
614    inspect(res)

618if __name__ == '__main__':
619    _test()