This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.

This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest $n_{cm} c$ memories are compressed into $n_{cm}$ memories, where $c$ is the compression rate.

The compression operation is defined as $f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}$. The paper introduces multiple choices for $f_c$ and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation $f_c^{(i)}$ where $i$ is the layer number.

Since training compression with BPTT requires maintaining
a very large computational graph (many time steps), the paper proposes
an *auto-encoding loss* and an *attention reconstruction loss*.
The auto-encoding loss decodes the original memories from the compressed memories
and calculates the loss.
Attention reconstruction loss computes the multi-headed attention results
on the compressed memory and on uncompressed memory and gets a mean squared error
between them.
We have implemented the latter here since it gives better results.

This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN[../feedforward.html) and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.

Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.

```
54from typing import Optional, List
55
56import torch
57import torch.nn.functional as F
58from torch import nn
59
60from labml_helpers.module import Module, TypedModuleList
61from labml_nn.transformers.feed_forward import FeedForward
62from labml_nn.transformers.mha import PrepareForMultiHeadAttention
63from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
64from labml_nn.utils import clone_module_list
```

This is a simple wrapper around
`nn.Conv1d`

with some tensor dimension permutations.

`67class Conv1dCompression(Module):`

`compression_rate`

$c$`d_model`

is the embedding size

`75 def __init__(self, compression_rate: int, d_model: int):`

```
80 super().__init__()
81 self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)
```

`mem`

has shape `[seq_len, batch, d_model]`

`83 def forward(self, mem: torch.Tensor):`

Permute the dimensions of `mem`

so that we can run it through the convolution layer.
The convolution layer accepts in the form `[batch, features, sequence]`

`90 mem = mem.permute(1, 2, 0)`

Get compressed memory by running it through the convolution layer

`92 c_mem = self.conv(mem)`

Permute back to form `[seq_len, batch, d_model]`

`94 return c_mem.permute(2, 0, 1)`

This is the implementation of a single compressive transformer layer

`97class CompressiveTransformerLayer(Module):`

`d_model`

is the token embedding size`self_attn`

is the self attention module`feed_forward`

is the feed forward module`dropout_prob`

is the probability of dropping out after self attention and FFN`compress`

is the compression function $f_c$

```
103 def __init__(self, *,
104 d_model: int,
105 self_attn: RelativeMultiHeadAttention,
106 feed_forward: FeedForward,
107 dropout_prob: float,
108 compress: Conv1dCompression):
```

```
116 super().__init__()
117 self.compress = compress
118 self.size = d_model
119 self.self_attn = self_attn
120 self.feed_forward = feed_forward
121 self.dropout = nn.Dropout(dropout_prob)
122 self.norm_self_attn = nn.LayerNorm([d_model])
123 self.norm_ff = nn.LayerNorm([d_model])
```

Concatenate the normalized token embeddings with memory and compressed memory.

`z`

is layer normalized token embeddings.`mem`

and`c_mem`

are memory and compressed memory (not normalized).

`125 def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):`

If there is no memory just return the token embeddings

```
134 if mem is None:
135 return z
```

If there are compressed memory concatenate that with memory

```
138 if c_mem is not None:
139 mem = torch.cat((c_mem, mem), dim=0)
```

Run the memory through the normalization layer

`142 mem = self.norm_self_attn(mem)`

Concatenate normalized memory and normalized token embeddings

`144 return torch.cat((mem, z), dim=0)`

`x`

is a tensor of token level feature vectors of shape`[seq_len, batch_size, d_model]`

`mem`

is a tensor of the past token level feature vectors (memory) of shape`[mem_len, batch_size, d_model]`

`c_mem`

is a tensor of the compressed memory`[c_mem_len, batch_size, d_model]`

`mask`

is a matrix of shape`[seq_len, c_mem_len + mem_len + seq_len, batch_size]`

or`[seq_len, c_mem_len + mem_len + seq_len, 1]`

.`mask[i, j]`

is true if token at`i`

can see token at`j`

.

```
146 def forward(self, *,
147 x: torch.Tensor,
148 mem: Optional[torch.Tensor],
149 c_mem: Optional[torch.Tensor],
150 mask: torch.Tensor):
```

Normalize the vectors before doing self attention

`160 z = self.norm_self_attn(x)`

Normalize and concatenate memory and compressed memory

`162 m_z = self.concat_memory(z, mem, c_mem)`

Attention

`164 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)`

Add the attention results

`166 x = x + self.dropout(self_attn)`

Normalize for feed-forward

`169 z = self.norm_ff(x)`

Pass through the feed-forward network

`171 ff = self.feed_forward(z)`

Add the feed-forward results back

`173 x = x + self.dropout(ff)`

