これは、畳み込みニューラルネットワークのインスタンス正規化層を分類に使用していることを示しています。インスタンスの正規化がスタイル転送のために設計されたわけではなく、これは単なるデモです
。16import torch.nn as nn
17
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
21from labml_nn.normalization.instance_norm import InstanceNorm
24class Model(CIFAR10VGGModel):
31 def conv_block(self, in_channels, out_channels) -> nn.Module:
32 return nn.Sequential(
33 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
34 InstanceNorm(out_channels),
35 nn.ReLU(inplace=True),
36 )
38 def __init__(self):
39 super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
42@option(CIFAR10Configs.model)
43def _model(c: CIFAR10Configs):
47 return Model().to(c.device)
50def main():
実験を作成
52 experiment.create(name='cifar10', comment='instance norm')
構成の作成
54 conf = CIFAR10Configs()
構成をロード
56 experiment.configs(conf, {
57 'optimizer.optimizer': 'Adam',
58 'optimizer.learning_rate': 2.5e-4,
59 })
実験を開始し、トレーニングループを実行します
61 with experiment.start():
62 conf.run()
66if __name__ == '__main__':
67 main()