This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.
This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest memories are compressed into memories, where is the compression rate.
The compression operation is defined as . The paper introduces multiple choices for and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation where is the layer number.
Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.
This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.
Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.
53from typing import Optional, List
54
55import torch
56import torch.nn.functional as F
57from torch import nn
58
59from labml_helpers.module import Module, TypedModuleList
60from labml_nn.transformers.feed_forward import FeedForward
61from labml_nn.transformers.mha import PrepareForMultiHeadAttention
62from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
63from labml_nn.utils import clone_module_list
This is a simple wrapper around nn.Conv1d
with some tensor dimension permutations.
66class Conv1dCompression(Module):
compression_rate
d_model
is the embedding size74 def __init__(self, compression_rate: int, d_model: int):
79 super().__init__()
80 self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)
mem
has shape [seq_len, batch, d_model]
82 def forward(self, mem: torch.Tensor):
Permute the dimensions of mem
so that we can run it through the convolution layer. The convolution layer accepts in the form [batch, features, sequence]
89 mem = mem.permute(1, 2, 0)
Get compressed memory by running it through the convolution layer
91 c_mem = self.conv(mem)
Permute back to form [seq_len, batch, d_model]
93 return c_mem.permute(2, 0, 1)
This is the implementation of a single compressive transformer layer
96class CompressiveTransformerLayer(Module):
d_model
is the token embedding size self_attn
is the self attention module feed_forward
is the feed forward module dropout_prob
is the probability of dropping out after self attention and FFN compress
is the compression function 102 def __init__(self, *,
103 d_model: int,
104 self_attn: RelativeMultiHeadAttention,
105 feed_forward: FeedForward,
106 dropout_prob: float,
107 compress: Conv1dCompression):
115 super().__init__()
116 self.compress = compress
117 self.size = d_model
118 self.self_attn = self_attn
119 self.feed_forward = feed_forward
120 self.dropout = nn.Dropout(dropout_prob)
121 self.norm_self_attn = nn.LayerNorm([d_model])
122 self.norm_ff = nn.LayerNorm([d_model])
Concatenate the normalized token embeddings with memory and compressed memory.
z
is layer normalized token embeddings. mem
and c_mem
are memory and compressed memory (not normalized).124 def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
If there is no memory just return the token embeddings
133 if mem is None:
134 return z
If there are compressed memory concatenate that with memory
137 if c_mem is not None:
138 mem = torch.cat((c_mem, mem), dim=0)
Run the memory through the normalization layer
141 mem = self.norm_self_attn(mem)
Concatenate normalized memory and normalized token embeddings
143 return torch.cat((mem, z), dim=0)
x
is a tensor of token level feature vectors of shape [seq_len, batch_size, d_model]
mem
is a tensor of the past token level feature vectors (memory) of shape [mem_len, batch_size, d_model]
c_mem
is a tensor of the compressed memory [c_mem_len, batch_size, d_model]
mask
is a matrix of shape [seq_len, c_mem_len + mem_len + seq_len, batch_size]
or [seq_len, c_mem_len + mem_len + seq_len, 1]
. mask[i, j]
is true if token at i
can see token at j
.145 def forward(self, *,
146 x: torch.Tensor,
147 mem: Optional[torch.Tensor],
148 c_mem: Optional[torch.Tensor],
149 mask: torch.Tensor):
Normalize the vectors before doing self attention
159 z = self.norm_self_attn(x)
Normalize and concatenate memory and compressed memory
161 m_z = self.concat_memory(z, mem, c_mem)
Attention
163 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
Add the attention results
165 x = x + self.dropout(self_attn)
Normalize for feed-forward
168 z = self.norm_ff(x)
Pass through the feed-forward network
170 ff = self.feed_forward(z)
Add the feed-forward results back
172 x = x + self.dropout(ff)
175 return x
178class CompressiveTransformer(Module):
185 def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
186 super().__init__()
Make copies of the transformer layer
188 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
190 self.norm = nn.LayerNorm([layer.size])
x
is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]
mem
is a list of tensors of the past token level feature vectors of shape [mem_len, batch_size, d_model]
for each layer c_mem
is a list of tensors of the compressed memory [c_mem_len, batch_size, d_model]
for each layer mask
is the masking matrix192 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):
List to store token level feature vectors, which will become the memories for the next sequential batch.
203 new_mem = []
Run through each transformer layer
205 for i, layer in enumerate(self.layers):
Add to the list of feature vectors
207 new_mem.append(x.detach())
Memory
209 m = mem[i] if mem else None
Compressed Memory
211 cm = c_mem[i] if c_mem else None
Run through the transformer XL layer
213 x = layer(x=x, mem=m, c_mem=cm, mask=mask)
Finally, normalize the vectors
215 return self.norm(x), new_mem
Attention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.
When calculating and training the compression function with attention reconstruction loss, all parameters but are frozen. This includes key/value projections and bias/scaling after normalization.
Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates . However, we use the same optimizer to update so when calculating attention reconstruction loss, we detach all other parameters except from the gradient computation.
218class AttentionReconstructionLoss:
layers
is the list of Compressive Transformer layers
236 def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
240 self.layers = layers
241 self.loss_func = nn.MSELoss()
This is a reimplementation of 'PrepareForMultiHeadAttention' where the projections are done with the parameters detached from gradient computation.
pmha
is the 'PrepareForMultiHeadAttention' module x
is tensor with the token embeddings243 def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):
Shape of the input except embedding dimension; [seq_len, batch_size]
.
253 head_shape = x.shape[:-1]
Detach projection weights and bias
256 weight = pmha.linear.weight.detach()
257 bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None
Linear transform
259 x = F.linear(x, weight, bias)
Split last dimension into heads
262 x = x.view(*head_shape, pmha.heads, pmha.d_k)
Output has shape [seq_len, batch_size, heads, d_k]
or [batch_size, d_model]
265 return x
This is a reimplementation of 'Multi-Head Attention' which calls prepare_for_attn
instead of 'PrepareForMultiHeadAttention' to detach projection parameters.
267 def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
Calculate query, key and value projections
274 query = self.prepare_for_attn(layer.query, query)
275 key = self.prepare_for_attn(layer.key, key)
276 value = self.prepare_for_attn(layer.value, value)
Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
.
280 scores = torch.einsum('ibhd,jbhd->ijbh', query, key)
Scale scores
283 scores *= layer.scale
attention along the key sequence dimension
287 attn = layer.softmax(scores)
Multiply by values
291 return torch.einsum("ijbh,jbhd->ibhd", attn, value)
Perform layer normalization with shift and scale parameters detached.
293 def norm(self, ln: nn.LayerNorm, x: torch.Tensor):
Detach shift(bias
) and scaling(weight
) parameters
299 weight = ln.weight.detach() if ln.weight is not None else None
300 bias = ln.bias.detach() if ln.bias is not None else None
Layer normalization
303 return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)
This calculates the loss for a layer
305 def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):
Detach the token embeddings and memory.
311 h = h.detach()
312 mem = mem.detach()
Compress the memory with . The parameters of are the only parameters not detached from gradient computation.
316 c_mem = layer.compress(mem)
Normalize the embeddings and memories
319 h = self.norm(layer.norm_self_attn, h)
320 mem = self.norm(layer.norm_self_attn, mem)
321 c_mem = self.norm(layer.norm_self_attn, c_mem)
Calculate the attention with uncompressed memory
324 attn_mem = self.attn(layer.self_attn, h, mem, mem)
Calculate the attention with compressed memory
326 attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)
Calculate the mean square error
329 return self.loss_func(attn_cmem, attn_mem)
331 def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):
Calculate the losses for each layer
333 losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]
Sum of the losses
335 return sum(losses)