This is an annotated PyTorch code to classify MNIST digits with PyTorch.
This paper implements the experiment described in paper Dynamic Routing Between Capsules.
14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19from labml import experiment, tracker
20from labml.configs import option
21from labml_nn.capsule_networks import Squash, Router, MarginLoss
22from labml_nn.helpers.datasets import MNISTConfigs
23from labml_nn.helpers.metrics import AccuracyDirect
24from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex
27class MNISTCapsuleNetworkModel(nn.Module):
32 def __init__(self):
33 super().__init__()
First convolution layer has , convolution kernels
35 self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
The second layer (Primary Capsules) s a convolutional capsule layer with channels of convolutional capsules ( features per capsule). That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with channels and reshape and permutate its output to get the capsules of features each.
41 self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
42 self.squash = Squash()
Routing layer gets the primary capsules and produces capsules. Each of the primary capsules have features, while output capsules (Digit Capsules) have features. The routing algorithm iterates times.
48 self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
This is the decoder mentioned in the paper. It takes the outputs of the digit capsules, each with features to reproduce the image. It goes through linear layers of sizes and with activations.
53 self.decoder = nn.Sequential(
54 nn.Linear(16 * 10, 512),
55 nn.ReLU(),
56 nn.Linear(512, 1024),
57 nn.ReLU(),
58 nn.Linear(1024, 784),
59 nn.Sigmoid()
60 )
data
are the MNIST images, with shape [batch_size, 1, 28, 28]
62 def forward(self, data: torch.Tensor):
Pass through the first convolution layer. Output of this layer has shape [batch_size, 256, 20, 20]
68 x = F.relu(self.conv1(data))
Pass through the second convolution layer. Output of this has shape [batch_size, 32 * 8, 6, 6]
. Note that this layer has a stride length of .
72 x = self.conv2(x)
Resize and permutate to get the capsules
75 caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
Squash the capsules
77 caps = self.squash(caps)
Take them through the router to get digit capsules. This has shape [batch_size, 10, 16]
.
80 caps = self.digit_capsules(caps)
Get masks for reconstructioon
83 with torch.no_grad():
The prediction by the capsule network is the capsule with longest length
85 pred = (caps ** 2).sum(-1).argmax(-1)
Create a mask to maskout all the other capsules
87 mask = torch.eye(10, device=data.device)[pred]
Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction
91 reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
Reshape the reconstruction to match the image dimensions
93 reconstructions = reconstructions.view(-1, 1, 28, 28)
94
95 return caps, reconstructions, pred
Configurations with MNIST data and Train & Validation setup
98class Configs(MNISTConfigs, SimpleTrainValidConfigs):
102 epochs: int = 10
103 model: nn.Module = 'capsule_network_model'
104 reconstruction_loss = nn.MSELoss()
105 margin_loss = MarginLoss(n_labels=10)
106 accuracy = AccuracyDirect()
108 def init(self):
Print losses and accuracy to screen
110 tracker.set_scalar('loss.*', True)
111 tracker.set_scalar('accuracy.*', True)
We need to set the metrics to calculate them for the epoch for training and validation
114 self.state_modules = [self.accuracy]
This method gets called by the trainer
116 def step(self, batch: Any, batch_idx: BatchIndex):
Set the model mode
121 self.model.train(self.mode.is_train)
Get the images and labels and move them to the model's device
124 data, target = batch[0].to(self.device), batch[1].to(self.device)
Increment step in training mode
127 if self.mode.is_train:
128 tracker.add_global_step(len(data))
Run the model
131 caps, reconstructions, pred = self.model(data)
Calculate the total loss
134 loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
135 tracker.add("loss.", loss)
Call accuracy metric
138 self.accuracy(pred, target)
139
140 if self.mode.is_train:
141 loss.backward()
142
143 self.optimizer.step()
Log parameters and gradients
145 if batch_idx.is_last:
146 tracker.add('model', self.model)
147 self.optimizer.zero_grad()
148
149 tracker.save()
Set the model
152@option(Configs.model)
153def capsule_network_model(c: Configs):
155 return MNISTCapsuleNetworkModel().to(c.device)
Run the experiment
158def main():
162 experiment.create(name='capsule_network_mnist')
163 conf = Configs()
164 experiment.configs(conf, {'optimizer.optimizer': 'Adam',
165 'optimizer.learning_rate': 1e-3})
166
167 experiment.add_pytorch_models({'model': conf.model})
168
169 with experiment.start():
170 conf.run()
171
172
173if __name__ == '__main__':
174 main()