FNet Experiment

This is an annotated PyTorch experiment to train a FNet model.

This is based on general training loop and configurations for AG News classification task.

15import torch
16from torch import nn
18from labml import experiment
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_nn.experiments.nlp_classification import NLPClassificationConfigs
22from labml_nn.transformers import Encoder
23from labml_nn.transformers import TransformerConfigs

Transformer based classifier model

26class TransformerClassifier(nn.Module):
30    def __init__(self, encoder: Encoder, src_embed: Module, generator: nn.Linear):
37        super().__init__()
38        self.src_embed = src_embed
39        self.encoder = encoder
40        self.generator = generator
42    def forward(self, x: torch.Tensor):

Get the token embeddings with positional encodings

44        x = self.src_embed(x)

Transformer encoder

46        x = self.encoder(x, None)

Get logits for classification.

We set the [CLS] token at the last position of the sequence. This is extracted by x[-1] , where x is of shape [seq_len, batch_size, d_model]

52        x = self.generator(x[-1])

Return results (second value is for state, since our trainer is used with RNNs also)

56        return x, None


This inherits from NLPClassificationConfigs

59class Configs(NLPClassificationConfigs):

Classification model

68    model: TransformerClassifier


70    transformer: TransformerConfigs

Transformer configurations

74def _transformer_configs(c: Configs):
81    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

83    conf.n_src_vocab = c.n_tokens
84    conf.n_tgt_vocab = c.n_tokens

87    return conf

Create FNetMix module that can replace the self-attention in transformer encoder layer .

91def fnet_mix():
97    from labml_nn.transformers.fnet import FNetMix
98    return FNetMix()

Create classification model

102def _model(c: Configs):
106    m = TransformerClassifier(c.transformer.encoder,
107                              c.transformer.src_embed,
108                              nn.Linear(c.d_model, c.n_classes)).to(c.device)
110    return m
113def main():

Create experiment

115    experiment.create(name="fnet")

Create configs

117    conf = Configs()

Override configurations

119    experiment.configs(conf, {

Use world level tokenizer

121        'tokenizer': 'basic_english',

Train for epochs

124        'epochs': 32,

Switch between training and validation for times per epoch

127        'inner_iterations': 10,

Transformer configurations (same as defaults)

130        'transformer.d_model': 512,
131        'transformer.ffn.d_ff': 2048,
132        'transformer.n_heads': 8,
133        'transformer.n_layers': 6,

Use FNet instead of self-a ttention

137        'transformer.encoder_attn': 'fnet_mix',
140        'optimizer.optimizer': 'Noam',
141        'optimizer.learning_rate': 1.,
142    })

Set models for saving and loading

145    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

148    with experiment.start():

Run training

150        conf.run()

154if __name__ == '__main__':
155    main()