これは、MNISTの数字をPyTorchで分類するためのアノテーション付きのPyTorchコードです。
この論文では、論文「カプセル間の動的ルーティング」で説明されている実験を実装しています。
14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss
29class MNISTCapsuleNetworkModel(Module):
34 def __init__(self):
35 super().__init__()
最初の畳み込み層には、畳み込みカーネルがあります
37 self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
2 番目の層 (プライマリーカプセル) は、畳み込みカプセル (カプセルごとのフィーチャ) のチャネルがある畳み込みカプセル層です。つまり、各プライマリカプセルには、9 × 9 のカーネルとストライドが 2 の 8 つの畳み込みユニットが含まれています。これを実装するために、チャネルを含む畳み込み層を作成し、その出力を形状変更および置換して、それぞれの特徴のカプセルを取得します
。43 self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44 self.squash = Squash()
ルーティング層は一次カプセルを取得し、カプセルを生成します。各プライマリーカプセルには特徴があり、出力カプセル(ディジットカプセル)には特徴があります。ルーティングアルゴリズムは何回も繰り返します。
50 self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
これは論文で言及されているデコーダーです。数字カプセルの出力を受け取り、それぞれに画像を再現する機能があります。サイズやアクティベーションが直線的に繰り返されます
。55 self.decoder = nn.Sequential(
56 nn.Linear(16 * 10, 512),
57 nn.ReLU(),
58 nn.Linear(512, 1024),
59 nn.ReLU(),
60 nn.Linear(1024, 784),
61 nn.Sigmoid()
62 )
data
MNIST の画像は形状付きです [batch_size, 1, 28, 28]
64 def forward(self, data: torch.Tensor):
最初の畳み込み層を通過します。このレイヤーの出力には形状があります [batch_size, 256, 20, 20]
70 x = F.relu(self.conv1(data))
2 番目の畳み込み層を通過します。これの出力には形状があります[batch_size, 32 * 8, 6, 6]
。このレイヤーのストライドの長さはであることに注意してください
74 x = self.conv2(x)
サイズを変更して並べ替えてカプセルにする
77 caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
カプセルを押しつぶす
79 caps = self.squash(caps)
それらをルーターに通して、数字のカプセルを入手してください。これは形があります[batch_size, 10, 16]
。
82 caps = self.digit_capsules(caps)
復興用マスクを入手
85 with torch.no_grad():
カプセルネットワークによる予測では、長さが最も長いカプセルです
87 pred = (caps ** 2).sum(-1).argmax(-1)
マスクを作成して、他のすべてのカプセルを覆い隠してください
89 mask = torch.eye(10, device=data.device)[pred]
数字のカプセルをマスクして予測を行ったカプセルのみを取得し、それをデコーダーに通して再構成します
93 reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
画像のサイズに合わせて再構成の形状を変更
95 reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97 return caps, reconstructions, pred
MNISTデータとトレーニングと検証のセットアップを含む構成
100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
104 epochs: int = 10
105 model: nn.Module = 'capsule_network_model'
106 reconstruction_loss = nn.MSELoss()
107 margin_loss = MarginLoss(n_labels=10)
108 accuracy = AccuracyDirect()
110 def init(self):
印刷ロスと画面の精度
112 tracker.set_scalar('loss.*', True)
113 tracker.set_scalar('accuracy.*', True)
トレーニングと検証のために、エポックに合わせてそれらを計算するメトリックを設定する必要があります
116 self.state_modules = [self.accuracy]
このメソッドはトレーナーによって呼び出されます
118 def step(self, batch: Any, batch_idx: BatchIndex):
モデルモードを設定
123 self.model.train(self.mode.is_train)
画像とラベルを取得してモデルのデバイスに移動します
126 data, target = batch[0].to(self.device), batch[1].to(self.device)
トレーニングモードでのインクリメントステップ
129 if self.mode.is_train:
130 tracker.add_global_step(len(data))
アクティベーションをログに記録するかどうか
133 with self.mode.update(is_log_activations=batch_idx.is_last):
モデルを実行
135 caps, reconstructions, pred = self.model(data)
総損失の計算
138 loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139 tracker.add("loss.", loss)
通話精度指標
142 self.accuracy(pred, target)
143
144 if self.mode.is_train:
145 loss.backward()
146
147 self.optimizer.step()
ログパラメータとグラデーション
149 if batch_idx.is_last:
150 tracker.add('model', self.model)
151 self.optimizer.zero_grad()
152
153 tracker.save()
モデルを設定する
156@option(Configs.model)
157def capsule_network_model(c: Configs):
159 return MNISTCapsuleNetworkModel().to(c.device)
実験を実行する
162def main():
166 experiment.create(name='capsule_network_mnist')
167 conf = Configs()
168 experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169 'optimizer.learning_rate': 1e-3})
170
171 experiment.add_pytorch_models({'model': conf.model})
172
173 with experiment.start():
174 conf.run()
175
176
177if __name__ == '__main__':
178 main()