Database for nearest neighbor retrieval

This is the build the database and retrieves nearest neighbors for RETRO model.

We use FAISS library for the database whilst the paper had used the SCaNN library.

16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings

Build Database

  • chunk_len is the length of a chunk (number of characters)
  • batch_size is the batch size to use when calculating
  • d_emb is the number of features in embeddings lists to select in FAISS index
  • n_centeroids is the number of lists in the index
  • code_size encoded vector size in the index
  • n_probe is the number of lists to probe
  • `n_train' is the number of keys to train the index on
27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):

Load the dataset text file

43    dataset = TextFileDataset(
44        lab.get_data_path() / 'tiny_shakespeare.txt',
45        list,
46        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

Get training data (a string)

49    text = dataset.train

Split the text into chunks of chunk_length

52    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]

Get the offsets of each of the chunks

54    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])

Number of chunks

56    n_chunks = len(chunks)

Initialize BERT to get

59    bert = BERTChunkEmbeddings(torch.device('cuda:0'))

Get chunk embeddings by processing batch_size number of chunks on each iteration

62    chunk_emb = []
63    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())

Merge them into a single tensor

66    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()

Create the FAISS index

69    quantizer = faiss.IndexFlatL2(d_emb)
70    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71    index.nprobe = n_probe

Get a random sample of the the chunk indexes

74    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)

Train the index to store the keys

77    with monit.section('Train index'):
78        index.train(chunk_emb[random_sample])

Add the chunks to the index in batches of size 1024

81    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82        e = min(s + 1024, n_chunks)

Add to index

84        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])

Save the index

87    with monit.section('Save'):
88        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))

Index for retrieving nearest neighbors

91class RetroIndex:
  • chunk_len is the chunk length
  • n_probe is the number of lists to probe
  • n_neighbors is the number of neighbors to retrieve
  • n_extra is the number of extra neighbors to retrieve since we will be removing neighbors overlapping with the query chunk
  • exclude_neighbor_span is the extra text length to avoid when checking for overlaps
96    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97                 n_neighbors: int = 2, n_extra: int = 2,
98                 exclude_neighbor_span: int = 8):
108        self.n_neighbors = n_neighbors
109        self.chunk_len = chunk_len
110        self.exclude_neighbor_span = exclude_neighbor_span
111        self.n_extra = n_extra

Initialize BERT to get

114        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))

Load the database

116        with monit.section('Load index'):
117            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118            self.index.nprobe = n_probe

Filter neighbors that overlap with the query

The positions of the neighbors are given by neighbor_offsets and the position of the query chunk is offset .

120    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
127        return [n for n in neighbor_offsets
128                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]

Retrieve nearest neighbors

131    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):

Get of query chunks

137        emb = self.bert(query_chunks).cpu()

Get n_neighbors + n_extra nearest neighbors from the database

140        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)

If the query chunk offsets are given filter out overlapping chunks

143        if offsets is not None:
144            neighbor_offsets = [self.filter_neighbors(off, n_off)
145                                for off, n_off in zip(offsets, neighbor_offsets)]

Get the closest n_neighbors after filtering

148        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]

151        return neighbor_offsets

155if __name__ == '__main__':
156    build_database()