CIFAR10 実験

10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigs

コンフィギュレーション

これは、およびの CIFAR 10 データセット構成を拡張したものですlabml_helpers MNISTConfigs

21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):

デフォルトで CIFAR10 データセットを使用

30    dataset_name: str = 'CIFAR10'

拡張された CIFAR 10 トレインデータセット

33@option(CIFAR10Configs.train_dataset)
34def cifar10_train_augmented():
38    from torchvision.datasets import CIFAR10
39    from torchvision.transforms import transforms
40    return CIFAR10(str(lab.get_data_path()),
41                   train=True,
42                   download=True,
43                   transform=transforms.Compose([

パッドとクロップ

45                       transforms.RandomCrop(32, padding=4),

ランダム水平反転

47                       transforms.RandomHorizontalFlip(),

49                       transforms.ToTensor(),
50                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51                   ]))

拡張されていない CIFAR 10 検証データセット

54@option(CIFAR10Configs.valid_dataset)
55def cifar10_valid_no_augment():
59    from torchvision.datasets import CIFAR10
60    from torchvision.transforms import transforms
61    return CIFAR10(str(lab.get_data_path()),
62                   train=False,
63                   download=True,
64                   transform=transforms.Compose([
65                       transforms.ToTensor(),
66                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67                   ]))

CIFAR-10 分類用の VGG モデル

70class CIFAR10VGGModel(Module):

コンボリューションとアクティベーションの組み合わせ

75    def conv_block(self, in_channels, out_channels) -> nn.Module:
79        return nn.Sequential(
80            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81            nn.ReLU(inplace=True),
82        )
84    def __init__(self, blocks: List[List[int]]):
85        super().__init__()

5つのプーリングレイヤーでサイズの出力が得られます。CIFAR 10 の画像サイズは

89        assert len(blocks) == 5
90        layers = []

RGB チャンネル

92        in_channels = 3

各ブロックの各レイヤーのチャンネル数

94        for block in blocks:

コンボリューション、ノーマライゼーション、アクティベーションレイヤー

96            for channels in block:
97                layers += self.conv_block(in_channels, channels)
98                in_channels = channels

各ブロック終了時の最大プーリング

100            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

レイヤーを含むシーケンシャルモデルの作成

103        self.layers = nn.Sequential(*layers)

最終ロジットレイヤー

105        self.fc = nn.Linear(in_channels, 10)
107    def forward(self, x):

VGG レイヤー

109        x = self.layers(x)

分類レイヤーの形状を変更

111        x = x.view(x.shape[0], -1)

最終線形レイヤー

113        return self.fc(x)