これは、RETROモデルのデータベースを構築し、最も近い近傍を検索するものです。
論文ではScanNライブラリを使用していましたが、データベースにはFAISSライブラリを使用しています。
16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddingschunk_len
チャンクの長さ (文字数)batch_size
計算時に使用するバッチサイズです d_emb
FAISS  インデックスで選択する埋め込みリスト内のフィーチャの数ですn_centeroids
インデックス内のリストの数ですcode_size
インデックス内のエンコードされたベクトルサイズn_probe
は調べるリストの数です27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):データセットのテキストファイルを読み込む
43    dataset = TextFileDataset(
44        lab.get_data_path() / 'tiny_shakespeare.txt',
45        list,
46        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')トレーニングデータ (文字列) を取得
49    text = dataset.trainテキストを次のチャンクに分割します chunk_length
52    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]各チャンクのオフセットを取得
54    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])チャンクの数
56    n_chunks = len(chunks)BERT を初期化して取得
59    bert = BERTChunkEmbeddings(torch.device('cuda:0'))batch_size
各反復処理でチャンクの数を処理してチャンクの埋め込みを行う
62    chunk_emb = []
63    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())それらを単一のテンソルにマージします
66    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()69    quantizer = faiss.IndexFlatL2(d_emb)
70    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71    index.nprobe = n_probeチャンクインデックスのランダムサンプルを取得
74    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)キーを保存するようにインデックスをトレーニングする
77    with monit.section('Train index'):
78        index.train(chunk_emb[random_sample])チャンクをサイズごとにインデックスに追加します。1024
81    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82        e = min(s + 1024, n_chunks)索引に追加
84        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])インデックスを保存する
87    with monit.section('Save'):
88        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))91class RetroIndex:chunk_len
はチャンクの長さですn_probe
は調べるリストの数ですn_neighbors
取得する近傍の数ですn_extra
クエリチャンクと重複している近傍を削除することになるため、取得する余分な近傍の数ですexclude_neighbor_span
オーバーラップをチェックするときに避けるべき余分なテキスト長です96    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97                 n_neighbors: int = 2, n_extra: int = 2,
98                 exclude_neighbor_span: int = 8):108        self.n_neighbors = n_neighbors
109        self.chunk_len = chunk_len
110        self.exclude_neighbor_span = exclude_neighbor_span
111        self.n_extra = n_extraBERT を初期化して取得
114        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))データベースを読み込む
116        with monit.section('Load index'):
117            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118            self.index.nprobe = n_probe120    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):127        return [n for n in neighbor_offsets
128                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]131    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):クエリチャンクの取得
137        emb = self.bert(query_chunks).cpu()n_neighbors + n_extra
データベースから最も近い近傍を取得
140        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)クエリのチャンクオフセットが指定されている場合は、重複するチャンクを除外します。
143        if offsets is not None:
144            neighbor_offsets = [self.filter_neighbors(off, n_off)
145                                for off, n_off in zip(offsets, neighbor_offsets)]n_neighbors
フィルタリング後に最も近いものを取得
148        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]151        return neighbor_offsets155if __name__ == '__main__':
156    build_database()