This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.
This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.
They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.
Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.
Here is the training code for the hourglass model.
28from typing import List
29
30import torch
31from torch import nn
32
33from labml_helpers.module import Module
34from labml_nn.transformers import MultiHeadAttention, TransformerLayer
35from labml_nn.transformers.feed_forward import FeedForward
36from labml_nn.transformers.utils import subsequent_mask
This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).
39class HourGlass(Module):
n_heads
is the number of heads in multi-head attention layers d_model
is the size of the token embeddings dropout
is the dropout probability d_ff
is the dimensionality of the hidden layer in position-wise feed-forward layers shortening_factors
is the list of shortening factors49 def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
57 super().__init__()
The transformer layer before down-sampling
60 self.pre = TransformerLayer(d_model=d_model,
62 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
64 feed_forward=FeedForward(d_model, d_ff, dropout),
66 dropout_prob=dropout)
Auto-regressive mask
68 self.mask = AutoregressiveMask()
The shortening factor (or the down-sampling rate)
71 k = shortening_factors[0]
We shift the tokens to the right by steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling
76 self.shift_right = ShiftRight(k - 1)
Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.
79 self.shortening = AvgPoolShortening(k)
If there are no more shortening (middle of the hourglass)
82 if len(shortening_factors) == 1:
The center layer is another transformer layer
84 self.shortened = TransformerLayer(d_model=d_model,
85 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
86 feed_forward=FeedForward(d_model, d_ff, dropout),
87 dropout_prob=dropout)
Autoregressive mask
89 self.mask_short = AutoregressiveMask()
90 self.hour_glass = None
91 else:
Insert another hourglass model recursively
93 self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])
Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.
97 self.up_sampling = NaiveUpSampling(k)
The final transformer layer after up-sampling
100 self.post = TransformerLayer(d_model=d_model,
101 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
102 feed_forward=FeedForward(d_model, d_ff, dropout),
103 dropout_prob=dropout)
105 def forward(self, x: torch.Tensor):
Initial transformer layer
108 x = self.pre(x=x, mask=self.mask(x))
Shifting and shortening
111 x_short = self.shortening(self.shift_right(x))
If we are at the center of the hourglass,
115 if self.hour_glass is None:
Center transformer layer
118 x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))
120 else:
122 x_short = self.hour_glass(x_short)
Up-sample the shortened sequence and add a skip connection
126 x = x + self.up_sampling(x, x_short)
Final transformer layer
129 x = self.post(x=x, mask=self.mask(x))
132 return x
135class ShiftRight(Module):
shift
is the number of steps to shift by142 def __init__(self, shift: int):
146 super().__init__()
cannot be negative
148 assert shift >= 0
150 self.shift = shift
x
is a tensor of shape [seq_len, ...]
152 def forward(self, x: torch.Tensor):
If the shift is return the original
157 if self.shift == 0:
158 return x
Zeros to be appended to the left
160 prefix = x.new_zeros([self.shift, *x.shape[1:]])
Concatenate the zeros and truncate the right
162 return torch.cat([prefix, x[:-self.shift]])
165class AvgPoolShortening(Module):
k
is the shortening factor172 def __init__(self, k: int):
176 super().__init__()
Average pooling layer
178 self.pool = nn.AvgPool1d(k, ceil_mode=True)
x
is of shape [seq_len, batch_size, d_model]
180 def forward(self, x: torch.Tensor):
Pooling layer accepts shape [batch_size, d_model, seq_len]
so we permute axes.
186 return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)
189class NaiveUpSampling(Module):
k
is the shortening factor196 def __init__(self, k: int):
200 super().__init__()
201 self.k = k
x
is the tensor with embeddings before down-sampling x_short
is the tensor of higher density (to be up-sampled) representations203 def forward(self, x: torch.Tensor, x_short: torch.Tensor):
Repeat across the sequence dimension
209 expanded = torch.repeat_interleave(x_short, self.k, dim=0)
Truncate the extra embeddings at the end
211 expanded = expanded[:x.shape[0]]
214 return expanded
217class AutoregressiveMask(Module):
222 def __init__(self):
223 super().__init__()
224 self.mask = None
226 def forward(self, x: torch.Tensor):
Create a mask if we haven't created or sizes have changed
228 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
230 self.mask = subsequent_mask(len(x)).to(x.device)
233 return self.mask
This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.
236class LinearPoolingShortening(Module):
244 def __init__(self):
245 super().__init__()
246 raise NotImplementedError
249class AttentionBasedShortening(Module):
261 def __init__(self):
262 super().__init__()
263 raise NotImplementedError
Make a linear projection of dense token embeddings to a size of .
266class LinearUpSampling(Module):
273 def __init__(self):
274 super().__init__()
275 raise NotImplementedError
278class AttentionBasedUpSampling(Module):
290 def __init__(self):
291 super().__init__()
292 raise NotImplementedError