Classify MNIST digits with Capsule Networks

This is an annotated PyTorch code to classify MNIST digits with PyTorch.

This paper implements the experiment described in paper Dynamic Routing Between Capsules.

14from typing import Any
16import torch.nn as nn
17import torch.nn.functional as F
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss

Model for classifying MNIST digits

29class MNISTCapsuleNetworkModel(Module):
34    def __init__(self):
35        super().__init__()

First convolution layer has , convolution kernels

37        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)

The second layer (Primary Capsules) s a convolutional capsule layer with channels of convolutional capsules ( features per capsule). That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with channels and reshape and permutate its output to get the capsules of features each.

43        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44        self.squash = Squash()

Routing layer gets the primary capsules and produces capsules. Each of the primary capsules have features, while output capsules (Digit Capsules) have features. The routing algorithm iterates times.

50        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)

This is the decoder mentioned in the paper. It takes the outputs of the digit capsules, each with features to reproduce the image. It goes through linear layers of sizes and with activations.

55        self.decoder = nn.Sequential(
56            nn.Linear(16 * 10, 512),
57            nn.ReLU(),
58            nn.Linear(512, 1024),
59            nn.ReLU(),
60            nn.Linear(1024, 784),
61            nn.Sigmoid()
62        )

data are the MNIST images, with shape [batch_size, 1, 28, 28]

64    def forward(self, data: torch.Tensor):

Pass through the first convolution layer. Output of this layer has shape [batch_size, 256, 20, 20]

70        x = F.relu(self.conv1(data))

Pass through the second convolution layer. Output of this has shape [batch_size, 32 * 8, 6, 6] . Note that this layer has a stride length of .

74        x = self.conv2(x)

Resize and permutate to get the capsules

77        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)

Squash the capsules

79        caps = self.squash(caps)

Take them through the router to get digit capsules. This has shape [batch_size, 10, 16] .

82        caps = self.digit_capsules(caps)

Get masks for reconstructioon

85        with torch.no_grad():

The prediction by the capsule network is the capsule with longest length

87            pred = (caps ** 2).sum(-1).argmax(-1)

Create a mask to maskout all the other capsules

89            mask = torch.eye(10, device=data.device)[pred]

Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction

93        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))

Reshape the reconstruction to match the image dimensions

95        reconstructions = reconstructions.view(-1, 1, 28, 28)
97        return caps, reconstructions, pred

Configurations with MNIST data and Train & Validation setup

100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
104    epochs: int = 10
105    model: nn.Module = 'capsule_network_model'
106    reconstruction_loss = nn.MSELoss()
107    margin_loss = MarginLoss(n_labels=10)
108    accuracy = AccuracyDirect()
110    def init(self):

Print losses and accuracy to screen

112        tracker.set_scalar('loss.*', True)
113        tracker.set_scalar('accuracy.*', True)

We need to set the metrics to calculate them for the epoch for training and validation

116        self.state_modules = [self.accuracy]

This method gets called by the trainer

118    def step(self, batch: Any, batch_idx: BatchIndex):

Set the model mode

123        self.model.train(self.mode.is_train)

Get the images and labels and move them to the model's device

126        data, target = batch[0].to(self.device), batch[1].to(self.device)

Increment step in training mode

129        if self.mode.is_train:
130            tracker.add_global_step(len(data))

Whether to log activations

133        with self.mode.update(is_log_activations=batch_idx.is_last):

Run the model

135            caps, reconstructions, pred = self.model(data)

Calculate the total loss

138        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139        tracker.add("loss.", loss)

Call accuracy metric

142        self.accuracy(pred, target)
144        if self.mode.is_train:
145            loss.backward()
147            self.optimizer.step()

Log parameters and gradients

149            if batch_idx.is_last:
150                tracker.add('model', self.model)
151            self.optimizer.zero_grad()

Set the model

157def capsule_network_model(c: Configs):
159    return MNISTCapsuleNetworkModel().to(c.device)

Run the experiment

162def main():
166    experiment.create(name='capsule_network_mnist')
167    conf = Configs()
168    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169                              'optimizer.learning_rate': 1e-3})
171    experiment.add_pytorch_models({'model': conf.model})
173    with experiment.start():
177if __name__ == '__main__':
178    main()