This is an annotated PyTorch code to classify MNIST digits with PyTorch.
This paper implements the experiment described in paper Dynamic Routing Between Capsules.
14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss
29class MNISTCapsuleNetworkModel(Module):
34 def __init__(self):
35 super().__init__()
First convolution layer has , convolution kernels
37 self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
The second layer (Primary Capsules) s a convolutional capsule layer with channels of convolutional capsules ( features per capsule). That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with channels and reshape and permutate its output to get the capsules of features each.
43 self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44 self.squash = Squash()
Routing layer gets the primary capsules and produces capsules. Each of the primary capsules have features, while output capsules (Digit Capsules) have features. The routing algorithm iterates times.
50 self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
This is the decoder mentioned in the paper. It takes the outputs of the digit capsules, each with features to reproduce the image. It goes through linear layers of sizes and with activations.
55 self.decoder = nn.Sequential(
56 nn.Linear(16 * 10, 512),
57 nn.ReLU(),
58 nn.Linear(512, 1024),
59 nn.ReLU(),
60 nn.Linear(1024, 784),
61 nn.Sigmoid()
62 )
data
are the MNIST images, with shape [batch_size, 1, 28, 28]
64 def forward(self, data: torch.Tensor):
Pass through the first convolution layer. Output of this layer has shape [batch_size, 256, 20, 20]
70 x = F.relu(self.conv1(data))
Pass through the second convolution layer. Output of this has shape [batch_size, 32 * 8, 6, 6]
. Note that this layer has a stride length of .
74 x = self.conv2(x)
Resize and permutate to get the capsules
77 caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
Squash the capsules
79 caps = self.squash(caps)
Take them through the router to get digit capsules. This has shape [batch_size, 10, 16]
.
82 caps = self.digit_capsules(caps)
Get masks for reconstructioon
85 with torch.no_grad():
The prediction by the capsule network is the capsule with longest length
87 pred = (caps ** 2).sum(-1).argmax(-1)
Create a mask to maskout all the other capsules
89 mask = torch.eye(10, device=data.device)[pred]
Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction
93 reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
Reshape the reconstruction to match the image dimensions
95 reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97 return caps, reconstructions, pred
Configurations with MNIST data and Train & Validation setup
100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
104 epochs: int = 10
105 model: nn.Module = 'capsule_network_model'
106 reconstruction_loss = nn.MSELoss()
107 margin_loss = MarginLoss(n_labels=10)
108 accuracy = AccuracyDirect()
110 def init(self):
Print losses and accuracy to screen
112 tracker.set_scalar('loss.*', True)
113 tracker.set_scalar('accuracy.*', True)
We need to set the metrics to calculate them for the epoch for training and validation
116 self.state_modules = [self.accuracy]
This method gets called by the trainer
118 def step(self, batch: Any, batch_idx: BatchIndex):
Set the model mode
123 self.model.train(self.mode.is_train)
Get the images and labels and move them to the model's device
126 data, target = batch[0].to(self.device), batch[1].to(self.device)
Increment step in training mode
129 if self.mode.is_train:
130 tracker.add_global_step(len(data))
Whether to log activations
133 with self.mode.update(is_log_activations=batch_idx.is_last):
Run the model
135 caps, reconstructions, pred = self.model(data)
Calculate the total loss
138 loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139 tracker.add("loss.", loss)
Call accuracy metric
142 self.accuracy(pred, target)
143
144 if self.mode.is_train:
145 loss.backward()
146
147 self.optimizer.step()
Log parameters and gradients
149 if batch_idx.is_last:
150 tracker.add('model', self.model)
151 self.optimizer.zero_grad()
152
153 tracker.save()
Set the model
156@option(Configs.model)
157def capsule_network_model(c: Configs):
159 return MNISTCapsuleNetworkModel().to(c.device)
Run the experiment
162def main():
166 experiment.create(name='capsule_network_mnist')
167 conf = Configs()
168 experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169 'optimizer.learning_rate': 1e-3})
170
171 experiment.add_pytorch_models({'model': conf.model})
172
173 with experiment.start():
174 conf.run()
175
176
177if __name__ == '__main__':
178 main()