This is a PyTorch implementation of the paper PonderNet: Learning to Ponder.
PonderNet adapts the computation based on the input. It changes the number of steps to take on a recurrent network based on the input. PonderNet learns this with end-to-end gradient descent.
PonderNet has a step function of the form
where is the input, is the state, is the prediction at step , and is the probability of halting (stopping) at current step.
can be any neural network (e.g. LSTM, MLP, GRU, Attention layer).
The unconditioned probability of halting at step is then,
That is the probability of not being halted at any of the previous steps and halting at step .
During inference, we halt by sampling based on the halting probability and get the prediction at the halting layer as the final output.
During training, we get the predictions from all the layers and calculate the losses for each of them. And then take the weighted average of the losses based on the probabilities of getting halted at each layer .
The step function is applied to a maximum number of steps donated by .
The overall loss of PonderNet is
is the normal loss function between target and prediction .
is the Kullback–Leibler divergence.
is the Geometric distribution parameterized by . has nothing to do with ; we are just sticking to same notation as the paper. .
The regularization loss biases the network towards taking steps and incentivizes non-zero probabilities for all steps; i.e. promotes exploration.
Here is the training code experiment.py
to train a PonderNet on Parity Task.
63from typing import Tuple
64
65import torch
66from torch import nn
This is a simple model that uses a GRU Cell as the step function.
This model is for the Parity Task where the input is a vector of n_elems
. Each element of the vector is either 0
, 1
or -1
and the output is the parity - a binary value that is true if the number of 1
s is odd and false otherwise.
The prediction of the model is the log probability of the parity being .
70class ParityPonderGRU(nn.Module):
n_elems
is the number of elements in the input vector n_hidden
is the state vector size of the GRU max_steps
is the maximum number of steps 84 def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
90 super().__init__()
91
92 self.max_steps = max_steps
93 self.n_hidden = n_hidden
GRU
97 self.gru = nn.GRUCell(n_elems, n_hidden)
We could use a layer that takes the concatenation of and as input but we went with this for simplicity.
101 self.output_layer = nn.Linear(n_hidden, 1)
103 self.lambda_layer = nn.Linear(n_hidden, 1)
104 self.lambda_prob = nn.Sigmoid()
An option to set during inference so that computation is actually halted at inference time
106 self.is_halt = False
x
is the input of shape [batch_size, n_elems]
This outputs a tuple of four tensors:
1. in a tensor of shape [N, batch_size]
2. in a tensor of shape [N, batch_size]
- the log probabilities of the parity being 3. of shape [batch_size]
4. of shape [batch_size]
where the computation was halted at step
108 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
121 batch_size = x.shape[0]
We get initial state
124 h = x.new_zeros((x.shape[0], self.n_hidden))
125 h = self.gru(x, h)
Lists to store and
128 p = []
129 y = []
131 un_halted_prob = h.new_ones((batch_size,))
A vector to maintain which samples has halted computation
134 halted = h.new_zeros((batch_size,))
and where the computation was halted at step
136 p_m = h.new_zeros((batch_size,))
137 y_m = h.new_zeros((batch_size,))
Iterate for steps
140 for n in range(1, self.max_steps + 1):
The halting probability for the last step
142 if n == self.max_steps:
143 lambda_n = h.new_ones(h.shape[0])
145 else:
146 lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]
148 y_n = self.output_layer(h)[:, 0]
151 p_n = un_halted_prob * lambda_n
Update
153 un_halted_prob = un_halted_prob * (1 - lambda_n)
Halt based on halting probability
156 halt = torch.bernoulli(lambda_n) * (1 - halted)
Collect and
159 p.append(p_n)
160 y.append(y_n)
Update and based on what was halted at current step
163 p_m = p_m * (1 - halt) + p_n * halt
164 y_m = y_m * (1 - halt) + y_n * halt
Update halted samples
167 halted = halted + halt
Get next state
169 h = self.gru(x, h)
Stop the computation if all samples have halted
172 if self.is_halt and halted.sum() == batch_size:
173 break
176 return torch.stack(p), torch.stack(y), p_m, y_m
179class ReconstructionLoss(nn.Module):
loss_func
is the loss function 188 def __init__(self, loss_func: nn.Module):
192 super().__init__()
193 self.loss_func = loss_func
p
is in a tensor of shape [N, batch_size]
y_hat
is in a tensor of shape [N, batch_size, ...]
y
is the target of shape [batch_size, ...]
195 def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
The total
203 total_loss = p.new_tensor(0.)
Iterate upto
205 for n in range(p.shape[0]):
for each sample and the mean of them
207 loss = (p[n] * self.loss_func(y_hat[n], y)).mean()
Add to total loss
209 total_loss = total_loss + loss
212 return total_loss
is the Kullback–Leibler divergence.
is the Geometric distribution parameterized by . has nothing to do with ; we are just sticking to same notation as the paper. .
The regularization loss biases the network towards taking steps and incentivies non-zero probabilities for all steps; i.e. promotes exploration.
215class RegularizationLoss(nn.Module):
lambda_p
is - the success probability of geometric distribution max_steps
is the highest ; we use this to pre-compute 231 def __init__(self, lambda_p: float, max_steps: int = 1_000):
236 super().__init__()
Empty vector to calculate
239 p_g = torch.zeros((max_steps,))
241 not_halted = 1.
Iterate upto max_steps
243 for k in range(max_steps):
245 p_g[k] = not_halted * lambda_p
Update
247 not_halted = not_halted * (1 - lambda_p)
Save
250 self.p_g = nn.Parameter(p_g, requires_grad=False)
KL-divergence loss
253 self.kl_div = nn.KLDivLoss(reduction='batchmean')
p
is in a tensor of shape [N, batch_size]
255 def forward(self, p: torch.Tensor):
Transpose p
to [batch_size, N]
260 p = p.transpose(0, 1)
Get upto and expand it across the batch dimension
262 p_g = self.p_g[None, :p.shape[1]].expand_as(p)
Calculate the KL-divergence. The PyTorch KL-divergence implementation accepts log probabilities.
267 return self.kl_div(p.log(), p_g)