10from typing import List, Optional
11
12from torch import nn
13
14from labml import experiment
15from labml.configs import option
16from labml_nn.experiments.cifar10 import CIFAR10Configs
17from labml_nn.resnet import ResNetBase

コンフィギュレーション

データセットに関連するすべての構成、オプティマイザー、トレーニングループを定義するものを使用していますCIFAR10Configs

20class Configs(CIFAR10Configs):

各フィーチャマップサイズのブロック数

29    n_blocks: List[int] = [3, 3, 3]

各フィーチャマップサイズのチャンネル数

31    n_channels: List[int] = [16, 32, 64]

ボトルネックサイズ

33    bottlenecks: Optional[List[int]] = None

初期畳み込み層のカーネルサイズ

35    first_kernel_size: int = 3

モデル作成

38@option(Configs.model)
39def _resnet(c: Configs):
44    base = ResNetBase(c.n_blocks, c.n_channels, c.bottlenecks, img_channels=3, first_kernel_size=c.first_kernel_size)

分類用の線形レイヤー

46    classification = nn.Linear(c.n_channels[-1], 10)

積み重ねて

49    model = nn.Sequential(base, classification)

モデルをデバイスに移動

51    return model.to(c.device)
54def main():

実験を作成

56    experiment.create(name='resnet', comment='cifar10')

構成の作成

58    conf = Configs()

構成をロード

60    experiment.configs(conf, {
61        'bottlenecks': [8, 16, 16],
62        'n_blocks': [6, 6, 6],
63
64        'optimizer.optimizer': 'Adam',
65        'optimizer.learning_rate': 2.5e-4,
66
67        'epochs': 500,
68        'train_batch_size': 256,
69
70        'train_dataset': 'cifar10_train_augmented',
71        'valid_dataset': 'cifar10_valid_no_augment',
72    })

保存/読み込み用のモデルを設定

74    experiment.add_pytorch_models({'model': conf.model})

実験を開始し、トレーニングループを実行します

76    with experiment.start():
77        conf.run()

81if __name__ == '__main__':
82    main()