Build FAISS index for k-NN search

We want to build the index of . We store and in memory mapped numpy arrays. We find nearest to using FAISS. FAISS indexes and we query it with .

15from typing import Optional
17import faiss
18import numpy as np
19import torch
21from labml import experiment, monit, lab
22from labml.utils.pytorch import get_modules
23from labml_nn.transformers.knn.train_model import Configs

Load a saved experiment from train model.

26def load_experiment(run_uuid: str, checkpoint: Optional[int] = None):

Create configurations object

32    conf = Configs()

Load custom configurations used in the experiment

34    conf_dict = experiment.load_configs(run_uuid)

We need to get inputs to the feed forward layer,

36    conf_dict['is_save_ff_input'] = True

This experiment is just an evaluation; i.e. nothing is tracked or saved

39    experiment.evaluate()

Initialize configurations

41    experiment.configs(conf, conf_dict)

Set models for saving/loading

43    experiment.add_pytorch_models(get_modules(conf))

Specify the experiment to load from

45    experiment.load(run_uuid, checkpoint)

Start the experiment; this is when it actually loads models

48    experiment.start()
50    return conf

Gather and save them in numpy arrays

Note that these numpy arrays will take up a lot of space (even few hundred gigabytes) depending on the size of your dataset.

53def gather_keys(conf: Configs):

Dimensions of

62    d_model = conf.transformer.d_model

Training data loader

64    data_loader = conf.trainer.data_loader

Number of contexts; i.e. number of tokens in the training data minus one. for

67    n_keys =[0] *[1] - 1

Numpy array for

69    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='w+', shape=(n_keys, d_model))

Numpy array for

71    vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'),, mode='w+', shape=(n_keys, 1))

Number of keys collected

74    added = 0
75    with torch.no_grad():

Loop through data

77        for i, batch in monit.enum("Collect data", data_loader, is_children_silent=True):

the target labels

79            vals = batch[1].view(-1, 1)

Input data moved to the device of the model

81            data = batch[0].to(conf.device)

Run the model

83            _ = conf.model(data)


85            keys = conf.model.ff_input.view(-1, d_model)

Save keys, in the memory mapped numpy array

87            keys_store[added: added + keys.shape[0]] = keys.cpu()

Save values, in the memory mapped numpy array

89            vals_store[added: added + keys.shape[0]] = vals

Increment the number of collected keys

91            added += keys.shape[0]

Build FAISS index

Getting started, faster search, and lower memory footprint tutorials on FAISS will help you learn more about FAISS usage.

94def build_index(conf: Configs, n_centeroids: int = 2048, code_size: int = 64, n_probe: int = 8, n_train: int = 200_000):

Dimensions of

104    d_model = conf.transformer.d_model

Training data loader

106    data_loader = conf.trainer.data_loader

Number of contexts; i.e. number of tokens in the training data minus one. for

109    n_keys =[0] *[1] - 1

Build an index with Verenoi cell based faster search with compression that doesn't store full vectors.

113    quantizer = faiss.IndexFlatL2(d_model)
114    index = faiss.IndexIVFPQ(quantizer, d_model, n_centeroids, code_size, 8)
115    index.nprobe = n_probe

Load the memory mapped numpy array of keys

118    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))

Pick a random sample of keys to train the index with

121    random_sample = np.random.choice(np.arange(n_keys), size=[min(n_train, n_keys)], replace=False)
123    with monit.section('Train index'):

Train the index to store the keys

125        index.train(keys_store[random_sample])

Add keys to the index;

128    for s in monit.iterate('Index', range(0, n_keys, 1024)):
129        e = min(s + 1024, n_keys)

131        keys = keys_store[s:e]

133        idx = np.arange(s, e)

Add to index

135        index.add_with_ids(keys, idx)
137    with monit.section('Save'):

Save the index

139        faiss.write_index(index, str(lab.get_data_path() / 'faiss.index'))
142def main():

Load the experiment. Replace the run uuid with you run uuid from training the model.

145    conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')

Set model to evaluation mode

147    conf.model.eval()


150    gather_keys(conf)

Add them to the index for fast search

152    build_index(conf)
155if __name__ == '__main__':
156    main()