Capsule Networks

This is a PyTorch implementation/tutorial of Dynamic Routing Between Capsules.

Capsule network is a neural network architecture that embeds features as capsules and routes them with a voting mechanism to next layer of capsules.

Unlike in other implementations of models, we've included a sample, because it is difficult to understand some concepts with just the modules. This is the annotated code for a model that uses capsules to classify MNIST dataset

This file holds the implementations of the core modules of Capsule Networks.

I used jindongwang/Pytorch-CapsuleNet to clarify some confusions I had with the paper.

Here's a notebook for training a Capsule Network on MNIST dataset.

Open In Colab

32import torch.nn as nn
33import torch.nn.functional as F
34import torch.utils.data
35
36from labml_helpers.module import Module

Squash

This is squashing function from paper, given by equation .

normalizes the length of all the capsules, whilst shrinks the capsules that have a length smaller than one .

39class Squash(Module):
54    def __init__(self, epsilon=1e-8):
55        super().__init__()
56        self.epsilon = epsilon

The shape of s is [batch_size, n_capsules, n_features]

58    def forward(self, s: torch.Tensor):

64        s2 = (s ** 2).sum(dim=-1, keepdims=True)

We add an epsilon when calculating to make sure it doesn't become zero. If this becomes zero it starts giving out nan values and training fails.

70        return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))

Routing Algorithm

This is the routing mechanism described in the paper. You can use multiple routing layers in your models.

This combines calculating for this layer and the routing algorithm described in Procedure 1.

73class Router(Module):

in_caps is the number of capsules, and in_d is the number of features per capsule from the layer below. out_caps and out_d are the same for this layer.

iterations is the number of routing iterations, symbolized by in the paper.

84    def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
91        super().__init__()
92        self.in_caps = in_caps
93        self.out_caps = out_caps
94        self.iterations = iterations
95        self.softmax = nn.Softmax(dim=1)
96        self.squash = Squash()

This is the weight matrix . It maps each capsule in the lower layer to each capsule in this layer

100        self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)

The shape of u is [batch_size, n_capsules, n_features] . These are the capsules from the lower layer.

102    def forward(self, u: torch.Tensor):

Here is used to index capsules in this layer, whilst is used to index capsules in the layer below (previous).

111        u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)

Initial logits are the log prior probabilities that capsule should be coupled with . We initialize these at zero

116        b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
117
118        v = None

Iterate

121        for i in range(self.iterations):

routing softmax

123            c = self.softmax(b)

125            s = torch.einsum('bij,bijm->bjm', c, u_hat)

127            v = self.squash(s)

129            a = torch.einsum('bjm,bijm->bij', v, u_hat)

131            b = b + a
132
133        return v

Margin loss for class existence

A separate margin loss is used for each output capsule and the total loss is the sum of them. The length of each output capsule is the probability that class is present in the input.

Loss for each output capsule or class is,

is if the class is present and otherwise. The first component of the loss is when the class is not present, and the second component is if the class is present. The is used to avoid predictions going to extremes. is set to be and to be in the paper.

The down-weighting is used to stop the length of all capsules from falling during the initial phase of training.

136class MarginLoss(Module):
156    def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
157        super().__init__()
158
159        self.m_negative = m_negative
160        self.m_positive = m_positive
161        self.lambda_ = lambda_
162        self.n_labels = n_labels

v , are the squashed output capsules. This has shape [batch_size, n_labels, n_features] ; that is, there is a capsule for each label.

labels are the labels, and has shape [batch_size] .

164    def forward(self, v: torch.Tensor, labels: torch.Tensor):

172        v_norm = torch.sqrt((v ** 2).sum(dim=-1))

labels is one-hot encoded labels of shape [batch_size, n_labels]

176        labels = torch.eye(self.n_labels, device=labels.device)[labels]

loss has shape [batch_size, n_labels] . We have parallelized the computation of for for all .

182        loss = labels * F.relu(self.m_positive - v_norm) + \
183               self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)

186        return loss.sum(dim=-1).mean()