This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.
This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.
They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.
Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.
Here is the training code for the hourglass model.
30from typing import List
31
32import torch
33from torch import nn
34
35from labml_helpers.module import Module
36from labml_nn.transformers import MultiHeadAttention, TransformerLayer
37from labml_nn.transformers.feed_forward import FeedForward
38from labml_nn.transformers.utils import subsequent_mask
This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).
41class HourGlass(Module):
n_heads
is the number of heads in multi-head attention layers d_model
is the size of the token embeddings dropout
is the dropout probability d_ff
is the dimensionality of the hidden layer in position-wise feed-forward layers shortening_factors
is the list of shortening factors51 def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
59 super().__init__()
The transformer layer before down-sampling
62 self.pre = TransformerLayer(d_model=d_model,
64 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
66 feed_forward=FeedForward(d_model, d_ff, dropout),
68 dropout_prob=dropout)
Auto-regressive mask
70 self.mask = AutoregressiveMask()
The shortening factor (or the down-sampling rate)
73 k = shortening_factors[0]
We shift the tokens to the right by steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling
78 self.shift_right = ShiftRight(k - 1)
Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.
81 self.shortening = AvgPoolShortening(k)
If there are no more shortening (middle of the hourglass)
84 if len(shortening_factors) == 1:
The center layer is another transformer layer
86 self.shortened = TransformerLayer(d_model=d_model,
87 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
88 feed_forward=FeedForward(d_model, d_ff, dropout),
89 dropout_prob=dropout)
Autoregressive mask
91 self.mask_short = AutoregressiveMask()
92 self.hour_glass = None
93 else:
Insert another hourglass model recursively
95 self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])
Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.
99 self.up_sampling = NaiveUpSampling(k)
The final transformer layer after up-sampling
102 self.post = TransformerLayer(d_model=d_model,
103 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
104 feed_forward=FeedForward(d_model, d_ff, dropout),
105 dropout_prob=dropout)
107 def forward(self, x: torch.Tensor):
Initial transformer layer
110 x = self.pre(x=x, mask=self.mask(x))
Shifting and shortening
113 x_short = self.shortening(self.shift_right(x))
If we are at the center of the hourglass,
117 if self.hour_glass is None:
Center transformer layer
120 x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))
122 else:
124 x_short = self.hour_glass(x_short)
Up-sample the shortened sequence and add a skip connection
128 x = x + self.up_sampling(x, x_short)
Final transformer layer
131 x = self.post(x=x, mask=self.mask(x))
134 return x
137class ShiftRight(Module):
shift
is the number of steps to shift by144 def __init__(self, shift: int):
148 super().__init__()
cannot be negative
150 assert shift >= 0
152 self.shift = shift
x
is a tensor of shape [seq_len, ...]
154 def forward(self, x: torch.Tensor):
If the shift is return the original
159 if self.shift == 0:
160 return x
Zeros to be appended to the left
162 prefix = x.new_zeros([self.shift, *x.shape[1:]])
Concatenate the zeros and truncate the right
164 return torch.cat([prefix, x[:-self.shift]])
167class AvgPoolShortening(Module):
k
is the shortening factor174 def __init__(self, k: int):
178 super().__init__()
Average pooling layer
180 self.pool = nn.AvgPool1d(k, ceil_mode=True)
x
is of shape [seq_len, batch_size, d_model]
182 def forward(self, x: torch.Tensor):
Pooling layer accepts shape [batch_size, d_model, seq_len]
so we permute axes.
188 return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)
191class NaiveUpSampling(Module):
k
is the shortening factor198 def __init__(self, k: int):
202 super().__init__()
203 self.k = k
x
is the tensor with embeddings before down-sampling x_short
is the tensor of higher density (to be up-sampled) representations205 def forward(self, x: torch.Tensor, x_short: torch.Tensor):
Repeat across the sequence dimension
211 expanded = torch.repeat_interleave(x_short, self.k, dim=0)
Truncate the extra embeddings at the end
213 expanded = expanded[:x.shape[0]]
216 return expanded
219class AutoregressiveMask(Module):
224 def __init__(self):
225 super().__init__()
226 self.mask = None
228 def forward(self, x: torch.Tensor):
Create a mask if we haven't created or sizes have changed
230 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
232 self.mask = subsequent_mask(len(x)).to(x.device)
235 return self.mask
This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.
238class LinearPoolingShortening(Module):
246 def __init__(self):
247 super().__init__()
248 raise NotImplementedError
251class AttentionBasedShortening(Module):
263 def __init__(self):
264 super().__init__()
265 raise NotImplementedError
Make a linear projection of dense token embeddings to a size of .
268class LinearUpSampling(Module):
275 def __init__(self):
276 super().__init__()
277 raise NotImplementedError
280class AttentionBasedUpSampling(Module):
292 def __init__(self):
293 super().__init__()
294 raise NotImplementedError