This is a PyTorch implementation of the paper Generalization through Memorization: Nearest Neighbor Language Models. It uses k-nearest neighbors to improve perplexity of autoregressive transformer models.

An autoregressive language model estimates $p(w_{t}∣c_{t})$, where $w_{t}$ is the token at step $t$ and $c_{t}$ is the context, $c_{t}=(w_{1},w_{2},...,w_{t−1})$.

This paper, improves $p(w_{t}∣c_{t})$ using a k-nearest neighbor search on key-value pairs $(f(c_{i}),w_{i})$, with search key $f(c_{t})$. Here $f(c_{t})$ is an embedding of the context $c_{t}$. The paper (and this implementation) uses the **input to the feed-forward layer of the final layer of the transformer** as $f(c_{t})$.

We use FAISS to index $f(c_{i})$.

So to run $k$NN-LM we need to:

- Train a transformer model
- Build an index of $(f(c_{i}),w_{i})$
- Evaluate kNN-ML using $k$NN seach on $(f(c_{i}),w_{i})$ with $f(c_{t})$

This experiment uses a small dataset so that we can run this without using up a few hundred giga-bytes of disk space for the index.

The official implementation of $k$NN-LM can be found here.