カプセルネットワークによる MNIST ディジットの分類

これは、MNISTの数字をPyTorchで分類するためのアノテーション付きのPyTorchコードです。

この論文では、論文「カプセル間の動的ルーティング」で説明されている実験を実装しています。

14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss

MNIST ディジットを分類するためのモデル

29class MNISTCapsuleNetworkModel(Module):
34    def __init__(self):
35        super().__init__()

最初の畳み込み層には畳み込みカーネルがあります

37        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)

2 番目の層 (プライマリーカプセル) は、畳み込みカプセル (カプセルごとのフィーチャ) のチャネルがある畳み込みカプセル層です。つまり、各プライマリカプセルには、9 × 9 のカーネルとストライドが 2 の 8 つの畳み込みユニットが含まれています。これを実装するために、チャネルを含む畳み込み層を作成し、その出力を形状変更および置換して、それぞれの特徴のカプセルを取得します

43        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44        self.squash = Squash()

ルーティング層は一次カプセルを取得し、カプセルを生成します。各プライマリーカプセルには特徴があり、出力カプセル(ディジットカプセル)には特徴があります。ルーティングアルゴリズムは何回も繰り返します。

50        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)

これは論文で言及されているデコーダーです。数字カプセルの出力を受け取り、それぞれに画像を再現する機能があります。サイズやアクティベーションが直線的に繰り返されます

55        self.decoder = nn.Sequential(
56            nn.Linear(16 * 10, 512),
57            nn.ReLU(),
58            nn.Linear(512, 1024),
59            nn.ReLU(),
60            nn.Linear(1024, 784),
61            nn.Sigmoid()
62        )

data MNIST の画像は形状付きです [batch_size, 1, 28, 28]

64    def forward(self, data: torch.Tensor):

最初の畳み込み層を通過します。このレイヤーの出力には形状があります [batch_size, 256, 20, 20]

70        x = F.relu(self.conv1(data))

2 番目の畳み込み層を通過します。これの出力には形状があります[batch_size, 32 * 8, 6, 6]このレイヤーのストライドの長さはであることに注意してください

74        x = self.conv2(x)

サイズを変更して並べ替えてカプセルにする

77        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)

カプセルを押しつぶす

79        caps = self.squash(caps)

それらをルーターに通して、数字のカプセルを入手してください。これは形があります[batch_size, 10, 16]

82        caps = self.digit_capsules(caps)

復興用マスクを入手

85        with torch.no_grad():

カプセルネットワークによる予測では、長さが最も長いカプセルです

87            pred = (caps ** 2).sum(-1).argmax(-1)

マスクを作成して、他のすべてのカプセルを覆い隠してください

89            mask = torch.eye(10, device=data.device)[pred]

数字のカプセルをマスクして予測を行ったカプセルのみを取得し、それをデコーダーに通して再構成します

93        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))

画像のサイズに合わせて再構成の形状を変更

95        reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97        return caps, reconstructions, pred

MNISTデータとトレーニングと検証のセットアップを含む構成

100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
104    epochs: int = 10
105    model: nn.Module = 'capsule_network_model'
106    reconstruction_loss = nn.MSELoss()
107    margin_loss = MarginLoss(n_labels=10)
108    accuracy = AccuracyDirect()
110    def init(self):

印刷ロスと画面の精度

112        tracker.set_scalar('loss.*', True)
113        tracker.set_scalar('accuracy.*', True)

トレーニングと検証のために、エポックに合わせてそれらを計算するメトリックを設定する必要があります

116        self.state_modules = [self.accuracy]

このメソッドはトレーナーによって呼び出されます

118    def step(self, batch: Any, batch_idx: BatchIndex):

モデルモードを設定

123        self.model.train(self.mode.is_train)

画像とラベルを取得してモデルのデバイスに移動します

126        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモードでのインクリメントステップ

129        if self.mode.is_train:
130            tracker.add_global_step(len(data))

アクティベーションをログに記録するかどうか

133        with self.mode.update(is_log_activations=batch_idx.is_last):

モデルを実行

135            caps, reconstructions, pred = self.model(data)

総損失の計算

138        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139        tracker.add("loss.", loss)

通話精度指標

142        self.accuracy(pred, target)
143
144        if self.mode.is_train:
145            loss.backward()
146
147            self.optimizer.step()

ログパラメータとグラデーション

149            if batch_idx.is_last:
150                tracker.add('model', self.model)
151            self.optimizer.zero_grad()
152
153            tracker.save()

モデルを設定する

156@option(Configs.model)
157def capsule_network_model(c: Configs):
159    return MNISTCapsuleNetworkModel().to(c.device)

実験を実行する

162def main():
166    experiment.create(name='capsule_network_mnist')
167    conf = Configs()
168    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169                              'optimizer.learning_rate': 1e-3})
170
171    experiment.add_pytorch_models({'model': conf.model})
172
173    with experiment.start():
174        conf.run()
175
176
177if __name__ == '__main__':
178    main()