Train a ConvMixer on CIFAR 10

This script trains a ConvMixer on CIFAR 10 dataset.

This is not an attempt to reproduce the results of the paper. The paper uses image augmentations present in PyTorch Image Models (timm) for training. We haven't done this for simplicity - which causes our validation accuracy to drop.

18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs


We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

23class Configs(CIFAR10Configs):

Size of a patch,

32    patch_size: int = 2

Number of channels in patch embeddings,

34    d_model: int = 256

Number of ConvMixer layers or depth,

36    n_layers: int = 8

Kernel size of the depth-wise convolution,

38    kernel_size: int = 7

Number of classes in the task

40    n_classes: int = 10

Create model

44def _conv_mixer(c: Configs):
48    from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings

Create ConvMixer

51    return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52                     PatchEmbeddings(c.d_model, c.patch_size, 3),
53                     ClassificationHead(c.d_model, c.n_classes)).to(c.device)
56def main():

Create experiment

58    experiment.create(name='ConvMixer', comment='cifar10')

Create configurations

60    conf = Configs()

Load configurations

62    experiment.configs(conf, {


64        'optimizer.optimizer': 'Adam',
65        'optimizer.learning_rate': 2.5e-4,

Training epochs and batch size

68        'epochs': 150,
69        'train_batch_size': 64,

Simple image augmentations

72        'train_dataset': 'cifar10_train_augmented',

Do not augment images for validation

74        'valid_dataset': 'cifar10_valid_no_augment',
75    })

Set model for saving/loading

77    experiment.add_pytorch_models({'model': conf.model})

Start the experiment and run the training loop

79    with experiment.start():

84if __name__ == '__main__':
85    main()