k-Nearest Neighbor Language Models

This is a PyTorch implementation of the paper Generalization through Memorization: Nearest Neighbor Language Models. It uses k-nearest neighbors to improve perplexity of autoregressive transformer models.

An autoregressive language model estimates , where is the token at step and is the context, .

This paper, improves using a k-nearest neighbor search on key-value pairs , with search key . Here is an embedding of the context . The paper (and this implementation) uses the input to the feed-forward layer of the final layer of the transformer as .

We use FAISS to index .

Implementation

So to run NN-LM we need to:

This experiment uses a small dataset so that we can run this without using up a few hundred giga-bytes of disk space for the index.

The official implementation of NN-LM can be found here.