カプセルネットワーク

これは、カプセル間の動的ルーティングのPyTorch実装/チュートリアルです

カプセルネットワークは、フィーチャをカプセルとして埋め込み、投票メカニズムを使用して次のカプセル層にルーティングするニューラルネットワークアーキテクチャです。

他のモデルの実装とは異なり、モジュールだけでは一部の概念を理解するのが難しいため、サンプルを用意しています。

これは、カプセルを使用して MNIST データセットを分類するモデルの注釈付きコードです。

このファイルには、Capsule Networks のコアモジュールの実装が格納されています。

Jindongwang/Pytorch-Capsulenetを使って、論文に関する混乱を解消しました。

これは、MNISTデータセットでカプセルネットワークをトレーニングするためのノートブックです。

Open In Colab

32import torch.nn as nn
33import torch.nn.functional as F
34import torch.utils.data
35
36from labml_helpers.module import Module

スカッシュ

これは、方程式で与えられる紙からの押しつぶし関数です

すべてのカプセルの長さを正規化し、長さが 1 より短いカプセルを縮小します。

39class Squash(Module):
54    def __init__(self, epsilon=1e-8):
55        super().__init__()
56        self.epsilon = epsilon

s の形は [batch_size, n_capsules, n_features]

58    def forward(self, s: torch.Tensor):

64        s2 = (s ** 2).sum(dim=-1, keepdims=True)

ゼロにならないように、計算時にイプシロンを追加します。これがゼロになると、nan 値が与えられ始め、トレーニングは失敗します。

70        return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))

ルーティングアルゴリズム

これは、このホワイトペーパーで説明されているルーティングメカニズムです。モデルでは複数のルーティングレイヤーを使用できます。

これは、このレイヤーの計算と手順1で説明したルーティングアルゴリズムを組み合わせたものです

73class Router(Module):

in_caps はカプセルの数で、in_d は下のレイヤーのカプセルあたりのフィーチャ数です。out_caps out_d このレイヤーでも同じです。

iterations はルーティングの反復回数で、論文では以下のように表示されています。

84    def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
91        super().__init__()
92        self.in_caps = in_caps
93        self.out_caps = out_caps
94        self.iterations = iterations
95        self.softmax = nn.Softmax(dim=1)
96        self.squash = Squash()

これはウェイトマトリックスです。下位レイヤーの各カプセルをこのレイヤーの各カプセルにマッピングします

100        self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)

u の形は[batch_size, n_capsules, n_features] .これらは下層のカプセルです

102    def forward(self, u: torch.Tensor):

ここでは、このレイヤーのカプセルのインデックスを作成し、下のレイヤー(前のレイヤー)のカプセルのインデックスに使用します。

111        u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)

初期ロジットは、カプセルと組み合わせるべき対数事前確率です。これらはゼロで初期化します

116        b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
117
118        v = None

繰り返し

121        for i in range(self.iterations):

ルーティングソフトマックス

123            c = self.softmax(b)

125            s = torch.einsum('bij,bijm->bjm', c, u_hat)

127            v = self.squash(s)

129            a = torch.einsum('bjm,bijm->bij', v, u_hat)

131            b = b + a
132
133        return v

クラス存在によるマージンロス

出力カプセルごとに個別のマージンロスが使用され、合計損失はそれらの合計になります。各出力カプセルの長さは、入力にクラスが存在する確率です。

各出力カプセルまたはクラスの損失は、

クラスが存在するかどうか、そうでない場合です。損失の最初の要素はクラスが存在しない場合で、 2番目の要素はクラスが存在する場合です。予測が極端になるのを防ぐために使用されます。新聞に掲載される予定で、掲載される予定です。

ダウンウエイトは、トレーニングの初期段階ですべてのカプセルの長さが落ちるのを防ぐために使用されます。

136class MarginLoss(Module):
156    def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
157        super().__init__()
158
159        self.m_negative = m_negative
160        self.m_positive = m_positive
161        self.lambda_ = lambda_
162        self.n_labels = n_labels

vは押しつぶされた出力カプセルです。これには形があります[batch_size, n_labels, n_features] 。つまり、ラベルごとにカプセルがあります。

labels はラベルで、形をしています[batch_size]

164    def forward(self, v: torch.Tensor, labels: torch.Tensor):

172        v_norm = torch.sqrt((v ** 2).sum(dim=-1))

labels ワンホットエンコードされた形状のラベルです [batch_size, n_labels]

176        labels = torch.eye(self.n_labels, device=labels.device)[labels]

loss 形があります[batch_size, n_labels]for の計算を並列化しました

182        loss = labels * F.relu(self.m_positive - v_norm) + \
183               self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)

186        return loss.sum(dim=-1).mean()