This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.
This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.
They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.
Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.
Here is the training code for the hourglass model.
28from typing import List
29
30import torch
31from torch import nn
32
33from labml_nn.transformers import MultiHeadAttention, TransformerLayer
34from labml_nn.transformers.feed_forward import FeedForward
35from labml_nn.transformers.utils import subsequent_mask
This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).
38class HourGlass(nn.Module):
n_heads
is the number of heads in multi-head attention layers d_model
is the size of the token embeddings dropout
is the dropout probability d_ff
is the dimensionality of the hidden layer in position-wise feed-forward layers shortening_factors
is the list of shortening factors48 def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
56 super().__init__()
The transformer layer before down-sampling
59 self.pre = TransformerLayer(d_model=d_model,
61 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
63 feed_forward=FeedForward(d_model, d_ff, dropout),
65 dropout_prob=dropout)
Auto-regressive mask
67 self.mask = AutoregressiveMask()
The shortening factor (or the down-sampling rate)
70 k = shortening_factors[0]
We shift the tokens to the right by steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling
75 self.shift_right = ShiftRight(k - 1)
Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.
78 self.shortening = AvgPoolShortening(k)
If there are no more shortening (middle of the hourglass)
81 if len(shortening_factors) == 1:
The center layer is another transformer layer
83 self.shortened = TransformerLayer(d_model=d_model,
84 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
85 feed_forward=FeedForward(d_model, d_ff, dropout),
86 dropout_prob=dropout)
Autoregressive mask
88 self.mask_short = AutoregressiveMask()
89 self.hour_glass = None
90 else:
Insert another hourglass model recursively
92 self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])
Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.
96 self.up_sampling = NaiveUpSampling(k)
The final transformer layer after up-sampling
99 self.post = TransformerLayer(d_model=d_model,
100 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
101 feed_forward=FeedForward(d_model, d_ff, dropout),
102 dropout_prob=dropout)
104 def forward(self, x: torch.Tensor):
Initial transformer layer
107 x = self.pre(x=x, mask=self.mask(x))
Shifting and shortening
110 x_short = self.shortening(self.shift_right(x))
If we are at the center of the hourglass,
114 if self.hour_glass is None:
Center transformer layer
117 x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))
119 else:
121 x_short = self.hour_glass(x_short)
Up-sample the shortened sequence and add a skip connection
125 x = x + self.up_sampling(x, x_short)
Final transformer layer
128 x = self.post(x=x, mask=self.mask(x))
131 return x
134class ShiftRight(nn.Module):
shift
is the number of steps to shift by141 def __init__(self, shift: int):
145 super().__init__()
cannot be negative
147 assert shift >= 0
149 self.shift = shift
x
is a tensor of shape [seq_len, ...]
151 def forward(self, x: torch.Tensor):
If the shift is return the original
156 if self.shift == 0:
157 return x
Zeros to be appended to the left
159 prefix = x.new_zeros([self.shift, *x.shape[1:]])
Concatenate the zeros and truncate the right
161 return torch.cat([prefix, x[:-self.shift]])
164class AvgPoolShortening(nn.Module):
k
is the shortening factor171 def __init__(self, k: int):
175 super().__init__()
Average pooling layer
177 self.pool = nn.AvgPool1d(k, ceil_mode=True)
x
is of shape [seq_len, batch_size, d_model]
179 def forward(self, x: torch.Tensor):
Pooling layer accepts shape [batch_size, d_model, seq_len]
so we permute axes.
185 return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)
188class NaiveUpSampling(nn.Module):
k
is the shortening factor195 def __init__(self, k: int):
199 super().__init__()
200 self.k = k
x
is the tensor with embeddings before down-sampling x_short
is the tensor of higher density (to be up-sampled) representations202 def forward(self, x: torch.Tensor, x_short: torch.Tensor):
Repeat across the sequence dimension
208 expanded = torch.repeat_interleave(x_short, self.k, dim=0)
Truncate the extra embeddings at the end
210 expanded = expanded[:x.shape[0]]
213 return expanded
216class AutoregressiveMask(nn.Module):
221 def __init__(self):
222 super().__init__()
223 self.mask = None
225 def forward(self, x: torch.Tensor):
Create a mask if we haven't created or sizes have changed
227 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
229 self.mask = subsequent_mask(len(x)).to(x.device)
232 return self.mask
This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.
235class LinearPoolingShortening(nn.Module):
243 def __init__(self):
244 super().__init__()
245 raise NotImplementedError
248class AttentionBasedShortening(nn.Module):
260 def __init__(self):
261 super().__init__()
262 raise NotImplementedError
Make a linear projection of dense token embeddings to a size of .
265class LinearUpSampling(nn.Module):
272 def __init__(self):
273 super().__init__()
274 raise NotImplementedError
277class AttentionBasedUpSampling(nn.Module):
289 def __init__(self):
290 super().__init__()
291 raise NotImplementedError