k-NN 検索用の FAISS インデックスを作成

のインデックスを作成したい.メモリにマップされたnumpy配列を格納します。FAISSを使うのに一番近いと思いますFAISSはインデックスを作成し、クエリを実行します。

15from typing import Optional
16
17import faiss
18import numpy as np
19import torch
20
21from labml import experiment, monit, lab
22from labml.utils.pytorch import get_modules
23from labml_nn.transformers.knn.train_model import Configs
26def load_experiment(run_uuid: str, checkpoint: Optional[int] = None):

設定オブジェクトの作成

32    conf = Configs()

実験で使用したカスタム構成を読み込む

34    conf_dict = experiment.load_configs(run_uuid)

フィードフォワード層への入力が必要ですが

36    conf_dict['is_save_ff_input'] = True

この実験は単なる評価です。つまり、何も追跡も保存もされていません

39    experiment.evaluate()

構成を初期化

41    experiment.configs(conf, conf_dict)

保存/読み込み用のモデルを設定

43    experiment.add_pytorch_models(get_modules(conf))

ロード元のテストを指定してください

45    experiment.load(run_uuid, checkpoint)

実験を開始します。この時点で、実際にモデルが読み込まれます。

48    experiment.start()
49
50    return conf

それらを集めてnumpy配列に保存する

これらの大量の配列は、データセットのサイズにもよりますが、(数百ギガバイトでも)多くのスペースを占めることに注意してください

53def gather_keys(conf: Configs):

の寸法

62    d_model = conf.transformer.d_model

トレーニングデータローダー

64    data_loader = conf.trainer.data_loader

コンテキストの数。つまり、トレーニングデータ内のトークン数から 1 を引いた数です。

にとって
67    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

のナンピー配列

69    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='w+', shape=(n_keys, d_model))

のナンピー配列

71    vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='w+', shape=(n_keys, 1))

収集されたキーの数

74    added = 0
75    with torch.no_grad():

データをループスルーする

77        for i, batch in monit.enum("Collect data", data_loader, is_children_silent=True):

ターゲットラベル

79            vals = batch[1].view(-1, 1)

入力データをモデルのデバイスに移動

81            data = batch[0].to(conf.device)

モデルを実行

83            _ = conf.model(data)

取得

85            keys = conf.model.ff_input.view(-1, d_model)

キーをメモリマップされたnumpy配列に保存

87            keys_store[added: added + keys.shape[0]] = keys.cpu()

値をメモリマップされたnumpy配列に保存する

89            vals_store[added: added + keys.shape[0]] = vals

収集したキーの数を増やしてください

91            added += keys.shape[0]

FAISS インデックスをビルド

FAISSの使い方高速検索メモリ使用量の削減に関するチュートリアルは、FAISSの使用法についてさらに学ぶのに役立ちます。

94def build_index(conf: Configs, n_centeroids: int = 2048, code_size: int = 64, n_probe: int = 8, n_train: int = 200_000):

の寸法

104    d_model = conf.transformer.d_model

トレーニングデータローダー

106    data_loader = conf.trainer.data_loader

コンテキストの数。つまり、トレーニングデータ内のトークン数から 1 を引いた数です。

にとって
109    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

Verenoi のセルベースの高速検索でインデックスを構築します。圧縮ではベクトル全体は保存されません。

113    quantizer = faiss.IndexFlatL2(d_model)
114    index = faiss.IndexIVFPQ(quantizer, d_model, n_centeroids, code_size, 8)
115    index.nprobe = n_probe

メモリマップされたキーのnumpy配列をロードします

118    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))

インデックスのトレーニングに使用するキーのサンプルをランダムに選んでください

121    random_sample = np.random.choice(np.arange(n_keys), size=[min(n_train, n_keys)], replace=False)
122
123    with monit.section('Train index'):

キーを保存するようにインデックスをトレーニングする

125        index.train(keys_store[random_sample])

インデックスにキーを追加します。

128    for s in monit.iterate('Index', range(0, n_keys, 1024)):
129        e = min(s + 1024, n_keys)

131        keys = keys_store[s:e]

133        idx = np.arange(s, e)

索引に追加

135        index.add_with_ids(keys, idx)
136
137    with monit.section('Save'):

インデックスを保存する

139        faiss.write_index(index, str(lab.get_data_path() / 'faiss.index'))
142def main():
145    conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')

モデルを評価モードに設定

147    conf.model.eval()

収集

150    gather_keys(conf)

索引に追加するとすばやく検索できます

152    build_index(conf)
153
154
155if __name__ == '__main__':
156    main()