CIFAR10 Experiment

10from typing import List
12import torch.nn as nn
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigs


This extends from CIFAR 10 dataset configurations from labml_helpers and MNISTConfigs .

21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):

Use CIFAR10 dataset by default

30    dataset_name: str = 'CIFAR10'

Augmented CIFAR 10 train dataset

34def cifar10_train_augmented():
38    from torchvision.datasets import CIFAR10
39    from torchvision.transforms import transforms
40    return CIFAR10(str(lab.get_data_path()),
41                   train=True,
42                   download=True,
43                   transform=transforms.Compose([

Pad and crop

45                       transforms.RandomCrop(32, padding=4),

Random horizontal flip

47                       transforms.RandomHorizontalFlip(),

49                       transforms.ToTensor(),
50                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51                   ]))

Non-augmented CIFAR 10 validation dataset

55def cifar10_valid_no_augment():
59    from torchvision.datasets import CIFAR10
60    from torchvision.transforms import transforms
61    return CIFAR10(str(lab.get_data_path()),
62                   train=False,
63                   download=True,
64                   transform=transforms.Compose([
65                       transforms.ToTensor(),
66                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67                   ]))

VGG model for CIFAR-10 classification

70class CIFAR10VGGModel(Module):

Convolution and activation combined

75    def conv_block(self, in_channels, out_channels) -> nn.Module:
79        return nn.Sequential(
80            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81            nn.ReLU(inplace=True),
82        )
84    def __init__(self, blocks: List[List[int]]):
85        super().__init__()

5 pooling layers will produce a output of size . CIFAR 10 image size is

89        assert len(blocks) == 5
90        layers = []

RGB channels

92        in_channels = 3

Number of channels in each layer in each block

94        for block in blocks:

Convolution, Normalization and Activation layers

96            for channels in block:
97                layers += self.conv_block(in_channels, channels)
98                in_channels = channels

Max pooling at end of each block

100            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

Create a sequential model with the layers

103        self.layers = nn.Sequential(*layers)

Final logits layer

105        self.fc = nn.Linear(in_channels, 10)
107    def forward(self, x):

The VGG layers

109        x = self.layers(x)

Reshape for classification layer

111        x = x.view(x.shape[0], -1)

Final linear layer

113        return self.fc(x)