PonderNet: Learning to Ponder

This is a PyTorch implementation of the paper PonderNet: Learning to Ponder.

PonderNet adapts the computation based on the input. It changes the number of steps to take on a recurrent network based on the input. PonderNet learns this with end-to-end gradient descent.

PonderNet has a step function of the form

where $x$ is the input, $h_n$ is the state, $\hat{y}_n$ is the prediction at step $n$, and $\lambda_n$ is the probability of halting (stopping) at current step.

$s$ can be any neural network (e.g. LSTM, MLP, GRU, Attention layer).

The unconditioned probability of halting at step $n$ is then,

That is the probability of not being halted at any of the previous steps and halting at step $n$.

During inference, we halt by sampling based on the halting probability $\lambda_n$ and get the prediction at the halting layer $\hat{y}_n$ as the final output.

During training, we get the predictions from all the layers and calculate the losses for each of them. And then take the weighted average of the losses based on the probabilities of getting halted at each layer $p_n$.

The step function is applied to a maximum number of steps donated by $N$.

The overall loss of PonderNet is

$\mathcal{L}$ is the normal loss function between target $y$ and prediction $\hat{y}_n$.

$\mathop{KL}$ is the KullbackÔÇôLeibler divergence.

$p_G$ is the Geometric distribution parameterized by $\lambda_p$. $\lambda_p$ has nothing to do with $\lambda_n$; we are just sticking to same notation as the paper. .

The regularization loss biases the network towards taking $\frac{1}{\lambda_p}$ steps and incentivizes non-zero probabilities for all steps; i.e. promotes exploration.

Here is the training code experiment.py to train a PonderNet on Parity Task.

View Run

64from typing import Tuple
65
66import torch
67from torch import nn
68
69from labml_helpers.module import Module

PonderNet with GRU for Parity Task

This is a simple model that uses a GRU Cell as the step function.

This model is for the Parity Task where the input is a vector of n_elems. Each element of the vector is either 0, 1 or -1 and the output is the parity - a binary value that is true if the number of 1s is odd and false otherwise.

The prediction of the model is the log probability of the parity being $1$.

72class ParityPonderGRU(Module):
  • n_elems is the number of elements in the input vector
  • n_hidden is the state vector size of the GRU
  • max_steps is the maximum number of steps $N$
86    def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
92        super().__init__()
93
94        self.max_steps = max_steps
95        self.n_hidden = n_hidden

GRU

99        self.gru = nn.GRUCell(n_elems, n_hidden)

We could use a layer that takes the concatenation of $h$ and $x$ as input but we went with this for simplicity.

103        self.output_layer = nn.Linear(n_hidden, 1)

105        self.lambda_layer = nn.Linear(n_hidden, 1)
106        self.lambda_prob = nn.Sigmoid()

An option to set during inference so that computation is actually halted at inference time

108        self.is_halt = False
  • x is the input of shape [batch_size, n_elems]

This outputs a tuple of four tensors:

  1. $p_1 \dots p_N$ in a tensor of shape [N, batch_size]
  2. $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape [N, batch_size] - the log probabilities of the parity being $1$
  3. $p_m$ of shape [batch_size]
  4. $\hat{y}_m$ of shape [batch_size] where the computation was halted at step $m$
110    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
123        batch_size = x.shape[0]

We get initial state $h_1 = s_h(x)$

126        h = x.new_zeros((x.shape[0], self.n_hidden))
127        h = self.gru(x, h)

Lists to store $p_1 \dots p_N$ and $\hat{y}_1 \dots \hat{y}_N$

130        p = []
131        y = []

$\prod_{j=1}^{n-1} (1 - \lambda_j)$

133        un_halted_prob = h.new_ones((batch_size,))

A vector to maintain which samples has halted computation

136        halted = h.new_zeros((batch_size,))

$p_m$ and $\hat{y}_m$ where the computation was halted at step $m$

138        p_m = h.new_zeros((batch_size,))
139        y_m = h.new_zeros((batch_size,))

Iterate for $N$ steps

142        for n in range(1, self.max_steps + 1):

