Retrieval-Enhanced Transformer (Retro)

This is a PyTorch implementation of the paper Improving language models by retrieving from trillions of tokens.

It builds a database of chunks of text. It is a key-value database where the keys are indexed by the BERT embeddings of the chunks. They use a frozen pre-trained BERT model to calculate these embeddings. The values are the corresponding chunks and an equal length of text proceeding that chunk.

Then the model retrieves text similar (nearest neighbors) to the input to the model from this database. These retrieved texts are used to predict the output.

Since we use a frozen BERT model for retrieval we can pre-calculate all the nearest neighbors for the training dataset. This speeds up the training process.