# PonderNet：学会思考

PonderNet 根据输入调整计算。它根据输入更改在循环网络上采取的步数。PonderNet 通过端到端梯度下降来了解这一点。

PonderNet 的步骤函数是这样的

PonderNet 的总损失是

63from typing import Tuple
64
65import torch
66from torch import nn
67
68from labml_helpers.module import Module

## PonderNet 与 GRU 一起执行奇偶校验任务

71class ParityPonderGRU(Module):
• n_elems 是输入向量中的元素数
• n_hidden 是 GRU 的状态向量大小
• max_steps 是最大步数
85    def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
91        super().__init__()
92
93        self.max_steps = max_steps
94        self.n_hidden = n_hidden

GRU

98        self.gru = nn.GRUCell(n_elems, n_hidden)

102        self.output_layer = nn.Linear(n_hidden, 1)
104        self.lambda_layer = nn.Linear(n_hidden, 1)
105        self.lambda_prob = nn.Sigmoid()

107        self.is_halt = False
• x 是形状的输入[batch_size, n_elems]

1.在形状为[N, batch_size] 2 的张量中。在形状张量中[N, batch_size] -奇偶校验的对数概率为 3。形状为[batch_size] 4。[batch_size] 在步进时停止计算的形状

109    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
122        batch_size = x.shape[0]

125        h = x.new_zeros((x.shape[0], self.n_hidden))
126        h = self.gru(x, h)

129        p = []
130        y = []
132        un_halted_prob = h.new_ones((batch_size,))

135        halted = h.new_zeros((batch_size,))

137        p_m = h.new_zeros((batch_size,))
138        y_m = h.new_zeros((batch_size,))

141        for n in range(1, self.max_steps + 1):

143            if n == self.max_steps:
144                lambda_n = h.new_ones(h.shape[0])
146            else:
147                lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]
149            y_n = self.output_layer(h)[:, 0]
152            p_n = un_halted_prob * lambda_n

154            un_halted_prob = un_halted_prob * (1 - lambda_n)

157            halt = torch.bernoulli(lambda_n) * (1 - halted)

160            p.append(p_n)
161            y.append(y_n)

164            p_m = p_m * (1 - halt) + p_n * halt
165            y_m = y_m * (1 - halt) + y_n * halt

168            halted = halted + halt

170            h = self.gru(x, h)

173            if self.is_halt and halted.sum() == batch_size:
174                break
177        return torch.stack(p), torch.stack(y), p_m, y_m

## 重建损失

180class ReconstructionLoss(Module):
• loss_func 是损失函数
189    def __init__(self, loss_func: nn.Module):
193        super().__init__()
194        self.loss_func = loss_func
• p 处于形状的张量[N, batch_size]
• y_hat 处于形状的张量[N, batch_size, ...]
• y 是形状的目标[batch_size, ...]
196    def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):

204        total_loss = p.new_tensor(0.)

206        for n in range(p.shape[0]):

208            loss = (p[n] * self.loss_func(y_hat[n], y)).mean()

210            total_loss = total_loss + loss
213        return total_loss

## 正规化损失

216class RegularizationLoss(Module):
• lambda_p is-几何分布的成功概率
• max_steps 是最高的；我们用它来预先计算
232    def __init__(self, lambda_p: float, max_steps: int = 1_000):
237        super().__init__()

240        p_g = torch.zeros((max_steps,))
242        not_halted = 1.

244        for k in range(max_steps):
246            p_g[k] = not_halted * lambda_p

248            not_halted = not_halted * (1 - lambda_p)

251        self.p_g = nn.Parameter(p_g, requires_grad=False)

KL-背离损失

254        self.kl_div = nn.KLDivLoss(reduction='batchmean')
• p 处于形状的张量[N, batch_size]
256    def forward(self, p: torch.Tensor):

261        p = p.transpose(0, 1)

263        p_g = self.p_g[None, :p.shape[1]].expand_as(p)

268        return self.kl_div(p.log(), p_g)