This is the build the database and retrieves nearest neighbors for RETRO model.
We use FAISS library for the database whilst the paper had used the SCaNN library.
16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings
chunk_len
is the length of a chunk (number of characters) batch_size
is the batch size to use when calculating d_emb
is the number of features in embeddings lists to select in FAISS index n_centeroids
is the number of lists in the index code_size
encoded vector size in the index n_probe
is the number of lists to probe 27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28 code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):
Load the dataset text file
43 dataset = TextFileDataset(
44 lab.get_data_path() / 'tiny_shakespeare.txt',
45 list,
46 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
Get training data (a string)
49 text = dataset.train
Split the text into chunks of chunk_length
52 chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]
Get the offsets of each of the chunks
54 chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])
Number of chunks
56 n_chunks = len(chunks)
Initialize BERT to get
59 bert = BERTChunkEmbeddings(torch.device('cuda:0'))
Get chunk embeddings by processing batch_size
number of chunks on each iteration
62 chunk_emb = []
63 for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64 chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())
Merge them into a single tensor
66 chunk_emb = torch.cat(chunk_emb, dim=0).numpy()
Create the FAISS index
69 quantizer = faiss.IndexFlatL2(d_emb)
70 index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71 index.nprobe = n_probe
Get a random sample of the the chunk indexes
74 random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)
Train the index to store the keys
77 with monit.section('Train index'):
78 index.train(chunk_emb[random_sample])
Add the chunks to the index in batches of size 1024
81 for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82 e = min(s + 1024, n_chunks)
Add to index
84 index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])
Save the index
87 with monit.section('Save'):
88 faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))
91class RetroIndex:
chunk_len
is the chunk length n_probe
is the number of lists to probe n_neighbors
is the number of neighbors to retrieve n_extra
is the number of extra neighbors to retrieve since we will be removing neighbors overlapping with the query chunk exclude_neighbor_span
is the extra text length to avoid when checking for overlaps96 def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97 n_neighbors: int = 2, n_extra: int = 2,
98 exclude_neighbor_span: int = 8):
108 self.n_neighbors = n_neighbors
109 self.chunk_len = chunk_len
110 self.exclude_neighbor_span = exclude_neighbor_span
111 self.n_extra = n_extra
Initialize BERT to get
114 self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))
Load the database
116 with monit.section('Load index'):
117 self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118 self.index.nprobe = n_probe
The positions of the neighbors are given by neighbor_offsets
and the position of the query chunk is offset
.
120 def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
127 return [n for n in neighbor_offsets
128 if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129 or n > offset + (self.chunk_len + self.exclude_neighbor_span)]
131 def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):
Get of query chunks
137 emb = self.bert(query_chunks).cpu()
Get n_neighbors + n_extra
nearest neighbors from the database
140 distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)
If the query chunk offsets are given filter out overlapping chunks
143 if offsets is not None:
144 neighbor_offsets = [self.filter_neighbors(off, n_off)
145 for off, n_off in zip(offsets, neighbor_offsets)]
Get the closest n_neighbors
after filtering
148 neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]
151 return neighbor_offsets
155if __name__ == '__main__':
156 build_database()