This is an implementation of nucleus sampling, introduced in the paper The Curious Case of Neural Text Degeneration.
The paper discusses the problems with other sampling methods such as Beam Search, Pure sampling, Temperature sampling, and Top-k sampling. The paper introduces the idea of nucleus sampling, which practically performs better than other sampling methods for text generation.
Nucleus sampling first picks a subset of the vocabulary , where is smallest set of tokens such that
That is, we pick the highest probable tokens until the sum of their probabilities is less that .
Then we sample from the selected tokens.
Here's an experiment that uses these sampling techniques.
29import torch
30from torch import nn
31
32from labml_nn.sampling import Sampler
35class NucleusSampler(Sampler):
p
is the sum of probabilities of tokens to pick sampler
is the sampler to use for the selected tokens39 def __init__(self, p: float, sampler: Sampler):
44 self.p = p
45 self.sampler = sampler
Softmax to compute from the logits
47 self.softmax = nn.Softmax(dim=-1)
Sample from logits with Nucleus Sampling
49 def __call__(self, logits: torch.Tensor):
Get probabilities
55 probs = self.softmax(logits)
Sort probabilities in descending order
58 sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
Get the cumulative sum of probabilities in the sorted order
60 cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
Find the cumulative sums less than .
62 nucleus = cum_sum_probs < self.p
Prepend ones so that we add one token after the minimum number of tokens with cumulative probability less that .
65 nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
Get log probabilities and mask out the non-nucleus
68 sorted_log_probs = torch.log(sorted_probs)
69 sorted_log_probs[~nucleus] = float('-inf')
Sample from the sampler
72 sampled_sorted_indexes = self.sampler(sorted_log_probs)
Get the actual indexes
75 res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
78 return res.squeeze(-1)