最近傍検索用データベース

これは、RETROモデルのデータベースを構築し、最も近い近傍を検索するものです。

論文ではScanNライブラリを使用していましたが、データベースにはFAISSライブラリを使用しています。

16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings

データベースを構築

  • chunk_len チャンクの長さ (文字数)
  • batch_size 計算時に使用するバッチサイズです
  • d_emb FAISS インデックスで選択する埋め込みリスト内のフィーチャの数です
  • n_centeroids インデックス内のリストの数です
  • code_size インデックス内のエンコードされたベクトルサイズ
  • n_probe は調べるリストの数です
  • `n_train' はインデックスをトレーニングするキーの数です
27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):

データセットのテキストファイルを読み込む

43    dataset = TextFileDataset(
44        lab.get_data_path() / 'tiny_shakespeare.txt',
45        list,
46        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

トレーニングデータ (文字列) を取得

49    text = dataset.train

テキストを次のチャンクに分割します chunk_length

52    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]

各チャンクのオフセットを取得

54    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])

チャンクの数

56    n_chunks = len(chunks)

BERT を初期化して取得

59    bert = BERTChunkEmbeddings(torch.device('cuda:0'))

batch_size 各反復処理でチャンクの数を処理してチャンクの埋め込みを行う

62    chunk_emb = []
63    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())

それらを単一のテンソルにマージします

66    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()

FAISS インデックスを作成する

69    quantizer = faiss.IndexFlatL2(d_emb)
70    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71    index.nprobe = n_probe

チャンクインデックスのランダムサンプルを取得

74    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)

キーを保存するようにインデックスをトレーニングする

77    with monit.section('Train index'):
78        index.train(chunk_emb[random_sample])

チャンクをサイズごとにインデックスに追加します。1024

81    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82        e = min(s + 1024, n_chunks)

索引に追加

84        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])

インデックスを保存する

87    with monit.section('Save'):
88        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))

最近傍を検索するためのインデックス

91class RetroIndex:
  • chunk_len はチャンクの長さです
  • n_probe は調べるリストの数です
  • n_neighbors 取得する近傍の数です
  • n_extra クエリチャンクと重複している近傍を削除することになるため、取得する余分な近傍の数です
  • exclude_neighbor_span オーバーラップをチェックするときに避けるべき余分なテキスト長です
96    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97                 n_neighbors: int = 2, n_extra: int = 2,
98                 exclude_neighbor_span: int = 8):
108        self.n_neighbors = n_neighbors
109        self.chunk_len = chunk_len
110        self.exclude_neighbor_span = exclude_neighbor_span
111        self.n_extra = n_extra

BERT を初期化して取得

114        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))

データベースを読み込む

116        with monit.section('Load index'):
117            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118            self.index.nprobe = n_probe

クエリと重複する近傍をフィルタリングする

neighbor_offsets 近傍の位置はで指定され、クエリチャンクの位置はです。offset

120    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
127        return [n for n in neighbor_offsets
128                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]

最も近い近傍を検索する

131    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):

クエリチャンクの取得

137        emb = self.bert(query_chunks).cpu()

n_neighbors + n_extra データベースから最も近い近傍を取得

140        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)

クエリのチャンクオフセットが指定されている場合は、重複するチャンクを除外します。

143        if offsets is not None:
144            neighbor_offsets = [self.filter_neighbors(off, n_off)
145                                for off, n_off in zip(offsets, neighbor_offsets)]

n_neighbors フィルタリング後に最も近いものを取得

148        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]

151        return neighbor_offsets

155if __name__ == '__main__':
156    build_database()