PonderNet:学会思考

这是 P onderNet:学会思考这篇论文的 PyT orch 实现。

PonderNet 根据输入调整计算。它根据输入更改在循环网络上采取的步数。PonderNet 通过端到端梯度下降来了解这一点。

PonderNet 的步骤函数是这样的

其中,是输入,是状态,是步骤中的预测是当前步骤停止(停止)的概率。

可以是任何神经网络(例如 LSTM、MLP、GRU、注意力层)。

因此,按步停顿的无条件概率

也就是说,在之前的任何步骤中都不会停下来,而是一步步停下来的可能性

在推理过程中,我们根据停止概率通过采样来停止,并将停顿层的预测作为最终输出。

在训练期间,我们会得到所有层的预测,并计算每个层的损失。然后根据每层停止的概率得出损失的加权平均值

步长函数应用于捐赠的最大步数

PonderNet 的总损失是

是目标和预测之间的正态损失函数

Kullback—Leibler 的分歧

是参数化的几何分布无关;我们只是坚持使用与论文相同的符号

正则化损失使网络偏向于采取步骤,并激励所有步骤的非零概率;即促进探索。

以下是在 PonderNet 上训练 PonderNet 完成奇偶任务的训练代码experiment.py

63from typing import Tuple
64
65import torch
66from torch import nn
67
68from labml_helpers.module import Module

PonderNet 与 GRU 一起执行奇偶校验任务

这是一个使用 GRU Cell 作为步进函数的简单模型。

此模型适用于奇偶校验任务,其中输入的向量为n_elems 。向量的每个元素都是10 -1 ,输出为奇偶校验——如果1 s的数量为奇数,则为true的二进制值否则为 false。

模型的预测是奇偶校验的对数概率

71class ParityPonderGRU(Module):
  • n_elems 是输入向量中的元素数
  • n_hidden 是 GRU 的状态向量大小
  • max_steps 是最大步数
85    def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
91        super().__init__()
92
93        self.max_steps = max_steps
94        self.n_hidden = n_hidden

GRU

98        self.gru = nn.GRUCell(n_elems, n_hidden)

我们可以使用一个将和的串联作为输入的层,但为了简单起见,我们使用了这个层。

102        self.output_layer = nn.Linear(n_hidden, 1)

104        self.lambda_layer = nn.Linear(n_hidden, 1)
105        self.lambda_prob = nn.Sigmoid()

在推理期间设置的选项,以便在推理时实际停止计算

107        self.is_halt = False
  • x 是形状的输入[batch_size, n_elems]

这将输出一个由四个张量组成的元组:

1.在形状为[N, batch_size] 2 的张量中。在形状张量中[N, batch_size] -奇偶校验的对数概率为 3。形状为[batch_size] 4。[batch_size] 在步进时停止计算的形状

109    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

122        batch_size = x.shape[0]

我们得到初始状态

125        h = x.new_zeros((x.shape[0], self.n_hidden))
126        h = self.gru(x, h)

要存储的清单

129        p = []
130        y = []

132        un_halted_prob = h.new_ones((batch_size,))

用于维护哪些样本已停止计算的向量

135        halted = h.new_zeros((batch_size,))

以及计算在何处停止

137        p_m = h.new_zeros((batch_size,))
138        y_m = h.new_zeros((batch_size,))

迭代步骤

141        for n in range(1, self.max_steps + 1):

最后一步的停止概率

143            if n == self.max_steps:
144                lambda_n = h.new_ones(h.shape[0])

146            else:
147                lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]

149            y_n = self.output_layer(h)[:, 0]

152            p_n = un_halted_prob * lambda_n

更新

154            un_halted_prob = un_halted_prob * (1 - lambda_n)

基于暂停概率暂停

157            halt = torch.bernoulli(lambda_n) * (1 - halted)

收集

160            p.append(p_n)
161            y.append(y_n)

更新基于当前步骤中止的内容

164            p_m = p_m * (1 - halt) + p_n * halt
165            y_m = y_m * (1 - halt) + y_n * halt

更新已暂停的样本

168            halted = halted + halt

获取下一个状态

170            h = self.gru(x, h)

如果所有样本都已停止,则停止计算

173            if self.is_halt and halted.sum() == batch_size:
174                break

177        return torch.stack(p), torch.stack(y), p_m, y_m

重建损失

是目标和预测之间的正常损失函数

180class ReconstructionLoss(Module):
  • loss_func 是损失函数
189    def __init__(self, loss_func: nn.Module):
193        super().__init__()
194        self.loss_func = loss_func
  • p 处于形状的张量[N, batch_size]
  • y_hat 处于形状的张量[N, batch_size, ...]
  • y 是形状的目标[batch_size, ...]
196    def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):

总数

204        total_loss = p.new_tensor(0.)

迭代到

206        for n in range(p.shape[0]):

每个样本及其均值

208            loss = (p[n] * self.loss_func(y_hat[n], y)).mean()

再加上总亏损

210            total_loss = total_loss + loss

213        return total_loss

正规化损失

Kullback—Leibler 背离

是通过参数化的几何分布与之无关;我们只是坚持使用与报纸相同的符号

正则化损失使网络偏向于采取措施,并激励所有步骤的非零概率;即促进探索。

216class RegularizationLoss(Module):
  • lambda_p is-几何分布的成功概率
  • max_steps 是最高的;我们用它来预先计算
232    def __init__(self, lambda_p: float, max_steps: int = 1_000):
237        super().__init__()

要计算的空向量

240        p_g = torch.zeros((max_steps,))

242        not_halted = 1.

迭代到max_steps

244        for k in range(max_steps):

246            p_g[k] = not_halted * lambda_p

更新

248            not_halted = not_halted * (1 - lambda_p)

保存

251        self.p_g = nn.Parameter(p_g, requires_grad=False)

KL-背离损失

254        self.kl_div = nn.KLDivLoss(reduction='batchmean')
  • p 处于形状的张量[N, batch_size]
256    def forward(self, p: torch.Tensor):

移调p[batch_size, N]

261        p = p.transpose(0, 1)

了解并将其扩展到整个批次维度

263        p_g = self.p_g[None, :p.shape[1]].expand_as(p)

计算 KL 背离。PyTorch KL-Divergen ce 实现接受对数概率。

268        return self.kl_div(p.log(), p_g)