Capsule 网络是一种神经网络架构,它以胶囊的形式嵌入特征,并通过投票机制将它们路由到下一层胶囊。
与其他模型实现不同,我们提供了一个示例,因为仅使用模块很难理解某些概念。这是使用胶囊对 MNIST 数据集进行分类的模型的带注释的代码
该文件包含了 Capsule Networks 核心模块的实现。
我用 jindongwang/pytorch-CapsuleNet 来澄清我对这篇论文的一些困惑。
这是一本在 MNIST 数据集上训练 Capsule 网络的笔记本。
32import torch.nn as nn
33import torch.nn.functional as F
34import torch.utils.data
35
36from labml_helpers.module import Module
39class Squash(Module):
54 def __init__(self, epsilon=1e-8):
55 super().__init__()
56 self.epsilon = epsilon
的形状s
是[batch_size, n_capsules, n_features]
58 def forward(self, s: torch.Tensor):
64 s2 = (s ** 2).sum(dim=-1, keepdims=True)
我们在计算时添加一个 epsilon,以确保它不会变为零。如果该值变为零,则开始给出nan
值,并且训练失败。
70 return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
73class Router(Module):
84 def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
91 super().__init__()
92 self.in_caps = in_caps
93 self.out_caps = out_caps
94 self.iterations = iterations
95 self.softmax = nn.Softmax(dim=1)
96 self.squash = Squash()
这是权重矩阵。它将下层中的每个胶囊映射到该层中的每个胶囊体
100 self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
的形状u
是[batch_size, n_capsules, n_features]
。这些是下层的胶囊。
102 def forward(self, u: torch.Tensor):
这里用于索引该层中的胶囊,而用于索引下层(上一层)中的胶囊。
111 u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)
初始对数是胶囊应与之相结合的对数先验概率。我们将它们初始化为零
116 b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
117
118 v = None
迭代
121 for i in range(self.iterations):
路由软最大
123 c = self.softmax(b)
125 s = torch.einsum('bij,bijm->bjm', c, u_hat)
127 v = self.squash(s)
129 a = torch.einsum('bjm,bijm->bij', v, u_hat)
131 b = b + a
132
133 return v
每个输出胶囊使用单独的保证金损失,总亏损是它们的总和。每个输出胶囊的长度是输入中存在类的概率。
每个输出胶囊或类的损失为,
是类是否存在,否则。损失的第一个组成部分是当类不存在时,第二个组成部分是类是否存在。用于避免预测走向极端。被设置为和将在报纸上。
在训练的初始阶段,减重用于防止所有胶囊的长度掉落。
136class MarginLoss(Module):
156 def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
157 super().__init__()
158
159 self.m_negative = m_negative
160 self.m_positive = m_positive
161 self.lambda_ = lambda_
162 self.n_labels = n_labels
164 def forward(self, v: torch.Tensor, labels: torch.Tensor):
172 v_norm = torch.sqrt((v ** 2).sum(dim=-1))
labels
是形状的一热编码标签[batch_size, n_labels]
176 labels = torch.eye(self.n_labels, device=labels.device)[labels]
loss
有形状[batch_size, n_labels]
。我们已经并行化了 for all 的计算。
182 loss = labels * F.relu(self.m_positive - v_norm) + \
183 self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
186 return loss.sum(dim=-1).mean()