This is a PyTorch implementation of the paper Generalization through Memorization: Nearest Neighbor Language Models. It uses k-nearest neighbors to improve perplexity of autoregressive transformer models.
An autoregressive language model estimates , where is the token at step and is the context, .
This paper, improves using a k-nearest neighbor search on key-value pairs , with search key . Here is an embedding of the context . The paper (and this implementation) uses the input to the feed-forward layer of the final layer of the transformer as .
We use FAISS to index .
So to run NN-LM we need to:
This experiment uses a small dataset so that we can run this without using up a few hundred giga-bytes of disk space for the index.
The official implementation of NN-LM can be found here.