10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigs
21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
デフォルトで CIFAR10 データセットを使用
30 dataset_name: str = 'CIFAR10'
33@option(CIFAR10Configs.train_dataset)
34def cifar10_train_augmented():
38 from torchvision.datasets import CIFAR10
39 from torchvision.transforms import transforms
40 return CIFAR10(str(lab.get_data_path()),
41 train=True,
42 download=True,
43 transform=transforms.Compose([
パッドとクロップ
45 transforms.RandomCrop(32, padding=4),
ランダム水平反転
47 transforms.RandomHorizontalFlip(),
49 transforms.ToTensor(),
50 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51 ]))
54@option(CIFAR10Configs.valid_dataset)
55def cifar10_valid_no_augment():
59 from torchvision.datasets import CIFAR10
60 from torchvision.transforms import transforms
61 return CIFAR10(str(lab.get_data_path()),
62 train=False,
63 download=True,
64 transform=transforms.Compose([
65 transforms.ToTensor(),
66 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67 ]))
70class CIFAR10VGGModel(Module):
コンボリューションとアクティベーションの組み合わせ
75 def conv_block(self, in_channels, out_channels) -> nn.Module:
79 return nn.Sequential(
80 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81 nn.ReLU(inplace=True),
82 )
84 def __init__(self, blocks: List[List[int]]):
85 super().__init__()
5つのプーリングレイヤーでサイズの出力が得られます。CIFAR 10 の画像サイズは
89 assert len(blocks) == 5
90 layers = []
RGB チャンネル
92 in_channels = 3
各ブロックの各レイヤーのチャンネル数
94 for block in blocks:
コンボリューション、ノーマライゼーション、アクティベーションレイヤー
96 for channels in block:
97 layers += self.conv_block(in_channels, channels)
98 in_channels = channels
各ブロック終了時の最大プーリング
100 layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
レイヤーを含むシーケンシャルモデルの作成
103 self.layers = nn.Sequential(*layers)
最終ロジットレイヤー
105 self.fc = nn.Linear(in_channels, 10)
107 def forward(self, x):
VGG レイヤー
109 x = self.layers(x)
分類レイヤーの形状を変更
111 x = x.view(x.shape[0], -1)
最終線形レイヤー
113 return self.fc(x)