バッチ正規化のための MNIST 実験

12import torch.nn as nn
13import torch.nn.functional as F
14import torch.utils.data
15
16from labml import experiment
17from labml.configs import option
18from labml_helpers.module import Module
19from labml_nn.experiments.mnist import MNISTConfigs
20from labml_nn.normalization.batch_norm import BatchNorm

モデル定義

23class Model(Module):
28    def __init__(self):
29        super().__init__()

バイアスパラメータは省略していることに注意してください。

31        self.conv1 = nn.Conv2d(1, 20, 5, 1, bias=False)

20 チャネル (畳み込み層の出力) によるバッチ正規化。このレイヤーへの入力はシェイプになります

[batch_size, 20, height(24), width(24)]
34        self.bn1 = BatchNorm(20)

36        self.conv2 = nn.Conv2d(20, 50, 5, 1, bias=False)

50 チャンネルのバッチ正規化。このレイヤーへの入力はシェイプになります

[batch_size, 50, height(8), width(8)]
39        self.bn2 = BatchNorm(50)

41        self.fc1 = nn.Linear(4 * 4 * 50, 500, bias=False)

500 チャネル (完全接続層の出力) によるバッチ正規化このレイヤーへの入力はシェイプになります

[batch_size, 500]
44        self.bn3 = BatchNorm(500)

46        self.fc2 = nn.Linear(500, 10)
48    def forward(self, x: torch.Tensor):
49        x = F.relu(self.bn1(self.conv1(x)))
50        x = F.max_pool2d(x, 2, 2)
51        x = F.relu(self.bn2(self.conv2(x)))
52        x = F.max_pool2d(x, 2, 2)
53        x = x.view(-1, 4 * 4 * 50)
54        x = F.relu(self.bn3(self.fc1(x)))
55        return self.fc2(x)

モデル作成

MNISTConfigs コンフィギュレーションを使用し、新しい関数を設定してモデルを計算します。

58@option(MNISTConfigs.model)
59def model(c: MNISTConfigs):
66    return Model().to(c.device)
69def main():

実験を作成

71    experiment.create(name='mnist_batch_norm')

構成の作成

73    conf = MNISTConfigs()

構成をロード

75    experiment.configs(conf, {
76        'optimizer.optimizer': 'Adam',
77        'optimizer.learning_rate': 0.001,
78    })

実験を開始し、トレーニングループを実行します

80    with experiment.start():
81        conf.run()

85if __name__ == '__main__':
86    main()