FNet Experiment

This is an annotated PyTorch experiment to train a FNet model.

This is based on general training loop and configurations for AG News classification task.

15import torch
16from torch import nn
17
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.nlp_classification import NLPClassificationConfigs
21from labml_nn.transformers import Encoder
22from labml_nn.transformers import TransformerConfigs

Transformer based classifier model

25class TransformerClassifier(nn.Module):
29    def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Linear):
36        super().__init__()
37        self.src_embed = src_embed
38        self.encoder = encoder
39        self.generator = generator
41    def forward(self, x: torch.Tensor):

Get the token embeddings with positional encodings

43        x = self.src_embed(x)

Transformer encoder

45        x = self.encoder(x, None)

Get logits for classification.

We set the [CLS] token at the last position of the sequence. This is extracted by x[-1] , where x is of shape [seq_len, batch_size, d_model]

51        x = self.generator(x[-1])

Return results (second value is for state, since our trainer is used with RNNs also)

55        return x, None

Configurations

This inherits from NLPClassificationConfigs

58class Configs(NLPClassificationConfigs):

Classification model

67    model: TransformerClassifier

Transformer

69    transformer: TransformerConfigs

Transformer configurations

72@option(Configs.transformer)
73def _transformer_configs(c: Configs):
80    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

82    conf.n_src_vocab = c.n_tokens
83    conf.n_tgt_vocab = c.n_tokens

86    return conf

Create FNetMix module that can replace the self-attention in transformer encoder layer .

89@option(TransformerConfigs.encoder_attn)
90def fnet_mix():
96    from labml_nn.transformers.fnet import FNetMix
97    return FNetMix()

Create classification model

100@option(Configs.model)
101def _model(c: Configs):
105    m = TransformerClassifier(c.transformer.encoder,
106                              c.transformer.src_embed,
107                              nn.Linear(c.d_model, c.n_classes)).to(c.device)
108
109    return m
112def main():

Create experiment

114    experiment.create(name="fnet")

Create configs

116    conf = Configs()

Override configurations

118    experiment.configs(conf, {

Use world level tokenizer

120        'tokenizer': 'basic_english',

Train for epochs

123        'epochs': 32,

Switch between training and validation for times per epoch

126        'inner_iterations': 10,

Transformer configurations (same as defaults)

129        'transformer.d_model': 512,
130        'transformer.ffn.d_ff': 2048,
131        'transformer.n_heads': 8,
132        'transformer.n_layers': 6,

Use FNet instead of self-a ttention

136        'transformer.encoder_attn': 'fnet_mix',
139        'optimizer.optimizer': 'Noam',
140        'optimizer.learning_rate': 1.,
141    })

Set models for saving and loading

144    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

147    with experiment.start():

Run training

149        conf.run()

153if __name__ == '__main__':
154    main()