13import torch.nn as nn
14
15from labml import experiment, logger
16from labml.configs import option
17from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
18from labml_nn.normalization.batch_norm import BatchNorm
21class Configs(CIFAR10Configs):
28 pass
31class LargeModel(CIFAR10VGGModel):
コンボリューションレイヤーとアクティベーションの作成
38 def conv_block(self, in_channels, out_channels) -> nn.Module:
42 return nn.Sequential(
ドロップアウト
44 nn.Dropout(0.1),
コンボリューションレイヤー
46 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
バッチ正規化
48 BatchNorm(out_channels, track_running_stats=False),
ReLU アクティベーション
50 nn.ReLU(inplace=True),
51 )
53 def __init__(self):
与えられた畳み込みサイズ (チャネル) でモデルを作成
55 super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
58@option(Configs.model)
59def _large_model(c: Configs):
63 return LargeModel().to(c.device)
66def main():
実験を作成
68 experiment.create(name='cifar10', comment='large model')
構成の作成
70 conf = Configs()
構成をロード
72 experiment.configs(conf, {
73 'optimizer.optimizer': 'Adam',
74 'optimizer.learning_rate': 2.5e-4,
75 'is_save_models': True,
76 'epochs': 20,
77 })
保存/読み込み用のモデルを設定
79 experiment.add_pytorch_models({'model': conf.model})
モデル内のパラメータの数を出力します
81 logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))
実験を開始し、トレーニングループを実行します
83 with experiment.start():
84 conf.run()
88if __name__ == '__main__':
89 main()