最近邻检索的数据库

这是构建数据库并检索 RETRO 模型的最近邻域

我们使用 FAISS 库作为数据库,而论文使用了 ScanN 库。

16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings

建立数据库

  • chunk_len 是块的长度(字符数)
  • batch_size 是计算时要使用的批次大小
  • d_emb 是要在 FAISS 索引中选择的嵌入列表中要素的数量
  • n_centeroids 是索引中的列表数
  • code_size 索引中的编码向量大小
  • n_probe 是要探测的列表的数量
  • `n_train'是用于训练索引的键的数量
27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):

加载数据集文本文件

43    dataset = TextFileDataset(
44        lab.get_data_path() / 'tiny_shakespeare.txt',
45        list,
46        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

获取训练数据(字符串)

49    text = dataset.train

将文本分成大块chunk_length

52    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]

获取每个区块的偏移量

54    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])

区块数

56    n_chunks = len(chunks)

初始化 BERT 以获取

59    bert = BERTChunkEmbeddings(torch.device('cuda:0'))

通过处理每次迭代的块batch_size 数来获取区块嵌入

62    chunk_emb = []
63    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())

将它们合并成单个张量

66    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()

创建 FAISS 指数

69    quantizer = faiss.IndexFlatL2(d_emb)
70    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71    index.nprobe = n_probe

获取区块索引的随机样本

74    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)

训练索引以存储密钥

77    with monit.section('Train index'):
78        index.train(chunk_emb[random_sample])

将块按大小批次添加到索引中1024

81    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82        e = min(s + 1024, n_chunks)

添加到索引

84        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])

保存索引

87    with monit.section('Save'):
88        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))

检索最近邻居的索引

91class RetroIndex:
  • chunk_len 是区块长度
  • n_probe 是要探测的列表的数量
  • n_neighbors 是要检索的邻居数量
  • n_extra 是要检索的额外邻居的数量,因为我们将移除与查询块重叠的邻居
  • exclude_neighbor_span 是检查重叠时要避免的额外文本长度
96    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97                 n_neighbors: int = 2, n_extra: int = 2,
98                 exclude_neighbor_span: int = 8):
108        self.n_neighbors = n_neighbors
109        self.chunk_len = chunk_len
110        self.exclude_neighbor_span = exclude_neighbor_span
111        self.n_extra = n_extra

初始化 BERT 以获取

114        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))

装载数据库

116        with monit.section('Load index'):
117            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118            self.index.nprobe = n_probe

筛选与查询重叠的邻域

邻居的位置由给出neighbor_offsets ,查询块的位置为offset

120    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
127        return [n for n in neighbor_offsets
128                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]

检索最近的邻居

131    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):

获取查询区块

137        emb = self.bert(query_chunks).cpu()

从数据库中获取n_neighbors + n_extra 最近的邻居

140        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)

如果给出了查询区块偏移量,请过滤掉重叠的块

143        if offsets is not None:
144            neighbor_offsets = [self.filter_neighbors(off, n_off)
145                                for off, n_off in zip(offsets, neighbor_offsets)]

筛选n_neighbors 后获取最接近的值

148        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]

151        return neighbor_offsets

155if __name__ == '__main__':
156    build_database()