Train a small model on CIFAR 10

This trains a small model on CIFAR 10 to test how much distillation benefits.

15import torch.nn as nn
17from labml import experiment, logger
18from labml.configs import option
19from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
20from labml_nn.normalization.batch_norm import BatchNorm


We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

23class Configs(CIFAR10Configs):
30    pass

VGG style model for CIFAR-10 classification

This derives from the generic VGG style architecture.

33class SmallModel(CIFAR10VGGModel):

Create a convolution layer and the activations

40    def conv_block(self, in_channels, out_channels) -> nn.Module:
44        return nn.Sequential(

Convolution layer

46            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),

Batch normalization

48            BatchNorm(out_channels, track_running_stats=False),

ReLU activation

50            nn.ReLU(inplace=True),
51        )
53    def __init__(self):

Create a model with given convolution sizes (channels)

55        super().__init__([[32, 32], [64, 64], [128], [128], [128]])

Create model

59def _small_model(c: Configs):
63    return SmallModel().to(c.device)
66def main():

Create experiment

68    experiment.create(name='cifar10', comment='small model')

Create configurations

70    conf = Configs()

Load configurations

72    experiment.configs(conf, {
73        'optimizer.optimizer': 'Adam',
74        'optimizer.learning_rate': 2.5e-4,
75    })

Set model for saving/loading

77    experiment.add_pytorch_models({'model': conf.model})

Print number of parameters in the model

79    logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))

Start the experiment and run the training loop

81    with experiment.start():

86if __name__ == '__main__':
87    main()