This is a PyTorch implementation/tutorial of Dynamic Routing Between Capsules.

Capsule network is a neural network architecture that embeds features as capsules and routes them with a voting mechanism to next layer of capsules.

Unlike in other implementations of models, we've included a sample, because it is difficult to understand some concepts with just the modules. This is the annotated code for a model that uses capsules to classify MNIST dataset

This file holds the implementations of the core modules of Capsule Networks.

I used jindongwang/Pytorch-CapsuleNet to clarify some confusions I had with the paper.

Here's a notebook for training a Capsule Network on MNIST dataset.

```
34import torch.nn as nn
35import torch.nn.functional as F
36import torch.utils.data
37
38from labml_helpers.module import Module
```

This is **squashing** function from paper, given by equation $(1)$.

$v_{j}=1+∥s_{j}∥_{2}∥s_{j}∥_{2} ∥s_{j}∥s_{j} $

$∥s_{j}∥s_{j} $ normalizes the length of all the capsules, whilst $1+∥s_{j}∥_{2}∥s_{j}∥_{2} $ shrinks the capsules that have a length smaller than one .

`41class Squash(Module):`

```
56 def __init__(self, epsilon=1e-8):
57 super().__init__()
58 self.epsilon = epsilon
```

The shape of `s`

is `[batch_size, n_capsules, n_features]`

`60 def forward(self, s: torch.Tensor):`

$∥s_{j}∥_{2}$

`66 s2 = (s ** 2).sum(dim=-1, keepdims=True)`

We add an epsilon when calculating $∥s_{j}∥$ to make sure it doesn't become zero. If this becomes zero it starts giving out `nan`

values and training fails. $v_{j}=1+∥s_{j}∥_{2}∥s_{j}∥_{2} ∥s_{j}∥_{2}+ϵ s_{j} $

`72 return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))`

This is the routing mechanism described in the paper. You can use multiple routing layers in your models.

This combines calculating $s_{j}$ for this layer and the routing algorithm described in *Procedure 1*.

`75class Router(Module):`

`in_caps`

is the number of capsules, and `in_d`

is the number of features per capsule from the layer below. `out_caps`

and `out_d`

are the same for this layer.

`iterations`

is the number of routing iterations, symbolized by $r$ in the paper.

`86 def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):`

```
93 super().__init__()
94 self.in_caps = in_caps
95 self.out_caps = out_caps
96 self.iterations = iterations
97 self.softmax = nn.Softmax(dim=1)
98 self.squash = Squash()
```

This is the weight matrix $W_{ij}$. It maps each capsule in the lower layer to each capsule in this layer

`102 self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)`

The shape of `u`

is `[batch_size, n_capsules, n_features]`

. These are the capsules from the lower layer.

`104 def forward(self, u: torch.Tensor):`

$u^_{j∣i}=W_{ij}u_{i}$ Here $j$ is used to index capsules in this layer, whilst $i$ is used to index capsules in the layer below (previous).

`113 u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)`

Initial logits $b_{ij}$ are the log prior probabilities that capsule $i$ should be coupled with $j$. We initialize these at zero

```
118 b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
119
120 v = None
```

Iterate

`123 for i in range(self.iterations):`

routing softmax $c_{ij}=∑_{k}exp(b_{ik})exp(b_{ij}) $

`125 c = self.softmax(b)`

$s_{j}=i∑ c_{ij}u^_{j∣i}$

`127 s = torch.einsum('bij,bijm->bjm', c, u_hat)`

$v_{j}=squash(s_{j})$

`129 v = self.squash(s)`

$a_{ij}=v_{j}⋅u^_{j∣i}$

`131 a = torch.einsum('bjm,bijm->bij', v, u_hat)`

$b_{ij}←b_{ij}+v_{j}⋅u^_{j∣i}$

```
133 b = b + a
134
135 return v
```

A separate margin loss is used for each output capsule and the total loss is the sum of them. The length of each output capsule is the probability that class is present in the input.

Loss for each output capsule or class $k$ is, $L_{k}=T_{k}max(0,m_{+}−∥v_{k}∥)_{2}+λ(1−T_{k})max(0,∥v_{k}∥−m_{−})_{2}$

$T_{k}$ is $1$ if the class $k$ is present and $0$ otherwise. The first component of the loss is $0$ when the class is not present, and the second component is $0$ if the class is present. The $max(0,x)$ is used to avoid predictions going to extremes. $m_{+}$ is set to be $0.9$ and $m_{−}$ to be $0.1$ in the paper.

The $λ$ down-weighting is used to stop the length of all capsules from falling during the initial phase of training.

`138class MarginLoss(Module):`

```
158 def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
159 super().__init__()
160
161 self.m_negative = m_negative
162 self.m_positive = m_positive
163 self.lambda_ = lambda_
164 self.n_labels = n_labels
```

`v`

, $v_{j}$ are the squashed output capsules. This has shape `[batch_size, n_labels, n_features]`

; that is, there is a capsule for each label.

`labels`

are the labels, and has shape `[batch_size]`

.

`166 def forward(self, v: torch.Tensor, labels: torch.Tensor):`

$∥v_{j}∥$

`174 v_norm = torch.sqrt((v ** 2).sum(dim=-1))`

$L$ `labels`

is one-hot encoded labels of shape `[batch_size, n_labels]`

`178 labels = torch.eye(self.n_labels, device=labels.device)[labels]`

$L_{k}=T_{k}max(0,m_{+}−∥v_{k}∥)_{2}+λ(1−T_{k})max(0,∥v_{k}∥−m_{−})_{2}$ `loss`

has shape `[batch_size, n_labels]`

. We have parallelized the computation of $L_{k}$ for for all $k$.

```
184 loss = labels * F.relu(self.m_positive - v_norm) + \
185 self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
```

$k∑ L_{k}$

`188 return loss.sum(dim=-1).mean()`