Configurable Transformer Components

9import copy
11import torch.nn as nn
13from labml.configs import BaseConfigs, option, calculate, aggregate
14from labml_helpers.module import Module
15from .feed_forward import FeedForward
16from .mha import MultiHeadAttention
17from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
18    Encoder, Decoder, Generator, EncoderDecoder

FFN Configurations

Creates a Position-wise FeedForward Network defined in .

21class FeedForwardConfigs(BaseConfigs):

Position-wise feedforward layer

31    ffn: FeedForward

Number of features in the embedding

33    d_model: int

Number of features in in the hidden layer

35    d_ff: int = 2048

Dropout probability

37    dropout: float = 0.1

Activation in position-wise feedforward layer

39    activation: nn.Module = 'ReLU'

Whether the FFN layer should be gated

41    is_gated: bool = False

Whether the first fully connected layer should have a learnable bias

43    bias1: bool = True

Whether the second fully connected layer should have a learnable bias

45    bias2: bool = True

Whether the fully connected layer for the gate should have a learnable bias

47    bias_gate: bool = False

Predefined GLU variants

49    glu_variant: str = 'none'

ReLU activation

52@option(FeedForwardConfigs.activation, 'ReLU')
53def _ffn_activation_relu():
59    return nn.ReLU()

GELU activation


It was introduced in paper Gaussian Error Linear Units.

62@option(FeedForwardConfigs.activation, 'GELU')
63def _ffn_activation_gelu():
71    return nn.GELU()

Initialize a feed forward network

74@option(FeedForwardConfigs.ffn, 'default')
75def _feed_forward(c: FeedForwardConfigs):
79    return FeedForward(c.d_model, c.d_ff,
80                       dropout=c.dropout,
81                       activation=c.activation,
82                       is_gated=c.is_gated,
83                       bias1=c.bias1,
84                       bias2=c.bias2,
85                       bias_gate=c.bias_gate)

GLU Variants

These are variants with gated hidden layers for the FFN as introduced in paper GLU Variants Improve Transformer. We have omitted the bias terms as specified in the paper.

FFN with Gated Linear Units

95aggregate(FeedForwardConfigs.glu_variant, 'GLU',
96          (FeedForwardConfigs.is_gated, True),
97          (FeedForwardConfigs.bias1, False),
98          (FeedForwardConfigs.bias2, False),
99          (FeedForwardConfigs.bias_gate, False),
100          (FeedForwardConfigs.activation, nn.Sigmoid()))

FFN with Bilinear hidden layer

105aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
106          (FeedForwardConfigs.is_gated, True),
107          (FeedForwardConfigs.bias1, False),
108          (FeedForwardConfigs.bias2, False),
109          (FeedForwardConfigs.bias_gate, False),
110          (FeedForwardConfigs.activation, nn.Identity()))

FFN with ReLU gate

115aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
116          (FeedForwardConfigs.is_gated, True),
117          (FeedForwardConfigs.bias1, False),
118          (FeedForwardConfigs.bias2, False),
119          (FeedForwardConfigs.bias_gate, False),
120          (FeedForwardConfigs.activation, nn.ReLU()))

FFN with GELU gate

125aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
126          (FeedForwardConfigs.is_gated, True),
127          (FeedForwardConfigs.bias1, False),
128          (FeedForwardConfigs.bias2, False),
129          (FeedForwardConfigs.bias_gate, False),
130          (FeedForwardConfigs.activation, nn.GELU()))

FFN with Swish gate


136aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
137          (FeedForwardConfigs.is_gated, True),
138          (FeedForwardConfigs.bias1, False),
139          (FeedForwardConfigs.bias2, False),
140          (FeedForwardConfigs.bias_gate, False),
141          (FeedForwardConfigs.activation, nn.SiLU()))

Transformer Configurations

This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated.

144class TransformerConfigs(BaseConfigs):

Number of attention heads

156    n_heads: int = 8

Transformer embedding size

158    d_model: int = 512

Number of layers

160    n_layers: int = 6

Dropout probability

