Train a Vision Transformer (ViT) on CIFAR 10

11from labml import experiment
12from labml.configs import option
13from labml_nn.experiments.cifar10 import CIFAR10Configs
14from labml_nn.transformers import TransformerConfigs

Configurations

We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

17class Configs(CIFAR10Configs):
27    transformer: TransformerConfigs

Size of a patch

30    patch_size: int = 4

Size of the hidden layer in classification head

32    n_hidden_classification: int = 2048

Number of classes in the task

34    n_classes: int = 10

Create transformer configs

37@option(Configs.transformer)
38def _transformer():
42    return TransformerConfigs()

Create model

45@option(Configs.model)
46def _vit(c: Configs):
50    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
51        PatchEmbeddings

Transformer size from Transformer configurations

54    d_model = c.transformer.d_model

Create a vision transformer

56    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
57                             PatchEmbeddings(d_model, c.patch_size, 3),
58                             LearnedPositionalEmbeddings(d_model),
59                             ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
62def main():

Create experiment

64    experiment.create(name='ViT', comment='cifar10')

Create configurations

66    conf = Configs()

Load configurations

68    experiment.configs(conf, {

Optimizer

70        'optimizer.optimizer': 'Adam',
71        'optimizer.learning_rate': 2.5e-4,

Transformer embedding size

74        'transformer.d_model': 512,

Training epochs and batch size

77        'epochs': 32,
78        'train_batch_size': 64,

Augment CIFAR 10 images for training

81        'train_dataset': 'cifar10_train_augmented',

Do not augment CIFAR 10 images for validation

83        'valid_dataset': 'cifar10_valid_no_augment',
84    })

Set model for saving/loading

86    experiment.add_pytorch_models({'model': conf.model})

Start the experiment and run the training loop

88    with experiment.start():
89        conf.run()

93if __name__ == '__main__':
94    main()