10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigs
This extends from CIFAR 10 dataset configurations from labml_helpers
and MNISTConfigs
.
21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
Use CIFAR10 dataset by default
30 dataset_name: str = 'CIFAR10'
33@option(CIFAR10Configs.train_dataset)
34def cifar10_train_augmented():
38 from torchvision.datasets import CIFAR10
39 from torchvision.transforms import transforms
40 return CIFAR10(str(lab.get_data_path()),
41 train=True,
42 download=True,
43 transform=transforms.Compose([
Pad and crop
45 transforms.RandomCrop(32, padding=4),
Random horizontal flip
47 transforms.RandomHorizontalFlip(),
49 transforms.ToTensor(),
50 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51 ]))
54@option(CIFAR10Configs.valid_dataset)
55def cifar10_valid_no_augment():
59 from torchvision.datasets import CIFAR10
60 from torchvision.transforms import transforms
61 return CIFAR10(str(lab.get_data_path()),
62 train=False,
63 download=True,
64 transform=transforms.Compose([
65 transforms.ToTensor(),
66 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67 ]))
70class CIFAR10VGGModel(Module):
Convolution and activation combined
75 def conv_block(self, in_channels, out_channels) -> nn.Module:
79 return nn.Sequential(
80 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81 nn.ReLU(inplace=True),
82 )
84 def __init__(self, blocks: List[List[int]]):
85 super().__init__()
5 pooling layers will produce a output of size . CIFAR 10 image size is
89 assert len(blocks) == 5
90 layers = []
RGB channels
92 in_channels = 3
Number of channels in each layer in each block
94 for block in blocks:
Convolution, Normalization and Activation layers
96 for channels in block:
97 layers += self.conv_block(in_channels, channels)
98 in_channels = channels
Max pooling at end of each block
100 layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
Create a sequential model with the layers
103 self.layers = nn.Sequential(*layers)
Final logits layer
105 self.fc = nn.Linear(in_channels, 10)
107 def forward(self, x):
The VGG layers
109 x = self.layers(x)
Reshape for classification layer
111 x = x.view(x.shape[0], -1)
Final linear layer
113 return self.fc(x)