Top-k Sampling

Here we first pick the top-k tokens from the distribution of logits, and then sample from them.

Here's an experiment that uses these sampling techniques.

15import torch
17from labml_nn.sampling import Sampler

Top-k Sampler

20class TopKSampler(Sampler):
  • k is the number of tokens to pick
  • sampler is the sampler to use for the top-k tokens

sampler can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. `TemperatureSampler'.

24    def __init__(self, k: int, sampler: Sampler):
32        self.k = k
33        self.sampler = sampler

Sample from logits

35    def __call__(self, logits: torch.Tensor):

New logits filled with ; i.e. zero probability

40        zeros = logits.new_ones(logits.shape) * float('-inf')

Pick the largest logits and their indices

42        values, indices = torch.topk(logits, self.k, dim=-1)

Set the values of the top-k selected indices to actual logits. Logits of other tokens remain

45        zeros.scatter_(-1, indices, values)

Sample from the top-k logits with the specified sampler.

48        return self.sampler(zeros)