これは、カプセル間の動的ルーティングのPyTorch実装/チュートリアルです。
カプセルネットワークは、フィーチャをカプセルとして埋め込み、投票メカニズムを使用して次のカプセル層にルーティングするニューラルネットワークアーキテクチャです。
他のモデルの実装とは異なり、モジュールだけでは一部の概念を理解するのが難しいため、サンプルを用意しています。
これは、カプセルを使用して MNIST データセットを分類するモデルの注釈付きコードです。このファイルには、Capsule Networks のコアモジュールの実装が格納されています。
Jindongwang/Pytorch-Capsulenetを使って、論文に関する混乱を解消しました。
これは、MNISTデータセットでカプセルネットワークをトレーニングするためのノートブックです。
32import torch.nn as nn
33import torch.nn.functional as F
34import torch.utils.data
35
36from labml_helpers.module import Module
39class Squash(Module):
54 def __init__(self, epsilon=1e-8):
55 super().__init__()
56 self.epsilon = epsilon
s
の形は [batch_size, n_capsules, n_features]
58 def forward(self, s: torch.Tensor):
64 s2 = (s ** 2).sum(dim=-1, keepdims=True)
ゼロにならないように、計算時にイプシロンを追加します。これがゼロになると、nan
値が与えられ始め、トレーニングは失敗します。
70 return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
これは、このホワイトペーパーで説明されているルーティングメカニズムです。モデルでは複数のルーティングレイヤーを使用できます。
これは、このレイヤーの計算と手順1で説明したルーティングアルゴリズムを組み合わせたものです。
73class Router(Module):
in_caps
はカプセルの数で、in_d
は下のレイヤーのカプセルあたりのフィーチャ数です。out_caps
out_d
このレイヤーでも同じです。
iterations
はルーティングの反復回数で、論文では以下のように表示されています。
84 def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
91 super().__init__()
92 self.in_caps = in_caps
93 self.out_caps = out_caps
94 self.iterations = iterations
95 self.softmax = nn.Softmax(dim=1)
96 self.squash = Squash()
これはウェイトマトリックスです。下位レイヤーの各カプセルをこのレイヤーの各カプセルにマッピングします
。100 self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
u
の形は[batch_size, n_capsules, n_features]
.これらは下層のカプセルです
102 def forward(self, u: torch.Tensor):
ここでは、このレイヤーのカプセルのインデックスを作成し、下のレイヤー(前のレイヤー)のカプセルのインデックスに使用します。
111 u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)
初期ロジットは、カプセルと組み合わせるべき対数事前確率です。これらはゼロで初期化します
。116 b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
117
118 v = None
繰り返し
121 for i in range(self.iterations):
ルーティングソフトマックス
123 c = self.softmax(b)
125 s = torch.einsum('bij,bijm->bjm', c, u_hat)
127 v = self.squash(s)
129 a = torch.einsum('bjm,bijm->bij', v, u_hat)
131 b = b + a
132
133 return v
出力カプセルごとに個別のマージンロスが使用され、合計損失はそれらの合計になります。各出力カプセルの長さは、入力にクラスが存在する確率です。
各出力カプセルまたはクラスの損失は、
クラスが存在するかどうか、そうでない場合です。損失の最初の要素はクラスが存在しない場合で、 2番目の要素はクラスが存在する場合です。予測が極端になるのを防ぐために使用されます。新聞に掲載される予定で、掲載される予定です。
ダウンウエイトは、トレーニングの初期段階ですべてのカプセルの長さが落ちるのを防ぐために使用されます。
136class MarginLoss(Module):
156 def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
157 super().__init__()
158
159 self.m_negative = m_negative
160 self.m_positive = m_positive
161 self.lambda_ = lambda_
162 self.n_labels = n_labels
v
、は押しつぶされた出力カプセルです。これには形があります[batch_size, n_labels, n_features]
。つまり、ラベルごとにカプセルがあります。
labels
はラベルで、形をしています[batch_size]
。
164 def forward(self, v: torch.Tensor, labels: torch.Tensor):
172 v_norm = torch.sqrt((v ** 2).sum(dim=-1))
labels
ワンホットエンコードされた形状のラベルです [batch_size, n_labels]
176 labels = torch.eye(self.n_labels, device=labels.device)[labels]
182 loss = labels * F.relu(self.m_positive - v_norm) + \
183 self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
186 return loss.sum(dim=-1).mean()