16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddingschunk_len
是块的长度(字符数)batch_size
是计算时要使用的批次大小d_emb
是要在 FAISS 索引中选择的嵌入列表中要素的数量n_centeroids
是索引中的列表数code_size
索引中的编码向量大小n_probe
是要探测的列表的数量27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):加载数据集文本文件
43    dataset = TextFileDataset(
44        lab.get_data_path() / 'tiny_shakespeare.txt',
45        list,
46        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')获取训练数据(字符串)
49    text = dataset.train将文本分成大块chunk_length
52    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]获取每个区块的偏移量
54    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])区块数
56    n_chunks = len(chunks)初始化 BERT 以获取
59    bert = BERTChunkEmbeddings(torch.device('cuda:0'))通过处理每次迭代的块batch_size
数来获取区块嵌入
62    chunk_emb = []
63    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())将它们合并成单个张量
66    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()69    quantizer = faiss.IndexFlatL2(d_emb)
70    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71    index.nprobe = n_probe获取区块索引的随机样本
74    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)训练索引以存储密钥
77    with monit.section('Train index'):
78        index.train(chunk_emb[random_sample])将块按大小批次添加到索引中1024
81    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82        e = min(s + 1024, n_chunks)添加到索引
84        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])保存索引
87    with monit.section('Save'):
88        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))91class RetroIndex:chunk_len
是区块长度n_probe
是要探测的列表的数量n_neighbors
是要检索的邻居数量n_extra
是要检索的额外邻居的数量,因为我们将移除与查询块重叠的邻居exclude_neighbor_span
是检查重叠时要避免的额外文本长度96    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97                 n_neighbors: int = 2, n_extra: int = 2,
98                 exclude_neighbor_span: int = 8):108        self.n_neighbors = n_neighbors
109        self.chunk_len = chunk_len
110        self.exclude_neighbor_span = exclude_neighbor_span
111        self.n_extra = n_extra初始化 BERT 以获取
114        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))装载数据库
116        with monit.section('Load index'):
117            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118            self.index.nprobe = n_probe120    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):127        return [n for n in neighbor_offsets
128                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]131    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):获取查询区块
137        emb = self.bert(query_chunks).cpu()从数据库中获取n_neighbors + n_extra
最近的邻居
140        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)如果给出了查询区块偏移量,请过滤掉重叠的块
143        if offsets is not None:
144            neighbor_offsets = [self.filter_neighbors(off, n_off)
145                                for off, n_off in zip(offsets, neighbor_offsets)]筛选n_neighbors
后获取最接近的值
148        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]151        return neighbor_offsets155if __name__ == '__main__':
156    build_database()