This is an annotated PyTorch experiment to train a FNet model.
This is based on general training loop and configurations for AG News classification task.
15import torch
16from torch import nn
17
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.nlp_classification import NLPClassificationConfigs
21from labml_nn.transformers import Encoder
22from labml_nn.transformers import TransformerConfigs
25class TransformerClassifier(nn.Module):
encoder
is the transformer Encoder src_embed
is the token embedding module (with positional encodings) generator
is the final fully connected layer that gives the logits.29 def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Linear):
36 super().__init__()
37 self.src_embed = src_embed
38 self.encoder = encoder
39 self.generator = generator
41 def forward(self, x: torch.Tensor):
Get the token embeddings with positional encodings
43 x = self.src_embed(x)
Transformer encoder
45 x = self.encoder(x, None)
Get logits for classification.
We set the [CLS]
token at the last position of the sequence. This is extracted by x[-1]
, where x
is of shape [seq_len, batch_size, d_model]
51 x = self.generator(x[-1])
Return results (second value is for state, since our trainer is used with RNNs also)
55 return x, None
58class Configs(NLPClassificationConfigs):
Classification model
67 model: TransformerClassifier
Transformer
69 transformer: TransformerConfigs
72@option(Configs.transformer)
73def _transformer_configs(c: Configs):
We use our configurable transformer implementation
80 conf = TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
82 conf.n_src_vocab = c.n_tokens
83 conf.n_tgt_vocab = c.n_tokens
86 return conf
Create FNetMix
module that can replace the self-attention in transformer encoder layer .
89@option(TransformerConfigs.encoder_attn)
90def fnet_mix():
96 from labml_nn.transformers.fnet import FNetMix
97 return FNetMix()
Create classification model
100@option(Configs.model)
101def _model(c: Configs):
105 m = TransformerClassifier(c.transformer.encoder,
106 c.transformer.src_embed,
107 nn.Linear(c.d_model, c.n_classes)).to(c.device)
108
109 return m
112def main():
Create experiment
114 experiment.create(name="fnet")
Create configs
116 conf = Configs()
Override configurations
118 experiment.configs(conf, {
Use world level tokenizer
120 'tokenizer': 'basic_english',
Train for epochs
123 'epochs': 32,
Switch between training and validation for times per epoch
126 'inner_iterations': 10,
Transformer configurations (same as defaults)
129 'transformer.d_model': 512,
130 'transformer.ffn.d_ff': 2048,
131 'transformer.n_heads': 8,
132 'transformer.n_layers': 6,
Use Noam optimizer
139 'optimizer.optimizer': 'Noam',
140 'optimizer.learning_rate': 1.,
141 })
Set models for saving and loading
144 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
147 with experiment.start():
Run training
149 conf.run()
153if __name__ == '__main__':
154 main()