`176 return x`

`179class CompressiveTransformer(Module):`

```
186 def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
187 super().__init__()
```

Make copies of the transformer layer

`189 self.layers = clone_module_list(layer, n_layers)`

Final normalization layer

`191 self.norm = nn.LayerNorm([layer.size])`

`x`

is a tensor of the token embeddings vectors of shape`[seq_len, batch_size, d_model]`

`mem`

is a list of tensors of the past token level feature vectors of shape`[mem_len, batch_size, d_model]`

for each layer`c_mem`

is a list of tensors of the compressed memory`[c_mem_len, batch_size, d_model]`

for each layer`mask`

is the masking matrix

`193 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):`

List to store token level feature vectors, which will become the memories for the next sequential batch.

`204 new_mem = []`

Run through each transformer layer

`206 for i, layer in enumerate(self.layers):`

Add to the list of feature vectors

`208 new_mem.append(x.detach())`

Memory

`210 m = mem[i] if mem else None`

Compressed Memory

`212 cm = c_mem[i] if c_mem else None`

Run through the transformer XL layer

`214 x = layer(x=x, mem=m, c_mem=cm, mask=mask)`

Finally, normalize the vectors

`216 return self.norm(x), new_mem`

Attention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.

When calculating and training the compression function $f_c$ with attention reconstruction loss, all parameters but $f_c$ are frozen. This includes key/value projections and bias/scaling after normalization.

Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates $f_c$. However, we use the same optimizer to update $f_c$ so when calculating attention reconstruction loss, we detach all other parameters except $f_c$ from the gradient computation.

`219class AttentionReconstructionLoss:`

`layers`

is the list of Compressive Transformer layers

`237 def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):`

```
241 self.layers = layers
242 self.loss_func = nn.MSELoss()
```

This is a reimplementation of ‘PrepareForMultiHeadAttention’ where the projections are done with the parameters detached from gradient computation.

- `pmha* is the ‘PrepareForMultiHeadAttention’ module
`x`

is tensor with the token embeddings

`244 def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):`

Shape of the input except embedding dimension; `[seq_len, batch_size]`

.

`254 head_shape = x.shape[:-1]`

Detach projection weights and bias

```
257 weight = pmha.linear.weight.detach()
258 bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None
```

Linear transform

`260 x = F.linear(x, weight, bias)`

Split last dimension into heads

`263 x = x.view(*head_shape, pmha.heads, pmha.d_k)`

Output has shape `[seq_len, batch_size, heads, d_k]`

or `[batch_size, d_model]`

`266 return x`

This is a reimplementation of ‘Multi-Head Attention’ which calls
`prepare_for_attn`

instead of ‘PrepareForMultiHeadAttention’
to detach projection parameters.

`268 def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):`

Calculate query, key and value projections

```
275 query = self.prepare_for_attn(layer.query, query)
276 key = self.prepare_for_attn(layer.key, key)
277 value = self.prepare_for_attn(layer.value, value)
```

Compute attention scores $Q K^\top$.
This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`

.

`281 scores = torch.einsum('ibhd,jbhd->ijbh', query, key)`

Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$

`284 scores *= layer.scale`

$softmax$ attention along the key sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$

`288 attn = layer.softmax(scores)`

Multiply by values

`292 return torch.einsum("ijbh,jbhd->ibhd", attn, value)`

Perform layer normalization with shift and scale parameters detached.

`294 def norm(self, ln: nn.LayerNorm, x: torch.Tensor):`

Detach shift(`bias`

) and scaling(`weight`

) parameters

```
300 weight = ln.weight.detach() if ln.weight is not None else None
301 bias = ln.bias.detach() if ln.bias is not None else None
```

Layer normalization

`304 return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)`

This calculates the loss for a layer

`306 def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):`

Detach the token embeddings and memory.

```
312 h = h.detach()
313 mem = mem.detach()
```

Compress the memory with $f_c^{(i)}$. The parameters of $f_c^{(i)}$ are the only parameters not detached from gradient computation.

`317 c_mem = layer.compress(mem)`

Normalize the embeddings and memories

```
320 h = self.norm(layer.norm_self_attn, h)
321 mem = self.norm(layer.norm_self_attn, mem)
322 c_mem = self.norm(layer.norm_self_attn, c_mem)
```

Calculate the attention with uncompressed memory

`325 attn_mem = self.attn(layer.self_attn, h, mem, mem)`

Calculate the attention with compressed memory

`327 attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)`

Calculate the mean square error

`330 return self.loss_func(attn_cmem, attn_mem)`

`332 def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):`

Calculate the losses for each layer

`334 losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]`

Sum of the losses

`336 return sum(losses)`