为 k-nn 搜索建立 FAISS 索引

我们要建立的索引。我们在内存中存储映射的 numpy 数组。我们发现最接近使用 FAISS。FAISS 索引,我们使用进行查询

15from typing import Optional
16
17import faiss
18import numpy as np
19import torch
20
21from labml import experiment, monit, lab
22from labml.utils.pytorch import get_modules
23from labml_nn.transformers.knn.train_model import Configs

训练模型加载已保存的实验。

26def load_experiment(run_uuid: str, checkpoint: Optional[int] = None):

创建配置对象

32    conf = Configs()

加载实验中使用的自定义配置

34    conf_dict = experiment.load_configs(run_uuid)

我们需要获取前馈层的输入,

36    conf_dict['is_save_ff_input'] = True

这个实验只是一个评估;也就是说,没有追踪或保存任何东西

39    experiment.evaluate()

初始化配置

41    experiment.configs(conf, conf_dict)

设置用于保存/加载的模型

43    experiment.add_pytorch_models(get_modules(conf))

指定要从中加载的实验

45    experiment.load(run_uuid, checkpoint)

开始实验;这是它实际加载模型的时候

48    experiment.start()
49
50    return conf

将它们收集并保存在numpy数组中

请注意,这些 numpy 数组将占用大量空间(甚至几百千兆字节),具体取决于数据集的大小

53def gather_keys(conf: Configs):

的尺寸

62    d_model = conf.transformer.d_model

训练数据加载器

64    data_loader = conf.trainer.data_loader

上下文的数量;即训练数据中的令牌数减一。对于

67    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

Numpy 数组用于

69    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='w+', shape=(n_keys, d_model))

Numpy 数组用于

71    vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='w+', shape=(n_keys, 1))

收集的钥匙数量

74    added = 0
75    with torch.no_grad():

循环浏览数据

77        for i, batch in monit.enum("Collect data", data_loader, is_children_silent=True):

目标标签

79            vals = batch[1].view(-1, 1)

输入数据已移至模型的设备

81            data = batch[0].to(conf.device)

运行模型

83            _ = conf.model(data)

得到

85            keys = conf.model.ff_input.view(-1, d_model)

在内存映射的 numpy 数组中保存键

87            keys_store[added: added + keys.shape[0]] = keys.cpu()

在内存映射的 numpy 数组中保存值

89            vals_store[added: added + keys.shape[0]] = vals

增加收集的密钥数量

91            added += keys.shape[0]

建立 FAISS 指数

入门更快的搜索和更低的内存占用教程FAISS 将帮助您进一步了解 FAISS 的使用情况。

94def build_index(conf: Configs, n_centeroids: int = 2048, code_size: int = 64, n_probe: int = 8, n_train: int = 200_000):

的尺寸

104    d_model = conf.transformer.d_model

训练数据加载器

106    data_loader = conf.trainer.data_loader

上下文的数量;即训练数据中的令牌数减一。对于

109    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

使用基于Verenoi单元格的快速搜索构建索引,压缩不会存储完整向量。

113    quantizer = faiss.IndexFlatL2(d_model)
114    index = faiss.IndexIVFPQ(quantizer, d_model, n_centeroids, code_size, 8)
115    index.nprobe = n_probe

加载内存映射的 numpy 键数组

118    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))

随机选择一个键样本来训练索引

121    random_sample = np.random.choice(np.arange(n_keys), size=[min(n_train, n_keys)], replace=False)
122
123    with monit.section('Train index'):

训练索引以存储密钥

125        index.train(keys_store[random_sample])

在索引中添加密钥;

128    for s in monit.iterate('Index', range(0, n_keys, 1024)):
129        e = min(s + 1024, n_keys)

131        keys = keys_store[s:e]

133        idx = np.arange(s, e)

添加到索引

135        index.add_with_ids(keys, idx)
136
137    with monit.section('Save'):

保存索引

139        faiss.write_index(index, str(lab.get_data_path() / 'faiss.index'))
142def main():

加载实验。将 run uuid 替换为你在训练模型时运行的 uuid。

145    conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')

将模型设置为评估模式

147    conf.model.eval()

收集

150    gather_keys(conf)

将它们添加到索引中以便快速搜索

152    build_index(conf)
153
154
155if __name__ == '__main__':
156    main()