Compressive Transformer

This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.

This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest memories are compressed into memories, where is the compression rate.

Compression operation

The compression operation is defined as . The paper introduces multiple choices for and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation where is the layer number.

Training compression operation

Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.

This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.

Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.

Open In Colab