Hierarchical Transformers Are More Efficient Language Models

This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.

This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.

They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.

Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.

Here is the training code for the hourglass model.

View Run

30from typing import List
32import torch
33from torch import nn
35from labml_helpers.module import Module
36from labml_nn.transformers import MultiHeadAttention, TransformerLayer
37from labml_nn.transformers.feed_forward import FeedForward
38from labml_nn.transformers.utils import subsequent_mask

Hourglass model

This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).

41class HourGlass(Module):
51    def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
59        super().__init__()

The transformer layer before down-sampling

62        self.pre = TransformerLayer(d_model=d_model,
64                                    self_attn=MultiHeadAttention(n_heads, d_model, dropout),
66                                    feed_forward=FeedForward(d_model, d_ff, dropout),

68                                    dropout_prob=dropout)

Auto-regressive mask

70        self.mask = AutoregressiveMask()

The shortening factor (or the down-sampling rate)

73        k = shortening_factors[0]

We shift the tokens to the right by steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling

78        self.shift_right = ShiftRight(k - 1)

Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.

81        self.shortening = AvgPoolShortening(k)

If there are no more shortening (middle of the hourglass)

84        if len(shortening_factors) == 1:

The center layer is another transformer layer

86            self.shortened = TransformerLayer(d_model=d_model,
87                                              self_attn=MultiHeadAttention(n_heads, d_model, dropout),
88                                              feed_forward=FeedForward(d_model, d_ff, dropout),
89                                              dropout_prob=dropout)

Autoregressive mask

91            self.mask_short = AutoregressiveMask()
92            self.hour_glass = None
93        else:

Insert another hourglass model recursively

95            self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])

Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.

99        self.up_sampling = NaiveUpSampling(k)

The final transformer layer after up-sampling

102        self.post = TransformerLayer(d_model=d_model,
103                                     self_attn=MultiHeadAttention(n_heads, d_model, dropout),
104                                     feed_forward=FeedForward(d_model, d_ff, dropout),
105                                     dropout_prob=dropout)
107    def forward(self, x: torch.Tensor):

Initial transformer layer

110        x = self.pre(x=x, mask=self.mask(x))

Shifting and shortening

113        x_short = self.shortening(self.shift_right(x))

If we are at the center of the hourglass,

117        if self.hour_glass is None:

Center transformer layer

120            x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))

122        else:

124            x_short = self.hour_glass(x_short)

Up-sample the shortened sequence and add a skip connection

128        x = x + self.up_sampling(x, x_short)

Final transformer layer

131        x = self.post(x=x, mask=self.mask(x))

134        return x

Shift right operation

This shifts the sequence to the right by the given number of steps

137class ShiftRight(Module):
  • shift is the number of steps to shift by
144    def __init__(self, shift: int):
148        super().__init__()

cannot be negative

150        assert shift >= 0

152        self.shift = shift
  • x is a tensor of shape [seq_len, ...]
154    def forward(self, x: torch.Tensor):

If the shift is return the original

159        if self.shift == 0:
160            return x

Zeros to be appended to the left

162        prefix = x.new_zeros([self.shift, *x.shape[1:]])

Concatenate the zeros and truncate the right

164        return torch.cat([prefix, x[:-self.shift]])

Average pool shortening

This down-samples by a given factor with average pooling

167class AvgPoolShortening(Module):
  • k is the shortening factor
174    def __init__(self, k: int):
178        super().__init__()

Average pooling layer

180        self.pool = nn.AvgPool1d(k, ceil_mode=True)
  • x is of shape [seq_len, batch_size, d_model]
182    def forward(self, x: torch.Tensor):

Pooling layer accepts shape [batch_size, d_model, seq_len] so we permute axes.

188        return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)

Naive up-sampling

This up-samples by repeating

191class NaiveUpSampling(Module):
  • k is the shortening factor
198    def __init__(self, k: int):
202        super().__init__()
203        self.k = k
  • x is the tensor with embeddings before down-sampling
  • x_short is the tensor of higher density (to be up-sampled) representations
205    def forward(self, x: torch.Tensor, x_short: torch.Tensor):

Repeat across the sequence dimension

211        expanded = torch.repeat_interleave(x_short, self.k, dim=0)

Truncate the extra embeddings at the end

213        expanded = expanded[:x.shape[0]]

216        return expanded

Generate auto-regressive mask

219class AutoregressiveMask(Module):
224    def __init__(self):
225        super().__init__()
226        self.mask = None
228    def forward(self, x: torch.Tensor):

Create a mask if we haven't created or sizes have changed

230        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

232            self.mask = subsequent_mask(len(x)).to(x.device)

235        return self.mask

🚧 Linear pooling for down-sampling

This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.

238class LinearPoolingShortening(Module):
246    def __init__(self):
247        super().__init__()
248        raise NotImplementedError

🚧 Down-sampling with attention

where is average pooling or linear pooling.

251class AttentionBasedShortening(Module):
263    def __init__(self):
264        super().__init__()
265        raise NotImplementedError

🚧 Linear projection for up-sampling

Make a linear projection of dense token embeddings to a size of .

268class LinearUpSampling(Module):
275    def __init__(self):
276        super().__init__()
277        raise NotImplementedError

🚧 Attention based up-sampling


280class AttentionBasedUpSampling(Module):
292    def __init__(self):
293        super().__init__()
294        raise NotImplementedError