This is a PyTorch implementation of the paper Generalization through Memorization: Nearest Neighbor Language Models. It uses k-nearest neighbors to improve perplexity of autoregressive transformer models.

An autoregressive language model estimates $p(w_t | \color{yellowgreen}{c_t})$, where $w_t$ is the token at step $t$ and $c_t$ is the context, $\color{yellowgreen}{c_t} = (w_1, w_2, …, w_{t-1})$.

This paper, improves $p(w_t | \color{yellowgreen}{c_t})$ using a k-nearest neighbor search
on key-value pairs $\big(f(c_i), w_i\big)$, with search key $f(\color{yellowgreen}{c_t})$.
Here $f(\color{yellowgreen}{c_t})$ is an embedding of the context $\color{yellowgreen}{c_t}$.
The paper (and this implementation) uses the **input to the feed-forward layer of the
final layer of the transformer** as $f(\color{yellowgreen}{c_t})$.

We use FAISS to index $f(c_i)$.

So to run $k$NN-LM we need to:

- Train a transformer model
- Build an index of $\big(f(c_i), w_i\big)$
- Evaluate kNN-ML using $k$NN seach on $\big(f(c_i), w_i\big)$ with $f(\color{yellowgreen}{c_t})$

This experiment uses a small dataset so that we can run this without using up a few hundred giga-bytes of disk space for the index.

The official implementation of $k$NN-LM can be found here.