胶囊网络

这是胶囊间动态路由PyTorch 实现/教程。

Capsule 网络是一种神经网络架构,它以胶囊的形式嵌入特征,并通过投票机制将它们路由到下一层胶囊。

与其他模型实现不同,我们提供了一个示例,因为仅使用模块很难理解某些概念。这是使用胶囊对 MNIST 数据集进行分类的模型的带注释的代码

该文件包含了 Capsule Networks 核心模块的实现。

我用 jindongwang/pytorch-CapsuleNet 来澄清我对这篇论文的一些困惑。

这是一本在 MNIST 数据集上训练 Capsule 网络的笔记本。

Open In Colab

32import torch.nn as nn
33import torch.nn.functional as F
34import torch.utils.data
35
36from labml_helpers.module import Module

壁球

这是来自纸张的挤压函数,由方程给出

标准化所有胶囊的长度,同时缩小长度小于一个的胶囊。

39class Squash(Module):
54    def __init__(self, epsilon=1e-8):
55        super().__init__()
56        self.epsilon = epsilon

的形状s[batch_size, n_capsules, n_features]

58    def forward(self, s: torch.Tensor):

64        s2 = (s ** 2).sum(dim=-1, keepdims=True)

我们在计算时添加一个 epsilon,以确保它不会变为零。如果该值变为零,则开始给出nan 值,并且训练失败。

70        return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))

路由算法

这是白皮书中描述的路由机制。可以在模型中使用多个布线层。

这结合了此层的计算和过程 1 中描述的路由算法。

73class Router(Module):

in_caps 是胶囊的数量,in_d 是下方图层中每个胶囊的特征数。out_caps 对于这个层来说out_d 是相同的。

iterations 是路由迭代次数,在论文中用符号表示。

84    def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
91        super().__init__()
92        self.in_caps = in_caps
93        self.out_caps = out_caps
94        self.iterations = iterations
95        self.softmax = nn.Softmax(dim=1)
96        self.squash = Squash()

这是权重矩阵。它将下层中的每个胶囊映射到该层中的每个胶囊体

100        self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)

的形状u[batch_size, n_capsules, n_features] 。这些是下层的胶囊。

102    def forward(self, u: torch.Tensor):

这里用于索引该层中的胶囊,而用于索引下层(上一层)中的胶囊。

111        u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)

初始对数是胶囊应与之相结合的对数先验概率。我们将它们初始化为零

116        b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
117
118        v = None

迭代

121        for i in range(self.iterations):

路由软最大

123            c = self.softmax(b)

125            s = torch.einsum('bij,bijm->bjm', c, u_hat)

127            v = self.squash(s)

129            a = torch.einsum('bjm,bijm->bij', v, u_hat)

131            b = b + a
132
133        return v

阶级存在的保证金损失

每个输出胶囊使用单独的保证金损失,总亏损是它们的总和。每个输出胶囊的长度是输入中存在类的概率。

每个输出胶囊或类的损失为,

是类是否存在,否则。损失的第一个组成部分是当类不存在时,第二个组成部分是类是否存在。用于避免预测走向极端。被设置和将在报纸上。

在训练的初始阶段,减重用于防止所有胶囊的长度掉落。

136class MarginLoss(Module):
156    def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
157        super().__init__()
158
159        self.m_negative = m_negative
160        self.m_positive = m_positive
161        self.lambda_ = lambda_
162        self.n_labels = n_labels

v是压扁的输出胶囊。它有形状[batch_size, n_labels, n_features] ;也就是说,每个标签都有一个胶囊。

labels 是标签,有形状[batch_size]

164    def forward(self, v: torch.Tensor, labels: torch.Tensor):

172        v_norm = torch.sqrt((v ** 2).sum(dim=-1))

labels 是形状的一热编码标签[batch_size, n_labels]

176        labels = torch.eye(self.n_labels, device=labels.device)[labels]

loss 有形状[batch_size, n_labels] 。我们已经并行化了 for all 的计算

182        loss = labels * F.relu(self.m_positive - v_norm) + \
183               self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)

186        return loss.sum(dim=-1).mean()