This script trains a ConvMixer on CIFAR 10 dataset.
This is not an attempt to reproduce the results of the paper. The paper uses image augmentations present in PyTorch Image Models (timm) for training. We haven't done this for simplicity - which causes our validation accuracy to drop.
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs
We use CIFAR10Configs
which defines all the dataset related configurations, optimizer, and a training loop.
23class Configs(CIFAR10Configs):
Size of a patch,
32 patch_size: int = 2
Number of channels in patch embeddings,
34 d_model: int = 256
Number of ConvMixer layers or depth,
36 n_layers: int = 8
Kernel size of the depth-wise convolution,
38 kernel_size: int = 7
Number of classes in the task
40 n_classes: int = 10
43@option(Configs.model)
44def _conv_mixer(c: Configs):
48 from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings
Create ConvMixer
51 return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52 PatchEmbeddings(c.d_model, c.patch_size, 3),
53 ClassificationHead(c.d_model, c.n_classes)).to(c.device)
56def main():
Create experiment
58 experiment.create(name='ConvMixer', comment='cifar10')
Create configurations
60 conf = Configs()
Load configurations
62 experiment.configs(conf, {
Optimizer
64 'optimizer.optimizer': 'Adam',
65 'optimizer.learning_rate': 2.5e-4,
Training epochs and batch size
68 'epochs': 150,
69 'train_batch_size': 64,
Simple image augmentations
72 'train_dataset': 'cifar10_train_augmented',
Do not augment images for validation
74 'valid_dataset': 'cifar10_valid_no_augment',
75 })
Set model for saving/loading
77 experiment.add_pytorch_models({'model': conf.model})
Start the experiment and run the training loop
79 with experiment.start():
80 conf.run()
84if __name__ == '__main__':
85 main()