10from typing import List, Optional
11
12from torch import nn
13
14from labml import experiment
15from labml.configs import option
16from labml_nn.experiments.cifar10 import CIFAR10Configs
17from labml_nn.resnet import ResNetBase
We use CIFAR10Configs
which defines all the dataset related configurations, optimizer, and a training loop.
20class Configs(CIFAR10Configs):
Number fo blocks for each feature map size
29 n_blocks: List[int] = [3, 3, 3]
Number of channels for each feature map size
31 n_channels: List[int] = [16, 32, 64]
Bottleneck sizes
33 bottlenecks: Optional[List[int]] = None
Kernel size of the initial convolution layer
35 first_kernel_size: int = 3
38@option(Configs.model)
39def _resnet(c: Configs):
44 base = ResNetBase(c.n_blocks, c.n_channels, c.bottlenecks, img_channels=3, first_kernel_size=c.first_kernel_size)
Linear layer for classification
46 classification = nn.Linear(c.n_channels[-1], 10)
Stack them
49 model = nn.Sequential(base, classification)
Move the model to the device
51 return model.to(c.device)
54def main():
Create experiment
56 experiment.create(name='resnet', comment='cifar10')
Create configurations
58 conf = Configs()
Load configurations
60 experiment.configs(conf, {
61 'bottlenecks': [8, 16, 16],
62 'n_blocks': [6, 6, 6],
63
64 'optimizer.optimizer': 'Adam',
65 'optimizer.learning_rate': 2.5e-4,
66
67 'epochs': 500,
68 'train_batch_size': 256,
69
70 'train_dataset': 'cifar10_train_augmented',
71 'valid_dataset': 'cifar10_valid_no_augment',
72 })
Set model for saving/loading
74 experiment.add_pytorch_models({'model': conf.model})
Start the experiment and run the training loop
76 with experiment.start():
77 conf.run()
81if __name__ == '__main__':
82 main()