This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.
Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is , it will keep the embeddings of all layers for previous batch of length and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.
Annotated implementation of relative multi-headed attention is in relative_mha.py
.
Here's the training code and a notebook for training a transformer XL model on Tiny Shakespeare dataset.
35from typing import List, Optional
36
37import torch
38import torch.nn as nn
39
40from labml_nn.utils import clone_module_list
41from .relative_mha import RelativeMultiHeadAttention
42from ..feed_forward import FeedForward
45class TransformerXLLayer(nn.Module):
d_model
is the token embedding size self_attn
is the self attention module feed_forward
is the feed forward module dropout_prob
is the probability of dropping out after self attention and FFN51 def __init__(self, *,
52 d_model: int,
53 self_attn: RelativeMultiHeadAttention,
54 feed_forward: FeedForward,
55 dropout_prob: float):
62 super().__init__()
63 self.size = d_model
64 self.self_attn = self_attn
65 self.feed_forward = feed_forward
66 self.dropout = nn.Dropout(dropout_prob)
67 self.norm_self_attn = nn.LayerNorm([d_model])
68 self.norm_ff = nn.LayerNorm([d_model])
x
is a tensor of the token level feature vectors of shape [seq_len, batch_size, d_model]
mem
is a tensor of the past token level feature vectors of shape [mem_len, batch_size, d_model]
mask
is a matrix of shape [seq_len, mem_len + seq_len, batch_size]
or [seq_len, mem_len + seq_len, 1]
. mask[i, j]
is true if token at i
can see token at j
.70 def forward(self, *,
71 x: torch.Tensor,
72 mem: Optional[torch.Tensor],
73 mask: torch.Tensor):
Normalize the vectors before doing self attention
81 z = self.norm_self_attn(x)
If there is memory
83 if mem is not None:
Normalize it
85 mem = self.norm_self_attn(mem)
Concatenate with z
87 m_z = torch.cat((mem, z), dim=0)
Ignore if there is no memory
89 else:
90 m_z = z
Attention
92 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
Add the attention results
94 x = x + self.dropout(self_attn)
Normalize for feed-forward
97 z = self.norm_ff(x)
Pass through the feed-forward network
99 ff = self.feed_forward(z)
Add the feed-forward results back
101 x = x + self.dropout(ff)
104 return x
107class TransformerXL(nn.Module):
114 def __init__(self, layer: TransformerXLLayer, n_layers: int):
115 super().__init__()
Make copies of the transformer layer
117 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
119 self.norm = nn.LayerNorm([layer.size])
x
is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]
mem
is a list of tensors of the past token level feature vectors of shape [mem_len, batch_size, d_model]
for each layer mask
is the masking matrix121 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
List to store token level feature vectors, which will become the memories for the next sequential batch.
130 new_mem = []
Run through each transformer layer
132 for i, layer in enumerate(self.layers):
Add to the list of feature vectors
134 new_mem.append(x.detach())
Memory
136 m = mem[i] if mem else None
Run through the transformer XL layer
138 x = layer(x=x, mem=m, mask=mask)
Finally, normalize the vectors
140 return self.norm(x), new_mem