162    dropout: float = 0.1

Number of tokens in the source vocabulary (for token embeddings)

164    n_src_vocab: int

Number of tokens in the target vocabulary (to generate logits for prediction)

166    n_tgt_vocab: int

The encoder self attention

169    encoder_attn: MultiHeadAttention = 'mha'

The decoder self attention

171    decoder_attn: MultiHeadAttention = 'mha'

The decoder memory attention

173    decoder_mem_attn: MultiHeadAttention = 'mha'

Configurable Feedforward Layer

176    ffn: FeedForwardConfigs

Encoder layer

179    encoder_layer: TransformerLayer = 'default'

Decoder layer

181    decoder_layer: TransformerLayer = 'default'

Encoder consisting of multiple encoder layers

184    encoder: Encoder = 'default'

Encoder consisting of multiple decoder layers

186    decoder: Decoder = 'default'

Embedding layer for source

189    src_embed: Module = 'fixed_pos'

Embedding layer for target (for decoder)

191    tgt_embed: Module = 'fixed_pos'

Logit generator for prediction

194    generator: Generator = 'default'


197    encoder_decoder: EncoderDecoder

Multi-head Attention

201def _mha(c: TransformerConfigs):
202    return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
205calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
206calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
207calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)

Relative Multi-head Attention

211def _relative_mha(c: TransformerConfigs):
212    from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
213    return RelativeMultiHeadAttention(c.n_heads, c.d_model)
216calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
217calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
218calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)

Create feedforward layer configurations

221@option(TransformerConfigs.ffn, 'default')
222def _feed_forward(c: TransformerConfigs):
226    conf = FeedForwardConfigs()
227    conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
228    conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
229    return conf

Encoder layer

232@option(TransformerConfigs.encoder_layer, 'default')
233def _encoder_layer(c: TransformerConfigs):
237    return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
238                            src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
239                            dropout_prob=c.dropout)

Decoder layer

242@option(TransformerConfigs.decoder_layer, 'default')
243def _decoder_layer(c: TransformerConfigs):
247    return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
248                            src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
249                            dropout_prob=c.dropout)


252@option(TransformerConfigs.encoder, 'default')
253def _encoder(c: TransformerConfigs):
257    return Encoder(c.encoder_layer, c.n_layers)


260@option(TransformerConfigs.decoder, 'default')
261def _decoder(c: TransformerConfigs):
265    return Decoder(c.decoder_layer, c.n_layers)

Logit generator

268@option(TransformerConfigs.generator, 'default')
269def _generator(c: TransformerConfigs):
273    return Generator(c.n_tgt_vocab, c.d_model)

Fixed Positional Embeddings

Source embedding with fixed positional encodings

277@option(TransformerConfigs.src_embed, 'fixed_pos')
278def _src_embed_with_positional(c: TransformerConfigs):
282    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)

Target embedding with fixed positional encodings

285@option(TransformerConfigs.tgt_embed, 'fixed_pos')
286def _tgt_embed_with_positional(c: TransformerConfigs):
290    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)

Learned Positional Embeddings

Source embedding with learned positional encodings

294@option(TransformerConfigs.src_embed, 'learned_pos')
295def _src_embed_with_learned_positional(c: TransformerConfigs):
299    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)

Target embedding with learned positional encodings

302@option(TransformerConfigs.tgt_embed, 'learned_pos')
303def _tgt_embed_with_learned_positional(c: TransformerConfigs):
307    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)

No Positional Embeddings

Source embedding without positional encodings

311@option(TransformerConfigs.src_embed, 'no_pos')
312def _src_embed_without_positional(c: TransformerConfigs):
316    return nn.Embedding(c.n_src_vocab, c.d_model)
319@option(TransformerConfigs.tgt_embed, 'no_pos')
320def _tgt_embed_without_positional(c: TransformerConfigs):
321    return nn.Embedding(c.n_tgt_vocab, c.d_model)
324@option(TransformerConfigs.encoder_decoder, 'default')
325def _encoder_decoder(c: TransformerConfigs):
326    return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)