# 胶囊网络

Capsule 网络是一种神经网络架构，它以胶囊的形式嵌入特征，并通过投票机制将它们路由到下一层胶囊。

32import torch.nn as nn
33import torch.nn.functional as F
34import torch.utils.data
35
36from labml_helpers.module import Module

## 壁球

39class Squash(Module):
54    def __init__(self, epsilon=1e-8):
55        super().__init__()
56        self.epsilon = epsilon

58    def forward(self, s: torch.Tensor):
64        s2 = (s ** 2).sum(dim=-1, keepdims=True)

70        return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))

## 路由算法

73class Router(Module):

in_caps 是胶囊的数量，in_d 是下方图层中每个胶囊的特征数。out_caps 对于这个层来说out_d 是相同的。

iterations 是路由迭代次数，在论文中用符号表示。

84    def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
91        super().__init__()
92        self.in_caps = in_caps
93        self.out_caps = out_caps
94        self.iterations = iterations
95        self.softmax = nn.Softmax(dim=1)
96        self.squash = Squash()

100        self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)

102    def forward(self, u: torch.Tensor):

111        u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)

116        b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
117
118        v = None

121        for i in range(self.iterations):

123            c = self.softmax(b)
125            s = torch.einsum('bij,bijm->bjm', c, u_hat)
127            v = self.squash(s)
129            a = torch.einsum('bjm,bijm->bij', v, u_hat)
131            b = b + a
132
133        return v

## 阶级存在的保证金损失

136class MarginLoss(Module):
156    def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
157        super().__init__()
158
159        self.m_negative = m_negative
160        self.m_positive = m_positive
161        self.lambda_ = lambda_
162        self.n_labels = n_labels

v是压扁的输出胶囊。它有形状[batch_size, n_labels, n_features] ；也就是说，每个标签都有一个胶囊。

labels 是标签，有形状[batch_size]

164    def forward(self, v: torch.Tensor, labels: torch.Tensor):
172        v_norm = torch.sqrt((v ** 2).sum(dim=-1))

labels 是形状的一热编码标签[batch_size, n_labels]

176        labels = torch.eye(self.n_labels, device=labels.device)[labels]

loss 有形状[batch_size, n_labels] 。我们已经并行化了 for all 的计算

182        loss = labels * F.relu(self.m_positive - v_norm) + \
183               self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
186        return loss.sum(dim=-1).mean()