13import torch.nn as nn
14
15from labml import experiment, logger
16from labml.configs import option
17from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
18from labml_nn.normalization.batch_norm import BatchNorm21class Configs(CIFAR10Configs):28    pass31class SmallModel(CIFAR10VGGModel):コンボリューションレイヤーとアクティベーションの作成
38    def conv_block(self, in_channels, out_channels) -> nn.Module:42        return nn.Sequential(コンボリューションレイヤー
44            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),バッチ正規化
46            BatchNorm(out_channels, track_running_stats=False),ReLU アクティベーション
48            nn.ReLU(inplace=True),
49        )51    def __init__(self):与えられた畳み込みサイズ (チャネル) でモデルを作成
53        super().__init__([[32, 32], [64, 64], [128], [128], [128]])56@option(Configs.model)
57def _small_model(c: Configs):61    return SmallModel().to(c.device)64def main():実験を作成
66    experiment.create(name='cifar10', comment='small model')構成の作成
68    conf = Configs()構成をロード
70    experiment.configs(conf, {
71        'optimizer.optimizer': 'Adam',
72        'optimizer.learning_rate': 2.5e-4,
73    })保存/読み込み用のモデルを設定
75    experiment.add_pytorch_models({'model': conf.model})モデル内のパラメータの数を出力します
77    logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))実験を開始し、トレーニングループを実行します
79    with experiment.start():
80        conf.run()84if __name__ == '__main__':
85    main()