这是 P onderNet:学会思考这篇论文的 PyT orch 实现。
PonderNet 根据输入调整计算。它根据输入更改在循环网络上采取的步数。PonderNet 通过端到端梯度下降来了解这一点。
PonderNet 的步骤函数是这样的
其中,是输入,是状态,是步骤中的预测,是当前步骤停止(停止)的概率。
可以是任何神经网络(例如 LSTM、MLP、GRU、注意力层)。
因此,按步停顿的无条件概率是
也就是说,在之前的任何步骤中都不会停下来,而是一步步停下来的可能性。
在推理过程中,我们根据停止概率通过采样来停止,并将停顿层的预测作为最终输出。
在训练期间,我们会得到所有层的预测,并计算每个层的损失。然后根据每层停止的概率得出损失的加权平均值。
步长函数应用于捐赠的最大步数。
PonderNet 的总损失是
是目标和预测之间的正态损失函数。
是参数化的几何分布。无关;我们只是坚持使用与论文相同的符号。
。正则化损失使网络偏向于采取步骤,并激励所有步骤的非零概率;即促进探索。
以下是在 PonderNet 上训练 PonderNet 完成奇偶任务的训练代码experiment.py
。
63from typing import Tuple
64
65import torch
66from torch import nn
67
68from labml_helpers.module import Module
这是一个使用 GRU Cell 作为步进函数的简单模型。
此模型适用于奇偶校验任务,其中输入的向量为n_elems
。向量的每个元素都是1
或0
-1
,输出为奇偶校验——如果1
s的数量为奇数,则为true的二进制值否则为 false。
模型的预测是奇偶校验的对数概率。
71class ParityPonderGRU(Module):
n_elems
是输入向量中的元素数n_hidden
是 GRU 的状态向量大小max_steps
是最大步数85 def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
91 super().__init__()
92
93 self.max_steps = max_steps
94 self.n_hidden = n_hidden
GRU
98 self.gru = nn.GRUCell(n_elems, n_hidden)
我们可以使用一个将和的串联作为输入的层,但为了简单起见,我们使用了这个层。
102 self.output_layer = nn.Linear(n_hidden, 1)
104 self.lambda_layer = nn.Linear(n_hidden, 1)
105 self.lambda_prob = nn.Sigmoid()
在推理期间设置的选项,以便在推理时实际停止计算
107 self.is_halt = False
x
是形状的输入[batch_size, n_elems]
这将输出一个由四个张量组成的元组:
1.在形状为[N, batch_size]
2 的张量中。在形状张量中[N, batch_size]
-奇偶校验的对数概率为 3。形状为[batch_size]
4。[batch_size]
在步进时停止计算的形状
109 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
122 batch_size = x.shape[0]
我们得到初始状态
125 h = x.new_zeros((x.shape[0], self.n_hidden))
126 h = self.gru(x, h)
要存储的清单和
129 p = []
130 y = []
132 un_halted_prob = h.new_ones((batch_size,))
用于维护哪些样本已停止计算的向量
135 halted = h.new_zeros((batch_size,))
以及计算在何处停止
137 p_m = h.new_zeros((batch_size,))
138 y_m = h.new_zeros((batch_size,))
迭代步骤
141 for n in range(1, self.max_steps + 1):
最后一步的停止概率
143 if n == self.max_steps:
144 lambda_n = h.new_ones(h.shape[0])
146 else:
147 lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]
149 y_n = self.output_layer(h)[:, 0]
152 p_n = un_halted_prob * lambda_n
更新
154 un_halted_prob = un_halted_prob * (1 - lambda_n)
基于暂停概率暂停
157 halt = torch.bernoulli(lambda_n) * (1 - halted)
收集和
160 p.append(p_n)
161 y.append(y_n)
更新并基于当前步骤中止的内容
164 p_m = p_m * (1 - halt) + p_n * halt
165 y_m = y_m * (1 - halt) + y_n * halt
更新已暂停的样本
168 halted = halted + halt
获取下一个状态
170 h = self.gru(x, h)
如果所有样本都已停止,则停止计算
173 if self.is_halt and halted.sum() == batch_size:
174 break
177 return torch.stack(p), torch.stack(y), p_m, y_m
180class ReconstructionLoss(Module):
loss_func
是损失函数189 def __init__(self, loss_func: nn.Module):
193 super().__init__()
194 self.loss_func = loss_func
p
处于形状的张量[N, batch_size]
y_hat
处于形状的张量[N, batch_size, ...]
y
是形状的目标[batch_size, ...]
196 def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
总数
204 total_loss = p.new_tensor(0.)
迭代到
206 for n in range(p.shape[0]):
每个样本及其均值
208 loss = (p[n] * self.loss_func(y_hat[n], y)).mean()
再加上总亏损
210 total_loss = total_loss + loss
213 return total_loss
216class RegularizationLoss(Module):
lambda_p
is-几何分布的成功概率max_steps
是最高的;我们用它来预先计算232 def __init__(self, lambda_p: float, max_steps: int = 1_000):
237 super().__init__()
要计算的空向量
240 p_g = torch.zeros((max_steps,))
242 not_halted = 1.
迭代到max_steps
244 for k in range(max_steps):
246 p_g[k] = not_halted * lambda_p
更新
248 not_halted = not_halted * (1 - lambda_p)
保存
251 self.p_g = nn.Parameter(p_g, requires_grad=False)
KL-背离损失
254 self.kl_div = nn.KLDivLoss(reduction='batchmean')
p
处于形状的张量[N, batch_size]
256 def forward(self, p: torch.Tensor):
移调p
到[batch_size, N]
261 p = p.transpose(0, 1)
了解并将其扩展到整个批次维度
263 p_g = self.p_g[None, :p.shape[1]].expand_as(p)
计算 KL 背离。PyTorch KL-Divergen ce 实现接受对数概率。
268 return self.kl_div(p.log(), p_g)