CIFAR10 实验

10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigs

配置

这是从和开始的 CIFAR 10 数据集配置扩展labml_helpers 而来的MNISTConfigs

21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):

默认使用 CIFAR10 数据集

30    dataset_name: str = 'CIFAR10'

增强的 CIFAR 10 训练数据集

33@option(CIFAR10Configs.train_dataset)
34def cifar10_train_augmented():
38    from torchvision.datasets import CIFAR10
39    from torchvision.transforms import transforms
40    return CIFAR10(str(lab.get_data_path()),
41                   train=True,
42                   download=True,
43                   transform=transforms.Compose([

填充和裁剪

45                       transforms.RandomCrop(32, padding=4),

随机水平翻转

47                       transforms.RandomHorizontalFlip(),

49                       transforms.ToTensor(),
50                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51                   ]))

非增强 CIFAR 10 验证数据集

54@option(CIFAR10Configs.valid_dataset)
55def cifar10_valid_no_augment():
59    from torchvision.datasets import CIFAR10
60    from torchvision.transforms import transforms
61    return CIFAR10(str(lab.get_data_path()),
62                   train=False,
63                   download=True,
64                   transform=transforms.Compose([
65                       transforms.ToTensor(),
66                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67                   ]))

用于 CIFAR-10 分类的 VGG 模型

70class CIFAR10VGGModel(Module):

卷积和激活相结合

75    def conv_block(self, in_channels, out_channels) -> nn.Module:
79        return nn.Sequential(
80            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81            nn.ReLU(inplace=True),
82        )
84    def __init__(self, blocks: List[List[int]]):
85        super().__init__()

5 个池化图层将生成大小为 size 的输出。CIFAR 10 图像大小为

89        assert len(blocks) == 5
90        layers = []

RGB 通道

92        in_channels = 3

每个区块中每层的通道数

94        for block in blocks:

卷积、归一化和激活层

96            for channels in block:
97                layers += self.conv_block(in_channels, channels)
98                in_channels = channels

每个区块末端的最大池数

100            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

使用层创建顺序模型

103        self.layers = nn.Sequential(*layers)

最后的 logits 层

105        self.fc = nn.Linear(in_channels, 10)
107    def forward(self, x):

VGG 层

109        x = self.layers(x)

修改分类图层的形状

111        x = x.view(x.shape[0], -1)

最后的线性层

113        return self.fc(x)