16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings
chunk_len
是块的长度(字符数)batch_size
是计算时要使用的批次大小d_emb
是要在 FAISS 索引中选择的嵌入列表中要素的数量n_centeroids
是索引中的列表数code_size
索引中的编码向量大小n_probe
是要探测的列表的数量27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28 code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):
加载数据集文本文件
43 dataset = TextFileDataset(
44 lab.get_data_path() / 'tiny_shakespeare.txt',
45 list,
46 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
获取训练数据(字符串)
49 text = dataset.train
将文本分成大块chunk_length
52 chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]
获取每个区块的偏移量
54 chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])
区块数
56 n_chunks = len(chunks)
初始化 BERT 以获取
59 bert = BERTChunkEmbeddings(torch.device('cuda:0'))
通过处理每次迭代的块batch_size
数来获取区块嵌入
62 chunk_emb = []
63 for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64 chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())
将它们合并成单个张量
66 chunk_emb = torch.cat(chunk_emb, dim=0).numpy()
69 quantizer = faiss.IndexFlatL2(d_emb)
70 index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71 index.nprobe = n_probe
获取区块索引的随机样本
74 random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)
训练索引以存储密钥
77 with monit.section('Train index'):
78 index.train(chunk_emb[random_sample])
将块按大小批次添加到索引中1024
81 for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82 e = min(s + 1024, n_chunks)
添加到索引
84 index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])
保存索引
87 with monit.section('Save'):
88 faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))
91class RetroIndex:
chunk_len
是区块长度n_probe
是要探测的列表的数量n_neighbors
是要检索的邻居数量n_extra
是要检索的额外邻居的数量,因为我们将移除与查询块重叠的邻居exclude_neighbor_span
是检查重叠时要避免的额外文本长度96 def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97 n_neighbors: int = 2, n_extra: int = 2,
98 exclude_neighbor_span: int = 8):
108 self.n_neighbors = n_neighbors
109 self.chunk_len = chunk_len
110 self.exclude_neighbor_span = exclude_neighbor_span
111 self.n_extra = n_extra
初始化 BERT 以获取
114 self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))
装载数据库
116 with monit.section('Load index'):
117 self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118 self.index.nprobe = n_probe
120 def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
127 return [n for n in neighbor_offsets
128 if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129 or n > offset + (self.chunk_len + self.exclude_neighbor_span)]
131 def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):
获取查询区块
137 emb = self.bert(query_chunks).cpu()
从数据库中获取n_neighbors + n_extra
最近的邻居
140 distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)
如果给出了查询区块偏移量,请过滤掉重叠的块
143 if offsets is not None:
144 neighbor_offsets = [self.filter_neighbors(off, n_off)
145 for off, n_off in zip(offsets, neighbor_offsets)]
筛选n_neighbors
后获取最接近的值
148 neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]
151 return neighbor_offsets
155if __name__ == '__main__':
156 build_database()