This is a PyTorch implementation of the paper PonderNet: Learning to Ponder.

PonderNet adapts the computation based on the input. It changes the number of steps to take on a recurrent network based on the input. PonderNet learns this with end-to-end gradient descent.

PonderNet has a step function of the form

$y^ _{n},h_{n+1},λ_{n}=s(x,h_{n})$

where $x$ is the input, $h_{n}$ is the state, $y^ _{n}$ is the prediction at step $n$, and $λ_{n}$ is the probability of halting (stopping) at current step.

$s$ can be any neural network (e.g. LSTM, MLP, GRU, Attention layer).

The unconditioned probability of halting at step $n$ is then,

$p_{n}=λ_{n}j=1∏n−1 (1−λ_{j})$

That is the probability of not being halted at any of the previous steps and halting at step $n$.

During inference, we halt by sampling based on the halting probability $λ_{n}$ and get the prediction at the halting layer $y^ _{n}$ as the final output.

During training, we get the predictions from all the layers and calculate the losses for each of them. And then take the weighted average of the losses based on the probabilities of getting halted at each layer $p_{n}$.

The step function is applied to a maximum number of steps donated by $N$.

The overall loss of PonderNet is

$LL_{Rec}L_{Reg} =L_{Rec}+βL_{Reg}=n=1∑N p_{n}L(y,y^ _{n})=KL(p_{n}∥p_{G}(λ_{p})) $$L$ is the normal loss function between target $y$ and prediction $y^ _{n}$.

$KL$ is the Kullback–Leibler divergence.

$p_{G}$ is the Geometric distribution parameterized by $λ_{p}$. *$λ_{p}$ has nothing to do with $λ_{n}$; we are just sticking to same notation as the paper*. $Pr_{p_{G}(λ_{p})}(X=k)=(1−λ_{p})_{k}λ_{p}$.

The regularization loss biases the network towards taking $λ_{p}1 $ steps and incentivizes non-zero probabilities for all steps; i.e. promotes exploration.

Here is the training code `experiment.py`

to train a PonderNet on Parity Task.

```
65from typing import Tuple
66
67import torch
68from torch import nn
69
70from labml_helpers.module import Module
```

This is a simple model that uses a GRU Cell as the step function.

This model is for the Parity Task where the input is a vector of `n_elems`

. Each element of the vector is either `0`

, `1`

or `-1`

and the output is the parity - a binary value that is true if the number of `1`

s is odd and false otherwise.

The prediction of the model is the log probability of the parity being $1$.

`73class ParityPonderGRU(Module):`

`n_elems`

is the number of elements in the input vector`n_hidden`

is the state vector size of the GRU`max_steps`

is the maximum number of steps $N$

`87 def __init__(self, n_elems: int, n_hidden: int, max_steps: int):`

```
93 super().__init__()
94
95 self.max_steps = max_steps
96 self.n_hidden = n_hidden
```

GRU $h_{n+1}=s_{h}(x,h_{n})$

`100 self.gru = nn.GRUCell(n_elems, n_hidden)`

$y^ _{n}=s_{y}(h_{n})$ We could use a layer that takes the concatenation of $h$ and $x$ as input but we went with this for simplicity.

`104 self.output_layer = nn.Linear(n_hidden, 1)`

$λ_{n}=s_{λ}(h_{n})$

```
106 self.lambda_layer = nn.Linear(n_hidden, 1)
107 self.lambda_prob = nn.Sigmoid()
```

An option to set during inference so that computation is actually halted at inference time

`109 self.is_halt = False`

`x`

is the input of shape`[batch_size, n_elems]`

This outputs a tuple of four tensors:

1. $p_{1}…p_{N}$ in a tensor of shape `[N, batch_size]`

2. $y^ _{1}…y^ _{N}$ in a tensor of shape `[N, batch_size]`

- the log probabilities of the parity being $1$ 3. $p_{m}$ of shape `[batch_size]`

4. $y^ _{m}$ of shape `[batch_size]`

where the computation was halted at step $m$

`111 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:`

`124 batch_size = x.shape[0]`

We get initial state $h_{1}=s_{h}(x)$

```
127 h = x.new_zeros((x.shape[0], self.n_hidden))
128 h = self.gru(x, h)
```

Lists to store $p_{1}…p_{N}$ and $y^ _{1}…y^ _{N}$

```
131 p = []
132 y = []
```

$∏_{j=1}(1−λ_{j})$

`134 un_halted_prob = h.new_ones((batch_size,))`

A vector to maintain which samples has halted computation

`137 halted = h.new_zeros((batch_size,))`

$p_{m}$ and $y^ _{m}$ where the computation was halted at step $m$

```
139 p_m = h.new_zeros((batch_size,))
140 y_m = h.new_zeros((batch_size,))
```

Iterate for $N$ steps

`143 for n in range(1, self.max_steps + 1):`

