Compressive Transformer

This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.

This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest memories are compressed into memories, where is the compression rate.

Compression operation

The compression operation is defined as . The paper introduces multiple choices for and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation where is the layer number.

Training compression operation

Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.

This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.

Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.

Open In Colab

53from typing import Optional, List
55import torch
56import torch.nn.functional as F
57from torch import nn
59from labml_helpers.module import Module, TypedModuleList
60from labml_nn.transformers.feed_forward import FeedForward
61from labml_nn.transformers.mha import PrepareForMultiHeadAttention
62from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
63from labml_nn.utils import clone_module_list

1D Convolution Compression

This is a simple wrapper around nn.Conv1d with some tensor dimension permutations.

66class Conv1dCompression(Module):
  • compression_rate
  • d_model is the embedding size
74    def __init__(self, compression_rate: int, d_model: int):
79        super().__init__()
80        self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)

mem has shape [seq_len, batch, d_model]

82    def forward(self, mem: torch.Tensor):

Permute the dimensions of mem so that we can run it through the convolution layer. The convolution layer accepts in the form [batch, features, sequence]

89        mem = mem.permute(1, 2, 0)

Get compressed memory by running it through the convolution layer

91        c_mem = self.conv(mem)

Permute back to form [seq_len, batch, d_model]

93        return c_mem.permute(2, 0, 1)

Compressive Transformer Layer

This is the implementation of a single compressive transformer layer

96class CompressiveTransformerLayer(Module):
  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
  • dropout_prob is the probability of dropping out after self attention and FFN
  • compress is the compression function
102    def __init__(self, *,
103                 d_model: int,
104                 self_attn: RelativeMultiHeadAttention,
105                 feed_forward: FeedForward,
106                 dropout_prob: float,
107                 compress: Conv1dCompression):
115        super().__init__()
116        self.compress = compress
117        self.size = d_model
118        self.self_attn = self_attn
119        self.feed_forward = feed_forward
120        self.dropout = nn.Dropout(dropout_prob)
121        self.norm_self_attn = nn.LayerNorm([d_model])
122        self.norm_ff = nn.LayerNorm([d_model])

Concatenate the normalized token embeddings with memory and compressed memory.

  • z is layer normalized token embeddings.
  • mem and c_mem are memory and compressed memory (not normalized).
124    def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):

If there is no memory just return the token embeddings

133        if mem is None:
134            return z

If there are compressed memory concatenate that with memory

137        if c_mem is not None:
138            mem =, mem), dim=0)

Run the memory through the normalization layer

141        mem = self.norm_self_attn(mem)

Concatenate normalized memory and normalized token embeddings

143        return, z), dim=0)
  • x is a tensor of token level feature vectors of shape [seq_len, batch_size, d_model]
  • mem is a tensor of the past token level feature vectors (memory) of shape [mem_len, batch_size, d_model]
  • c_mem is a tensor of the compressed memory [c_mem_len, batch_size, d_model]
  • mask is a matrix of shape [seq_len, c_mem_len + mem_len + seq_len, batch_size] or [seq_len, c_mem_len + mem_len + seq_len, 1] . mask[i, j] is true if token at i can see token at j .
145    def forward(self, *,
146                x: torch.Tensor,
147                mem: Optional[torch.Tensor],
148                c_mem: Optional[torch.Tensor],
149                mask: torch.Tensor):

Normalize the vectors before doing self attention

159        z = self.norm_self_attn(x)

Normalize and concatenate memory and compressed memory

161        m_z = self.concat_memory(z, mem, c_mem)


163        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)

Add the attention results

165        x = x + self.dropout(self_attn)

Normalize for feed-forward

168        z = self.norm_ff(x)

Pass through the feed-forward network

170        ff = self.feed_forward(z)

Add the feed-forward results back

172        x = x + self.dropout(ff)

175        return x

Compressive Transformer Model

This consists of multiple compressive transformer layers

178class CompressiveTransformer(Module):
185    def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
186        super().__init__()

Make copies of the transformer layer

188        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

190        self.norm = nn.LayerNorm([layer.size])
  • x is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]
  • mem is a list of tensors of the past token level feature vectors of shape [mem_len, batch_size, d_model] for each layer
  • c_mem is a list of tensors of the compressed memory [c_mem_len, batch_size, d_model] for each layer
  • mask is the masking matrix
192    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):

List to store token level feature vectors, which will become the memories for the next sequential batch.

203        new_mem = []

Run through each transformer layer

205        for i, layer in enumerate(self.layers):

Add to the list of feature vectors

207            new_mem.append(x.detach())


209            m = mem[i] if mem else None

Compressed Memory

211            cm = c_mem[i] if c_mem else None

Run through the transformer XL layer

213            x = layer(x=x, mem=m, c_mem=cm, mask=mask)

Finally, normalize the vectors

215        return self.norm(x), new_mem

Attention Reconstruction Loss

Attention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.

When calculating and training the compression function with attention reconstruction loss, all parameters but are frozen. This includes key/value projections and bias/scaling after normalization.

Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates . However, we use the same optimizer to update so when calculating attention reconstruction loss, we detach all other parameters except from the gradient computation.

218class AttentionReconstructionLoss:

layers is the list of Compressive Transformer layers

236    def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
240        self.layers = layers
241        self.loss_func = nn.MSELoss()

This is a reimplementation of 'PrepareForMultiHeadAttention' where the projections are done with the parameters detached from gradient computation.

243    def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):

Shape of the input except embedding dimension; [seq_len, batch_size] .

253        head_shape = x.shape[:-1]

Detach projection weights and bias

256        weight = pmha.linear.weight.detach()
257        bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None

Linear transform

259        x = F.linear(x, weight, bias)

Split last dimension into heads

262        x = x.view(*head_shape, pmha.heads, pmha.d_k)

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]

265        return x

This is a reimplementation of 'Multi-Head Attention' which calls prepare_for_attn instead of 'PrepareForMultiHeadAttention' to detach projection parameters.

267    def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):

Calculate query, key and value projections

274        query = self.prepare_for_attn(layer.query, query)
275        key = self.prepare_for_attn(layer.key, key)
276        value = self.prepare_for_attn(layer.value, value)

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

280        scores = torch.einsum('ibhd,jbhd->ijbh', query, key)

Scale scores

283        scores *= layer.scale

attention along the key sequence dimension

287        attn = layer.softmax(scores)

Multiply by values

291        return torch.einsum("ijbh,jbhd->ibhd", attn, value)

Perform layer normalization with shift and scale parameters detached.

293    def norm(self, ln: nn.LayerNorm, x: torch.Tensor):

Detach shift(bias ) and scaling(weight ) parameters

299        weight = ln.weight.detach() if ln.weight is not None else None
300        bias = ln.bias.detach() if ln.bias is not None else None

Layer normalization

303        return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)

This calculates the loss for a layer

305    def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):

Detach the token embeddings and memory.

311        h = h.detach()
312        mem = mem.detach()

Compress the memory with . The parameters of are the only parameters not detached from gradient computation.

316        c_mem = layer.compress(mem)

Normalize the embeddings and memories

319        h = self.norm(layer.norm_self_attn, h)
320        mem = self.norm(layer.norm_self_attn, mem)
321        c_mem = self.norm(layer.norm_self_attn, c_mem)

Calculate the attention with uncompressed memory

324        attn_mem = self.attn(layer.self_attn, h, mem, mem)

Calculate the attention with compressed memory

326        attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)

Calculate the mean square error

329        return self.loss_func(attn_cmem, attn_mem)
331    def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):

Calculate the losses for each layer

333        losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]

Sum of the losses

335        return sum(losses)