Here we first pick the top-k tokens from the distribution of logits, and then sample from them.
Here's an experiment that uses these sampling techniques.
15import torch
16
17from labml_nn.sampling import Sampler
20class TopKSampler(Sampler):
k
is the number of tokens to pick sampler
is the sampler to use for the top-k tokenssampler
can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. `TemperatureSampler'.
24 def __init__(self, k: int, sampler: Sampler):
32 self.k = k
33 self.sampler = sampler
Sample from logits
35 def __call__(self, logits: torch.Tensor):
New logits filled with ; i.e. zero probability
40 zeros = logits.new_ones(logits.shape) * float('-inf')
Pick the largest logits and their indices
42 values, indices = torch.topk(logits, self.k, dim=-1)
Set the values of the top-k selected indices to actual logits. Logits of other tokens remain
45 zeros.scatter_(-1, indices, values)
Sample from the top-k logits with the specified sampler.
48 return self.sampler(zeros)