15from typing import Optional
16
17import faiss
18import numpy as np
19import torch
20
21from labml import experiment, monit, lab
22from labml.utils.pytorch import get_modules
23from labml_nn.transformers.knn.train_model import Configs
创建配置对象
32 conf = Configs()
加载实验中使用的自定义配置
34 conf_dict = experiment.load_configs(run_uuid)
我们需要获取前馈层的输入,
36 conf_dict['is_save_ff_input'] = True
这个实验只是一个评估;也就是说,没有追踪或保存任何东西
39 experiment.evaluate()
初始化配置
41 experiment.configs(conf, conf_dict)
设置用于保存/加载的模型
43 experiment.add_pytorch_models(get_modules(conf))
指定要从中加载的实验
45 experiment.load(run_uuid, checkpoint)
开始实验;这是它实际加载模型的时候
48 experiment.start()
49
50 return conf
53def gather_keys(conf: Configs):
的尺寸
62 d_model = conf.transformer.d_model
训练数据加载器
64 data_loader = conf.trainer.data_loader
上下文的数量;即训练数据中的令牌数减一。对于
67 n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1
Numpy 数组用于
69 keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='w+', shape=(n_keys, d_model))
Numpy 数组用于
71 vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='w+', shape=(n_keys, 1))
收集的钥匙数量
74 added = 0
75 with torch.no_grad():
循环浏览数据
77 for i, batch in monit.enum("Collect data", data_loader, is_children_silent=True):
目标标签
79 vals = batch[1].view(-1, 1)
输入数据已移至模型的设备
81 data = batch[0].to(conf.device)
运行模型
83 _ = conf.model(data)
得到
85 keys = conf.model.ff_input.view(-1, d_model)
在内存映射的 numpy 数组中保存键
87 keys_store[added: added + keys.shape[0]] = keys.cpu()
在内存映射的 numpy 数组中保存值
89 vals_store[added: added + keys.shape[0]] = vals
增加收集的密钥数量
91 added += keys.shape[0]
94def build_index(conf: Configs, n_centeroids: int = 2048, code_size: int = 64, n_probe: int = 8, n_train: int = 200_000):
的尺寸
104 d_model = conf.transformer.d_model
训练数据加载器
106 data_loader = conf.trainer.data_loader
上下文的数量;即训练数据中的令牌数减一。对于
109 n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1
使用基于Verenoi单元格的快速搜索构建索引,压缩不会存储完整向量。
113 quantizer = faiss.IndexFlatL2(d_model)
114 index = faiss.IndexIVFPQ(quantizer, d_model, n_centeroids, code_size, 8)
115 index.nprobe = n_probe
加载内存映射的 numpy 键数组
118 keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))
随机选择一个键样本来训练索引
121 random_sample = np.random.choice(np.arange(n_keys), size=[min(n_train, n_keys)], replace=False)
122
123 with monit.section('Train index'):
训练索引以存储密钥
125 index.train(keys_store[random_sample])
在索引中添加密钥;
128 for s in monit.iterate('Index', range(0, n_keys, 1024)):
129 e = min(s + 1024, n_keys)
131 keys = keys_store[s:e]
133 idx = np.arange(s, e)
添加到索引
135 index.add_with_ids(keys, idx)
136
137 with monit.section('Save'):
保存索引
139 faiss.write_index(index, str(lab.get_data_path() / 'faiss.index'))
142def main():
145 conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')
将模型设置为评估模式
147 conf.model.eval()
收集
150 gather_keys(conf)
将它们添加到索引中以便快速搜索
152 build_index(conf)
153
154
155if __name__ == '__main__':
156 main()