10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigs
21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
默认使用 CIFAR10 数据集
30 dataset_name: str = 'CIFAR10'
33@option(CIFAR10Configs.train_dataset)
34def cifar10_train_augmented():
38 from torchvision.datasets import CIFAR10
39 from torchvision.transforms import transforms
40 return CIFAR10(str(lab.get_data_path()),
41 train=True,
42 download=True,
43 transform=transforms.Compose([
填充和裁剪
45 transforms.RandomCrop(32, padding=4),
随机水平翻转
47 transforms.RandomHorizontalFlip(),
49 transforms.ToTensor(),
50 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51 ]))
54@option(CIFAR10Configs.valid_dataset)
55def cifar10_valid_no_augment():
59 from torchvision.datasets import CIFAR10
60 from torchvision.transforms import transforms
61 return CIFAR10(str(lab.get_data_path()),
62 train=False,
63 download=True,
64 transform=transforms.Compose([
65 transforms.ToTensor(),
66 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67 ]))
70class CIFAR10VGGModel(Module):
卷积和激活相结合
75 def conv_block(self, in_channels, out_channels) -> nn.Module:
79 return nn.Sequential(
80 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81 nn.ReLU(inplace=True),
82 )
84 def __init__(self, blocks: List[List[int]]):
85 super().__init__()
5 个池化图层将生成大小为 size 的输出。CIFAR 10 图像大小为
89 assert len(blocks) == 5
90 layers = []
RGB 通道
92 in_channels = 3
每个区块中每层的通道数
94 for block in blocks:
卷积、归一化和激活层
96 for channels in block:
97 layers += self.conv_block(in_channels, channels)
98 in_channels = channels
每个区块末端的最大池数
100 layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
使用层创建顺序模型
103 self.layers = nn.Sequential(*layers)
最后的 logits 层
105 self.fc = nn.Linear(in_channels, 10)
107 def forward(self, x):
VGG 层
109 x = self.layers(x)
修改分类图层的形状
111 x = x.view(x.shape[0], -1)
最后的线性层
113 return self.fc(x)