The halting probability $\lambda_N = 1$ for the last step

144            if n == self.max_steps:
145                lambda_n = h.new_ones(h.shape[0])

$\lambda_n = s_\lambda(h_n)$

147            else:
148                lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]

$\hat{y}_n = s_y(h_n)$

150            y_n = self.output_layer(h)[:, 0]

153            p_n = un_halted_prob * lambda_n

Update $\prod_{j=1}^{n-1} (1 - \lambda_j)$

155            un_halted_prob = un_halted_prob * (1 - lambda_n)

Halt based on halting probability $\lambda_n$

158            halt = torch.bernoulli(lambda_n) * (1 - halted)

Collect $p_n$ and $\hat{y}_n$

161            p.append(p_n)
162            y.append(y_n)

Update $p_m$ and $\hat{y}_m$ based on what was halted at current step $n$

165            p_m = p_m * (1 - halt) + p_n * halt
166            y_m = y_m * (1 - halt) + y_n * halt

Update halted samples

169            halted = halted + halt

Get next state $h_{n+1} = s_h(x, h_n)$

171            h = self.gru(x, h)

Stop the computation if all samples have halted

174            if self.is_halt and halted.sum() == batch_size:
175                break
178        return torch.stack(p), torch.stack(y), p_m, y_m

Reconstruction loss

$\mathcal{L}$ is the normal loss function between target $y$ and prediction $\hat{y}_n$.

181class ReconstructionLoss(Module):
  • loss_func is the loss function $\mathcal{L}$
190    def __init__(self, loss_func: nn.Module):
194        super().__init__()
195        self.loss_func = loss_func
  • p is $p_1 \dots p_N$ in a tensor of shape [N, batch_size]
  • y_hat is $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape [N, batch_size, ...]
  • y is the target of shape [batch_size, ...]
197    def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):

The total $\sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n)$

205        total_loss = p.new_tensor(0.)

Iterate upto $N$

207        for n in range(p.shape[0]):

$p_n \mathcal{L}(y, \hat{y}_n)$ for each sample and the mean of them

209            loss = (p[n] * self.loss_func(y_hat[n], y)).mean()

Add to total loss

211            total_loss = total_loss + loss
214        return total_loss

Regularization loss

$\mathop{KL}$ is the KullbackÔÇôLeibler divergence.

$p_G$ is the Geometric distribution parameterized by $\lambda_p$. $\lambda_p$ has nothing to do with $\lambda_n$; we are just sticking to same notation as the paper. .

The regularization loss biases the network towards taking $\frac{1}{\lambda_p}$ steps and incentivies non-zero probabilities for all steps; i.e. promotes exploration.

217class RegularizationLoss(Module):
  • lambda_p is $\lambda_p$ - the success probability of geometric distribution
  • max_steps is the highest $N$; we use this to pre-compute $p_G(\lambda_p)$
233    def __init__(self, lambda_p: float, max_steps: int = 1_000):
238        super().__init__()

Empty vector to calculate $p_G(\lambda_p)$

241        p_g = torch.zeros((max_steps,))

$(1 - \lambda_p)^k$

243        not_halted = 1.

Iterate upto max_steps

245        for k in range(max_steps):

247            p_g[k] = not_halted * lambda_p

Update $(1 - \lambda_p)^k$

249            not_halted = not_halted * (1 - lambda_p)

Save $Pr_{p_G(\lambda_p)}$

252        self.p_g = nn.Parameter(p_g, requires_grad=False)

KL-divergence loss

255        self.kl_div = nn.KLDivLoss(reduction='batchmean')
  • p is $p_1 \dots p_N$ in a tensor of shape [N, batch_size]
257    def forward(self, p: torch.Tensor):

Transpose p to [batch_size, N]

262        p = p.transpose(0, 1)

Get $Pr_{p_G(\lambda_p)}$ upto $N$ and expand it across the batch dimension

264        p_g = self.p_g[None, :p.shape[1]].expand_as(p)

Calculate the KL-divergence. The PyTorch KL-divergence implementation accepts log probabilities.

269        return self.kl_div(p.log(), p_g)