Hierarchical Transformers Are More Efficient Language Models

This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.

This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.

They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.

Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.

Here is the training code for the hourglass model.

28from typing import List
29
30import torch
31from torch import nn
32
33from labml_nn.transformers import MultiHeadAttention, TransformerLayer
34from labml_nn.transformers.feed_forward import FeedForward
35from labml_nn.transformers.utils import subsequent_mask

Hourglass model

This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).

38class HourGlass(nn.Module):
48    def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
56        super().__init__()

The transformer layer before down-sampling

59        self.pre = TransformerLayer(d_model=d_model,
61                                    self_attn=MultiHeadAttention(n_heads, d_model, dropout),
63                                    feed_forward=FeedForward(d_model, d_ff, dropout),

65                                    dropout_prob=dropout)

Auto-regressive mask

67        self.mask = AutoregressiveMask()

The shortening factor (or the down-sampling rate)

70        k = shortening_factors[0]

We shift the tokens to the right by steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling

75        self.shift_right = ShiftRight(k - 1)

Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.

78        self.shortening = AvgPoolShortening(k)

If there are no more shortening (middle of the hourglass)

81        if len(shortening_factors) == 1:

The center layer is another transformer layer

83            self.shortened = TransformerLayer(d_model=d_model,
84                                              self_attn=MultiHeadAttention(n_heads, d_model, dropout),
85                                              feed_forward=FeedForward(d_model, d_ff, dropout),
86                                              dropout_prob=dropout)

Autoregressive mask

88            self.mask_short = AutoregressiveMask()
89            self.hour_glass = None
90        else:

Insert another hourglass model recursively

92            self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])

Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.

96        self.up_sampling = NaiveUpSampling(k)

The final transformer layer after up-sampling

99        self.post = TransformerLayer(d_model=d_model,
100                                     self_attn=MultiHeadAttention(n_heads, d_model, dropout),
101                                     feed_forward=FeedForward(d_model, d_ff, dropout),
102                                     dropout_prob=dropout)
104    def forward(self, x: torch.Tensor):

Initial transformer layer

107        x = self.pre(x=x, mask=self.mask(x))

Shifting and shortening

110        x_short = self.shortening(self.shift_right(x))

If we are at the center of the hourglass,

114        if self.hour_glass is None:

Center transformer layer

117            x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))

119        else:

121            x_short = self.hour_glass(x_short)

Up-sample the shortened sequence and add a skip connection

125        x = x + self.up_sampling(x, x_short)

Final transformer layer

128        x = self.post(x=x, mask=self.mask(x))

131        return x

Shift right operation

This shifts the sequence to the right by the given number of steps

134class ShiftRight(nn.Module):
  • shift is the number of steps to shift by
141    def __init__(self, shift: int):
145        super().__init__()

cannot be negative

147        assert shift >= 0

149        self.shift = shift
  • x is a tensor of shape [seq_len, ...]
151    def forward(self, x: torch.Tensor):

If the shift is return the original

156        if self.shift == 0:
157            return x

Zeros to be appended to the left

159        prefix = x.new_zeros([self.shift, *x.shape[1:]])

Concatenate the zeros and truncate the right

161        return torch.cat([prefix, x[:-self.shift]])

Average pool shortening

This down-samples by a given factor with average pooling

164class AvgPoolShortening(nn.Module):
  • k is the shortening factor
171    def __init__(self, k: int):
175        super().__init__()

Average pooling layer

177        self.pool = nn.AvgPool1d(k, ceil_mode=True)
  • x is of shape [seq_len, batch_size, d_model]
179    def forward(self, x: torch.Tensor):

Pooling layer accepts shape [batch_size, d_model, seq_len] so we permute axes.

185        return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)

Naive up-sampling

This up-samples by repeating

188class NaiveUpSampling(nn.Module):
  • k is the shortening factor
195    def __init__(self, k: int):
199        super().__init__()
200        self.k = k
  • x is the tensor with embeddings before down-sampling
  • x_short is the tensor of higher density (to be up-sampled) representations
202    def forward(self, x: torch.Tensor, x_short: torch.Tensor):

Repeat across the sequence dimension

208        expanded = torch.repeat_interleave(x_short, self.k, dim=0)

Truncate the extra embeddings at the end

210        expanded = expanded[:x.shape[0]]

213        return expanded

Generate auto-regressive mask

216class AutoregressiveMask(nn.Module):
221    def __init__(self):
222        super().__init__()
223        self.mask = None
225    def forward(self, x: torch.Tensor):

Create a mask if we haven't created or sizes have changed

227        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

229            self.mask = subsequent_mask(len(x)).to(x.device)

232        return self.mask

🚧 Linear pooling for down-sampling

This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.

235class LinearPoolingShortening(nn.Module):
243    def __init__(self):
244        super().__init__()
245        raise NotImplementedError

🚧 Down-sampling with attention

where is average pooling or linear pooling.

248class AttentionBasedShortening(nn.Module):
260    def __init__(self):
261        super().__init__()
262        raise NotImplementedError

🚧 Linear projection for up-sampling

Make a linear projection of dense token embeddings to a size of .

265class LinearUpSampling(nn.Module):
272    def __init__(self):
273        super().__init__()
274        raise NotImplementedError

🚧 Attention based up-sampling

where

277class AttentionBasedUpSampling(nn.Module):
289    def __init__(self):
290        super().__init__()
291        raise NotImplementedError