k-最近傍言語モデルの評価

11from typing import Optional, List
12
13import faiss
14import numpy as np
15import torch
16
17from labml import monit, lab
18from labml.logger import inspect
19from labml_nn.transformers.knn.train_model import Configs

-NN で取得

ここでは、クエリ、キー、値と呼びます。

22def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray, vals_store: np.ndarray, n_tokens: int):

クエリの形状を保存して結果の形状を変える

31    queries_shape = queries.shape

batch sequence クエリの次元と次元を平坦化

34    queries = queries.view(-1, queries_shape[-1])

その中から最も近い隣人を10人見つける。distance はFAISSで与えられた距離でidxはその距離のインデックスです

keys_store
38    distance, idx = index.search(queries.numpy(), 10)

取得

41    keys_found = queries.new_tensor(keys_store[idx])

取得

43    vals_found = torch.tensor(vals_store[idx]).squeeze(-1)

正規化されたベクトル間のコサイン類似度を計算します

ノーマライズ

48    keys_found_n = keys_found / torch.sqrt((keys_found ** 2).sum(-1, keepdims=True) + 1e-10)

ノーマライズ

50    queries_n = queries / torch.sqrt((queries ** 2).sum(-1, keepdims=True) + 1e-10)

点積またはコサイン類似度を求める

53    dot_prod = (keys_found_n * queries_n.unsqueeze(1)).sum(-1)

トークンごとのロジット

56    logits_token = dot_prod.new_zeros(queries.shape[0], n_tokens)

最も近い隣人に基づいてトークンロジットを分散して蓄積する

58    _ = logits_token.scatter_(dim=1, index=vals_found, src=dot_prod, reduce='add')

ロジットの形状を変える

61    logits_token = logits_token.reshape(queries_shape[0], queries_shape[1], -1)
62
63    return logits_token

検証損失の計算

-NN 予測とトランスフォーマー予測を組み合わせた場合の検証損失を計算します。-NN モデルに与えられる重みはで与えられます。knn_weight これは重みのリストで、それぞれの検証損失を計算します

66def validation_loss(knn_weights: List[float], last_n: Optional[int], conf: Configs, index: faiss.IndexFlatL2,
67                    keys_store: np.ndarray, vals_store: np.ndarray):

それぞれの損失のリスト knn_weights

77    losses = [[] for _ in knn_weights]

各バッチのサンプル数

79    n_samples = []
80    with torch.no_grad():

検証データを繰り返し処理

82        for i, batch in monit.enum("Validation", conf.validator.data_loader, is_children_silent=True):

データとターゲットラベルを取得

84            data, target = batch[0].to(conf.device), batch[1].to(conf.device)

モデルを実行して予測を取得

86            res = conf.model(data)

-NN 予測を取得

88            res_knn = knn(conf.model.ff_input.cpu(), index, keys_store, vals_store, conf.n_tokens)
89            res_knn = res_knn.to(conf.device)

last_n これはトークンの損失のみを計算するためのものです。トランスフォーマーモデルの(シーケンスに沿った)最初の予測では、調べるべき過去のトークンがほとんどないため、これは重要です

94            if last_n:
95                res = res[-last_n:]
96                res_knn = res_knn[-last_n:]
97                target = target[-last_n:]

サンプル数

100            n_s = res.shape[0] * data.shape[1]
101            n_samples.append(n_s)

それぞれのスコアを計算しますknn_weights

104            for i, c in enumerate(knn_weights):

損失の計算

106                loss = conf.loss_func(res_knn * c + (1 - c) * res, target)
107                losses[i].append(loss * n_s)
108
109    return losses, n_samples

インデックスを読み込む

112def load_index(conf: Configs, n_probe: int = 8):

の寸法

117    d_model = conf.transformer.d_model

トレーニングデータローダー

119    data_loader = conf.trainer.data_loader

コンテキストの数。つまり、トレーニングデータ内のトークン数から 1 を引いた数です。

にとって
122    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

FAISS インデックスをロード

125    with monit.section('Load index'):
126        index = faiss.read_index(str(lab.get_data_path() / 'faiss.index'))

プローブするセルの数を設定

128    index.nprobe = n_probe

メモリマップされた numpy 配列をロード

131    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))
132    vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='r', shape=(n_keys, 1))
133
134    return index, keys_store, vals_store
137def main():
138    from labml_nn.transformers.knn.build_index import load_experiment
141    conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')

モデルを評価モードに設定

143    conf.model.eval()

ロードインデックス

146    index, keys_store, vals_store = load_index(conf)

-NN 予測に与えられる重みのリスト。それぞれの重みの検証損失を評価します。

149    knn_weights = [i / 20 for i in range(10)]

検証損失の評価

151    losses, n_samples = validation_loss(knn_weights, None, conf, index, keys_store, vals_store)

それぞれの損失を出力しますknn_weights

153    inspect({c: np.sum(losses[i]) / np.sum(n_samples) for i, c in enumerate(knn_weights)})
154
155
156if __name__ == '__main__':
157    main()