14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss
29class MNISTCapsuleNetworkModel(Module):
34 def __init__(self):
35 super().__init__()
第一个卷积层有,卷积内核
37 self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
第二层(Primary Capsules)是卷积胶囊层,带有卷积胶囊通道(每个胶囊的特征)。也就是说,每个主胶囊包含 8 个卷积单位,内核为 9×9,步幅为 2。为了实现这一点,我们创建了一个带有通道的卷积层,并对其输出进行整形和排列,以获得每个特征的胶囊。
43 self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44 self.squash = Squash()
路由层获取主胶囊并生成胶囊。每个主胶囊都有特征,而输出胶囊(Digit Capsules)都有特征。路由算法会迭代次数。
50 self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
这是本文中提到的解码器。它采用数字胶囊的输出,每个胶囊都有重现图像的功能。它穿过大小和激活的线性层。
55 self.decoder = nn.Sequential(
56 nn.Linear(16 * 10, 512),
57 nn.ReLU(),
58 nn.Linear(512, 1024),
59 nn.ReLU(),
60 nn.Linear(1024, 784),
61 nn.Sigmoid()
62 )
data
是 MNIST 图像,有形状[batch_size, 1, 28, 28]
64 def forward(self, data: torch.Tensor):
穿过第一个卷积层。此图层的输出具有形状[batch_size, 256, 20, 20]
70 x = F.relu(self.conv1(data))
穿过第二个卷积层。这个的输出有形状[batch_size, 32 * 8, 6, 6]
。请注意,此图层的步长为。
74 x = self.conv2(x)
调整大小并排列以获得胶囊
77 caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
挤压胶囊
79 caps = self.squash(caps)
带他们通过路由器获得数字胶囊。这有形状[batch_size, 10, 16]
。
82 caps = self.digit_capsules(caps)
获取用于重建的口罩
85 with torch.no_grad():
胶囊网络的预测是长度最长的胶囊
87 pred = (caps ** 2).sum(-1).argmax(-1)
创建遮罩以遮盖所有其他胶囊
89 mask = torch.eye(10, device=data.device)[pred]
掩盖数字胶囊以仅获得做出预测的胶囊,然后将其通过解码器进行重建
93 reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
重塑重建以匹配图像尺寸
95 reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97 return caps, reconstructions, pred
使用 MNIST 数据和训练与验证设置的配置
100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
104 epochs: int = 10
105 model: nn.Module = 'capsule_network_model'
106 reconstruction_loss = nn.MSELoss()
107 margin_loss = MarginLoss(n_labels=10)
108 accuracy = AccuracyDirect()
110 def init(self):
印刷损耗和屏幕精度
112 tracker.set_scalar('loss.*', True)
113 tracker.set_scalar('accuracy.*', True)
我们需要设置指标来计算训练和验证时期的指标
116 self.state_modules = [self.accuracy]
这个方法被训练器调用
118 def step(self, batch: Any, batch_idx: BatchIndex):
设置模型模式
123 self.model.train(self.mode.is_train)
获取图像和标签并将其移动到模特的设备上
126 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式中增加步数
129 if self.mode.is_train:
130 tracker.add_global_step(len(data))
是否记录激活次数
133 with self.mode.update(is_log_activations=batch_idx.is_last):
运行模型
135 caps, reconstructions, pred = self.model(data)
计算总损失
138 loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139 tracker.add("loss.", loss)
呼叫准确度指标
142 self.accuracy(pred, target)
143
144 if self.mode.is_train:
145 loss.backward()
146
147 self.optimizer.step()
日志参数和梯度
149 if batch_idx.is_last:
150 tracker.add('model', self.model)
151 self.optimizer.zero_grad()
152
153 tracker.save()
设置模型
156@option(Configs.model)
157def capsule_network_model(c: Configs):
159 return MNISTCapsuleNetworkModel().to(c.device)
运行实验
162def main():
166 experiment.create(name='capsule_network_mnist')
167 conf = Configs()
168 experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169 'optimizer.learning_rate': 1e-3})
170
171 experiment.add_pytorch_models({'model': conf.model})
172
173 with experiment.start():
174 conf.run()
175
176
177if __name__ == '__main__':
178 main()