We pre-retrieve nearest neighbors from the key-value database and create the dataset to train the RETRO model.
15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndex
chunk_len
is the chunk length chunks_per_sample
is the number of chunks per training sample skip_range
is the maximum number of characters to skip between two samples. We skip a few characters between samples to make sure the samples aren't aligned perfectly with the chunks in the database27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):
Load the text file
39 dataset = TextFileDataset(
40 lab.get_data_path() / 'tiny_shakespeare.txt',
41 list,
42 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
Training portion of it
45 text = dataset.train
Load the index for retrieving neighbors
48 index = RetroIndex()
The input sample offsets
51 sample_offsets = []
Cursor for the text
53 i = 0
54 while i < len(text):
Skip a few characters to make sure it's not aligned with the neighbors
56 skip = np.random.randint(skip_range)
57 i += skip
Stop if we've reached the end of the text
60 if i + chunks_per_sample * chunk_len > len(text):
61 break
Collect the offset
64 sample_offsets.append(i)
Increment the cursor
67 i += chunks_per_sample * chunk_len
For samples
70 samples = []
Iterate through sample offsets
72 for i in monit.iterate('Gather Neighbors', sample_offsets):
Get the sample including an extra character (for prediction)
74 sample = text[i: i + chunks_per_sample * chunk_len + 1]
The input
76 src = sample[:-1]
Break it into chunks
78 chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]
The chunk offsets
80 chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]
Retrieve nearest neighbors
83 neighbor_offsets = index(chunks, chunk_offsets)
Get neighbor texts. The neighbor length is twice the chunk_len
86 neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]
Add to list of samples
89 samples.append((sample[:-1], sample[1:], neighbors))
Save the samples in JSON. We don't need to use complex dataset storage mechanisms or pre-tokenize since our dataset is small.
94 with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95 f.write(json.dumps(samples))
98class Dataset(PyTorchDataset):
file_path
is the path of the saved JSON file tds
is the TextDataset
105 def __init__(self, file_path: Path, tds: TextDataset):
111 self.tds = tds
Load the samples
113 with open(str(file_path), 'r') as f:
114 self.samples = json.loads(f.read())
Number of samples
116 def __len__(self):
120 return len(self.samples)
Get a sample
122 def __getitem__(self, idx: int):
Get the sample
127 s = self.samples[idx]
Tokenize
129 src = self.tds.text_to_i(s[0])
130 tgt = self.tds.text_to_i(s[1])
131 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])
133 return src, tgt, neighbors
136if __name__ == '__main__':
137 build_dataset()