Hierarchical Transformers Are More Efficient Language Models

This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.

This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.

They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.

Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.

Here is the training code for the hourglass model.

28from typing import List
30import torch
31from torch import nn
33from labml_helpers.module import Module
34from labml_nn.transformers import MultiHeadAttention, TransformerLayer
35from labml_nn.transformers.feed_forward import FeedForward
36from labml_nn.transformers.utils import subsequent_mask

Hourglass model

This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).

39class HourGlass(Module):
49    def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
57        super().__init__()

The transformer layer before down-sampling

60        self.pre = TransformerLayer(d_model=d_model,
62                                    self_attn=MultiHeadAttention(n_heads, d_model, dropout),
64                                    feed_forward=FeedForward(d_model, d_ff, dropout),

66                                    dropout_prob=dropout)

Auto-regressive mask

68        self.mask = AutoregressiveMask()

The shortening factor (or the down-sampling rate)

71        k = shortening_factors[0]

We shift the tokens to the right by steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling

76        self.shift_right = ShiftRight(k - 1)

Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.

79        self.shortening = AvgPoolShortening(k)

If there are no more shortening (middle of the hourglass)

82        if len(shortening_factors) == 1:

The center layer is another transformer layer

84            self.shortened = TransformerLayer(d_model=d_model,
85                                              self_attn=MultiHeadAttention(n_heads, d_model, dropout),
86                                              feed_forward=FeedForward(d_model, d_ff, dropout),
87                                              dropout_prob=dropout)

Autoregressive mask

89            self.mask_short = AutoregressiveMask()
90            self.hour_glass = None
91        else:

Insert another hourglass model recursively

93            self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])

Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.

97        self.up_sampling = NaiveUpSampling(k)

The final transformer layer after up-sampling

100        self.post = TransformerLayer(d_model=d_model,
101                                     self_attn=MultiHeadAttention(n_heads, d_model, dropout),
102                                     feed_forward=FeedForward(d_model, d_ff, dropout),
103                                     dropout_prob=dropout)
105    def forward(self, x: torch.Tensor):

Initial transformer layer

108        x = self.pre(x=x, mask=self.mask(x))

Shifting and shortening

111        x_short = self.shortening(self.shift_right(x))

If we are at the center of the hourglass,

115        if self.hour_glass is None:

Center transformer layer

118            x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))

120        else:

122            x_short = self.hour_glass(x_short)

Up-sample the shortened sequence and add a skip connection

126        x = x + self.up_sampling(x, x_short)

Final transformer layer

129        x = self.post(x=x, mask=self.mask(x))

132        return x

Shift right operation

This shifts the sequence to the right by the given number of steps

135class ShiftRight(Module):
  • shift is the number of steps to shift by
142    def __init__(self, shift: int):
146        super().__init__()

cannot be negative

148        assert shift >= 0

150        self.shift = shift
  • x is a tensor of shape [seq_len, ...]
152    def forward(self, x: torch.Tensor):

If the shift is return the original

157        if self.shift == 0:
158            return x

Zeros to be appended to the left

160        prefix = x.new_zeros([self.shift, *x.shape[1:]])

Concatenate the zeros and truncate the right

162        return torch.cat([prefix, x[:-self.shift]])

Average pool shortening

This down-samples by a given factor with average pooling

165class AvgPoolShortening(Module):
  • k is the shortening factor
172    def __init__(self, k: int):
176        super().__init__()

Average pooling layer

178        self.pool = nn.AvgPool1d(k, ceil_mode=True)
  • x is of shape [seq_len, batch_size, d_model]
180    def forward(self, x: torch.Tensor):

Pooling layer accepts shape [batch_size, d_model, seq_len] so we permute axes.

186        return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)

Naive up-sampling

This up-samples by repeating

189class NaiveUpSampling(Module):
  • k is the shortening factor
196    def __init__(self, k: int):
200        super().__init__()
201        self.k = k
  • x is the tensor with embeddings before down-sampling
  • x_short is the tensor of higher density (to be up-sampled) representations
203    def forward(self, x: torch.Tensor, x_short: torch.Tensor):

Repeat across the sequence dimension

209        expanded = torch.repeat_interleave(x_short, self.k, dim=0)

Truncate the extra embeddings at the end

211        expanded = expanded[:x.shape[0]]

214        return expanded

Generate auto-regressive mask

217class AutoregressiveMask(Module):
222    def __init__(self):
223        super().__init__()
224        self.mask = None
226    def forward(self, x: torch.Tensor):

Create a mask if we haven't created or sizes have changed

228        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

230            self.mask = subsequent_mask(len(x)).to(x.device)

233        return self.mask

🚧 Linear pooling for down-sampling

This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.

236class LinearPoolingShortening(Module):
244    def __init__(self):
245        super().__init__()
246        raise NotImplementedError

🚧 Down-sampling with attention

where is average pooling or linear pooling.

249class AttentionBasedShortening(Module):
261    def __init__(self):
262        super().__init__()
263        raise NotImplementedError

🚧 Linear projection for up-sampling

Make a linear projection of dense token embeddings to a size of .

266class LinearUpSampling(Module):
273    def __init__(self):
274        super().__init__()
275        raise NotImplementedError

🚧 Attention based up-sampling


278class AttentionBasedUpSampling(Module):
290    def __init__(self):
291        super().__init__()
292        raise NotImplementedError