This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.

Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is $l$, it will keep the embeddings of all layers for previous batch of length $l$ and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.

Annotated implementation of relative multi-headed attention is in `relative_mha.py`

.

Here’s the training code and a notebook for training a transformer XL model on Tiny Shakespeare dataset.

```
36from typing import List, Optional
37
38import torch
39import torch.nn as nn
40
41from labml_helpers.module import Module
42from labml_nn.utils import clone_module_list
43from .relative_mha import RelativeMultiHeadAttention
44from ..feed_forward import FeedForward
```

`47class TransformerXLLayer(Module):`

`d_model`

is the token embedding size`self_attn`

is the self attention module`feed_forward`

is the feed forward module`dropout_prob`

is the probability of dropping out after self attention and FFN

```
53 def __init__(self, *,
54 d_model: int,
55 self_attn: RelativeMultiHeadAttention,
56 feed_forward: FeedForward,
57 dropout_prob: float):
```

```
64 super().__init__()
65 self.size = d_model
66 self.self_attn = self_attn
67 self.feed_forward = feed_forward
68 self.dropout = nn.Dropout(dropout_prob)
69 self.norm_self_attn = nn.LayerNorm([d_model])
70 self.norm_ff = nn.LayerNorm([d_model])
```

`x`

is a tensor of the token level feature vectors of shape`[seq_len, batch_size, d_model]`

`mem`

is a tensor of the past token level feature vectors of shape`[mem_len, batch_size, d_model]`

`mask`

is a matrix of shape`[seq_len, mem_len + seq_len, batch_size]`

or`[seq_len, mem_len + seq_len, 1]`

.`mask[i, j]`

is true if token at`i`

can see token at`j`

.

```
72 def forward(self, *,
73 x: torch.Tensor,
74 mem: Optional[torch.Tensor],
75 mask: torch.Tensor):
```

Normalize the vectors before doing self attention

`83 z = self.norm_self_attn(x)`

If there is memory

`85 if mem is not None:`

Normalize it

`87 mem = self.norm_self_attn(mem)`

Concatenate with `z`

`89 m_z = torch.cat((mem, z), dim=0)`

Ignore if there is no memory

```
91 else:
92 m_z = z
```

Attention

`94 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)`

Add the attention results

`96 x = x + self.dropout(self_attn)`

Normalize for feed-forward

`99 z = self.norm_ff(x)`

Pass through the feed-forward network

`101 ff = self.feed_forward(z)`

Add the feed-forward results back

`103 x = x + self.dropout(ff)`

`106 return x`

`109class TransformerXL(Module):`

```
116 def __init__(self, layer: TransformerXLLayer, n_layers: int):
117 super().__init__()
```

Make copies of the transformer layer

`119 self.layers = clone_module_list(layer, n_layers)`

Final normalization layer

`121 self.norm = nn.LayerNorm([layer.size])`

`x`

is a tensor of the token embeddings vectors of shape`[seq_len, batch_size, d_model]`

`mem`

is a list of tensors of the past token level feature vectors of shape`[mem_len, batch_size, d_model]`

for each layer`mask`

is the masking matrix

`123 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):`

List to store token level feature vectors, which will become the memories for the next sequential batch.

`132 new_mem = []`

Run through each transformer layer

`134 for i, layer in enumerate(self.layers):`

Add to the list of feature vectors

`136 new_mem.append(x.detach())`

Memory

`138 m = mem[i] if mem else None`

Run through the transformer XL layer

`140 x = layer(x=x, mem=m, mask=mask)`

Finally, normalize the vectors

`142 return self.norm(x), new_mem`