RETRO training dataset

We pre-retrieve nearest neighbors from the key-value database and create the dataset to train the RETRO model.

15import json
16from pathlib import Path
18import numpy as np
19import torch
20from import Dataset as PyTorchDataset
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndex

Build the dataset

  • chunk_len is the chunk length
  • chunks_per_sample is the number of chunks per training sample
  • skip_range is the maximum number of characters to skip between two samples. We skip a few characters between samples to make sure the samples aren't aligned perfectly with the chunks in the database
27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):

Load the text file

39    dataset = TextFileDataset(
40        lab.get_data_path() / 'tiny_shakespeare.txt',
41        list,
42        url='')

Training portion of it

45    text = dataset.train

Load the index for retrieving neighbors

48    index = RetroIndex()

The input sample offsets

51    sample_offsets = []

Cursor for the text

53    i = 0
54    while i < len(text):

Skip a few characters to make sure it's not aligned with the neighbors

56        skip = np.random.randint(skip_range)
57        i += skip

Stop if we've reached the end of the text

60        if i + chunks_per_sample * chunk_len > len(text):
61            break

Collect the offset

64        sample_offsets.append(i)

Increment the cursor

67        i += chunks_per_sample * chunk_len

For samples

70    samples = []

Iterate through sample offsets

72    for i in monit.iterate('Gather Neighbors', sample_offsets):

Get the sample including an extra character (for prediction)

74        sample = text[i: i + chunks_per_sample * chunk_len + 1]

The input

76        src = sample[:-1]

Break it into chunks

78        chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]

The chunk offsets

80        chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]

Retrieve nearest neighbors

83        neighbor_offsets = index(chunks, chunk_offsets)

Get neighbor texts. The neighbor length is twice the chunk_len

86        neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]

Add to list of samples

89        samples.append((sample[:-1], sample[1:], neighbors))

Save the samples in JSON. We don't need to use complex dataset storage mechanisms or pre-tokenize since our dataset is small.

94    with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95        f.write(json.dumps(samples))


This is the PyTorch dataset that loads the dataset created by build_dataset .

98class Dataset(PyTorchDataset):
  • file_path is the path of the saved JSON file
  • tds is the TextDataset
105    def __init__(self, file_path: Path, tds: TextDataset):
111        self.tds = tds

Load the samples

113        with open(str(file_path), 'r') as f:
114            self.samples = json.loads(

Number of samples

116    def __len__(self):
120        return len(self.samples)

Get a sample

122    def __getitem__(self, idx: int):

Get the sample

127        s = self.samples[idx]


129        src = self.tds.text_to_i(s[0])
130        tgt = self.tds.text_to_i(s[1])
131        neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])

133        return src, tgt, neighbors

136if __name__ == '__main__':
137    build_dataset()