Sampling from Language Models with Temperature

Here we sample from the following probability distribution where is the vocabulary, are the logits of the distribution and T is the temperature:

is normal random sampling.

Here's an experiment that uses these sampling techniques.

19import torch
20from torch.distributions import Categorical
21
22from labml_nn.sampling import Sampler

Sampler with Temperature

25class TemperatureSampler(Sampler):
  • temperature is the temperature to sample with
29    def __init__(self, temperature: float = 1.0):
33        self.temperature = temperature

Sample from logits

35    def __call__(self, logits: torch.Tensor):

Create a categorical distribution with temperature adjusted logits

41        dist = Categorical(logits=logits / self.temperature)

Sample

44        return dist.sample()