これは、RETROモデルのデータベースを構築し、最も近い近傍を検索するものです。
論文ではScanNライブラリを使用していましたが、データベースにはFAISSライブラリを使用しています。
16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings
chunk_len
チャンクの長さ (文字数)batch_size
計算時に使用するバッチサイズです d_emb
FAISS インデックスで選択する埋め込みリスト内のフィーチャの数ですn_centeroids
インデックス内のリストの数ですcode_size
インデックス内のエンコードされたベクトルサイズn_probe
は調べるリストの数です27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28 code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):
データセットのテキストファイルを読み込む
43 dataset = TextFileDataset(
44 lab.get_data_path() / 'tiny_shakespeare.txt',
45 list,
46 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
トレーニングデータ (文字列) を取得
49 text = dataset.train
テキストを次のチャンクに分割します chunk_length
52 chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]
各チャンクのオフセットを取得
54 chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])
チャンクの数
56 n_chunks = len(chunks)
BERT を初期化して取得
59 bert = BERTChunkEmbeddings(torch.device('cuda:0'))
batch_size
各反復処理でチャンクの数を処理してチャンクの埋め込みを行う
62 chunk_emb = []
63 for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64 chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())
それらを単一のテンソルにマージします
66 chunk_emb = torch.cat(chunk_emb, dim=0).numpy()
69 quantizer = faiss.IndexFlatL2(d_emb)
70 index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71 index.nprobe = n_probe
チャンクインデックスのランダムサンプルを取得
74 random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)
キーを保存するようにインデックスをトレーニングする
77 with monit.section('Train index'):
78 index.train(chunk_emb[random_sample])
チャンクをサイズごとにインデックスに追加します。1024
81 for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82 e = min(s + 1024, n_chunks)
索引に追加
84 index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])
インデックスを保存する
87 with monit.section('Save'):
88 faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))
91class RetroIndex:
chunk_len
はチャンクの長さですn_probe
は調べるリストの数ですn_neighbors
取得する近傍の数ですn_extra
クエリチャンクと重複している近傍を削除することになるため、取得する余分な近傍の数ですexclude_neighbor_span
オーバーラップをチェックするときに避けるべき余分なテキスト長です96 def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97 n_neighbors: int = 2, n_extra: int = 2,
98 exclude_neighbor_span: int = 8):
108 self.n_neighbors = n_neighbors
109 self.chunk_len = chunk_len
110 self.exclude_neighbor_span = exclude_neighbor_span
111 self.n_extra = n_extra
BERT を初期化して取得
114 self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))
データベースを読み込む
116 with monit.section('Load index'):
117 self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118 self.index.nprobe = n_probe
120 def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
127 return [n for n in neighbor_offsets
128 if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129 or n > offset + (self.chunk_len + self.exclude_neighbor_span)]
131 def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):
クエリチャンクの取得
137 emb = self.bert(query_chunks).cpu()
n_neighbors + n_extra
データベースから最も近い近傍を取得
140 distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)
クエリのチャンクオフセットが指定されている場合は、重複するチャンクを除外します。
143 if offsets is not None:
144 neighbor_offsets = [self.filter_neighbors(off, n_off)
145 for off, n_off in zip(offsets, neighbor_offsets)]
n_neighbors
フィルタリング後に最も近いものを取得
148 neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]
151 return neighbor_offsets
155if __name__ == '__main__':
156 build_database()