使用胶囊网络对 MNIST 数字进行分类

这是一个带注释的 PyTorch 代码,用于使用 PyTorch 对 MNIST 数字进行分类。

本文实施了论文《胶囊间动态路由》中描述的实验。

14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss

用于对 MNIST 数字进行分类的模型

29class MNISTCapsuleNetworkModel(Module):
34    def __init__(self):
35        super().__init__()

第一个卷积层有卷积内核

37        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)

第二层(Primary Capsules)是卷积胶囊层,带有卷积胶囊通道(每个胶囊的特征)。也就是说,每个主胶囊包含 8 个卷积单位,内核为 9×9,步幅为 2。为了实现这一点,我们创建了一个带有通道的卷积层,并对其输出进行整形和排列,以获得每个特征的胶囊。

43        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44        self.squash = Squash()

路由层获取主胶囊并生成胶囊。每个主胶囊都有特征,而输出胶囊(Digit Capsules)都有特征。路由算法会迭代次数。

50        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)

这是本文中提到的解码器。它采用数字胶囊的输出,每个胶囊都有重现图像的功能。它穿过大小激活的线性层。

55        self.decoder = nn.Sequential(
56            nn.Linear(16 * 10, 512),
57            nn.ReLU(),
58            nn.Linear(512, 1024),
59            nn.ReLU(),
60            nn.Linear(1024, 784),
61            nn.Sigmoid()
62        )

data 是 MNIST 图像,有形状[batch_size, 1, 28, 28]

64    def forward(self, data: torch.Tensor):

穿过第一个卷积层。此图层的输出具有形状[batch_size, 256, 20, 20]

70        x = F.relu(self.conv1(data))

穿过第二个卷积层。这个的输出有形状[batch_size, 32 * 8, 6, 6]请注意,此图层的步长为

74        x = self.conv2(x)

调整大小并排列以获得胶囊

77        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)

挤压胶囊

79        caps = self.squash(caps)

带他们通过路由器获得数字胶囊。这有形状[batch_size, 10, 16]

82        caps = self.digit_capsules(caps)

获取用于重建的口罩

85        with torch.no_grad():

胶囊网络的预测是长度最长的胶囊

87            pred = (caps ** 2).sum(-1).argmax(-1)

创建遮罩以遮盖所有其他胶囊

89            mask = torch.eye(10, device=data.device)[pred]

掩盖数字胶囊以仅获得做出预测的胶囊,然后将其通过解码器进行重建

93        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))

重塑重建以匹配图像尺寸

95        reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97        return caps, reconstructions, pred

使用 MNIST 数据和训练与验证设置的配置

100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
104    epochs: int = 10
105    model: nn.Module = 'capsule_network_model'
106    reconstruction_loss = nn.MSELoss()
107    margin_loss = MarginLoss(n_labels=10)
108    accuracy = AccuracyDirect()
110    def init(self):

印刷损耗和屏幕精度

112        tracker.set_scalar('loss.*', True)
113        tracker.set_scalar('accuracy.*', True)

我们需要设置指标来计算训练和验证时期的指标

116        self.state_modules = [self.accuracy]

这个方法被训练器调用

118    def step(self, batch: Any, batch_idx: BatchIndex):

设置模型模式

123        self.model.train(self.mode.is_train)

获取图像和标签并将其移动到模特的设备上

126        data, target = batch[0].to(self.device), batch[1].to(self.device)

在训练模式中增加步数

129        if self.mode.is_train:
130            tracker.add_global_step(len(data))

是否记录激活次数

133        with self.mode.update(is_log_activations=batch_idx.is_last):

运行模型

135            caps, reconstructions, pred = self.model(data)

计算总损失

138        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139        tracker.add("loss.", loss)

呼叫准确度指标

142        self.accuracy(pred, target)
143
144        if self.mode.is_train:
145            loss.backward()
146
147            self.optimizer.step()

日志参数和梯度

149            if batch_idx.is_last:
150                tracker.add('model', self.model)
151            self.optimizer.zero_grad()
152
153            tracker.save()

设置模型

156@option(Configs.model)
157def capsule_network_model(c: Configs):
159    return MNISTCapsuleNetworkModel().to(c.device)

运行实验

162def main():
166    experiment.create(name='capsule_network_mnist')
167    conf = Configs()
168    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169                              'optimizer.learning_rate': 1e-3})
170
171    experiment.add_pytorch_models({'model': conf.model})
172
173    with experiment.start():
174        conf.run()
175
176
177if __name__ == '__main__':
178    main()