The halting probability $λ_{N}=1$ for the last step

```
145 if n == self.max_steps:
146 lambda_n = h.new_ones(h.shape[0])
```

$λ_{n}=s_{λ}(h_{n})$

```
148 else:
149 lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]
```

$y^ _{n}=s_{y}(h_{n})$

`151 y_n = self.output_layer(h)[:, 0]`

$p_{n}=λ_{n}j=1∏n−1 (1−λ_{j})$

`154 p_n = un_halted_prob * lambda_n`

Update $∏_{j=1}(1−λ_{j})$

`156 un_halted_prob = un_halted_prob * (1 - lambda_n)`

Halt based on halting probability $λ_{n}$

`159 halt = torch.bernoulli(lambda_n) * (1 - halted)`

Collect $p_{n}$ and $y^ _{n}$

```
162 p.append(p_n)
163 y.append(y_n)
```

Update $p_{m}$ and $y^ _{m}$ based on what was halted at current step $n$

```
166 p_m = p_m * (1 - halt) + p_n * halt
167 y_m = y_m * (1 - halt) + y_n * halt
```

Update halted samples

`170 halted = halted + halt`

Get next state $h_{n+1}=s_{h}(x,h_{n})$

`172 h = self.gru(x, h)`

Stop the computation if all samples have halted

```
175 if self.is_halt and halted.sum() == batch_size:
176 break
```

`179 return torch.stack(p), torch.stack(y), p_m, y_m`

$L_{Rec}=n=1∑N p_{n}L(y,y^ _{n})$

$L$ is the normal loss function between target $y$ and prediction $y^ _{n}$.

`182class ReconstructionLoss(Module):`

`loss_func`

is the loss function $L$

`191 def __init__(self, loss_func: nn.Module):`

```
195 super().__init__()
196 self.loss_func = loss_func
```

`p`

is $p_{1}…p_{N}$ in a tensor of shape`[N, batch_size]`

`y_hat`

is $y^ _{1}…y^ _{N}$ in a tensor of shape`[N, batch_size, ...]`

`y`

is the target of shape`[batch_size, ...]`

`198 def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):`

The total $∑_{n=1}p_{n}L(y,y^ _{n})$

`206 total_loss = p.new_tensor(0.)`

Iterate upto $N$

`208 for n in range(p.shape[0]):`

$p_{n}L(y,y^ _{n})$ for each sample and the mean of them

`210 loss = (p[n] * self.loss_func(y_hat[n], y)).mean()`

Add to total loss

`212 total_loss = total_loss + loss`

`215 return total_loss`

$L_{Reg}=KL(p_{n}∥p_{G}(λ_{p}))$

$KL$ is the Kullback–Leibler divergence.

$p_{G}$ is the Geometric distribution parameterized by $λ_{p}$. *$λ_{p}$ has nothing to do with $λ_{n}$; we are just sticking to same notation as the paper*. $Pr_{p_{G}(λ_{p})}(X=k)=(1−λ_{p})_{k}λ_{p}$.

The regularization loss biases the network towards taking $λ_{p}1 $ steps and incentivies non-zero probabilities for all steps; i.e. promotes exploration.

`218class RegularizationLoss(Module):`

`lambda_p`

is $λ_{p}$ - the success probability of geometric distribution`max_steps`

is the highest $N$; we use this to pre-compute $p_{G}(λ_{p})$

`234 def __init__(self, lambda_p: float, max_steps: int = 1_000):`

`239 super().__init__()`

Empty vector to calculate $p_{G}(λ_{p})$

`242 p_g = torch.zeros((max_steps,))`

$(1−λ_{p})_{k}$

`244 not_halted = 1.`

Iterate upto `max_steps`

`246 for k in range(max_steps):`

$Pr_{p_{G}(λ_{p})}(X=k)=(1−λ_{p})_{k}λ_{p}$

`248 p_g[k] = not_halted * lambda_p`

Update $(1−λ_{p})_{k}$

`250 not_halted = not_halted * (1 - lambda_p)`

Save $Pr_{p_{G}(λ_{p})}$

`253 self.p_g = nn.Parameter(p_g, requires_grad=False)`

KL-divergence loss

`256 self.kl_div = nn.KLDivLoss(reduction='batchmean')`

`p`

is $p_{1}…p_{N}$ in a tensor of shape`[N, batch_size]`

`258 def forward(self, p: torch.Tensor):`

Transpose `p`

to `[batch_size, N]`

`263 p = p.transpose(0, 1)`

Get $Pr_{p_{G}(λ_{p})}$ upto $N$ and expand it across the batch dimension

`265 p_g = self.p_g[None, :p.shape[1]].expand_as(p)`

Calculate the KL-divergence. *The PyTorch KL-divergence implementation accepts log probabilities.*

`270 return self.kl_div(p.log(), p_g)`