Here we sample from the following probability distribution where is the vocabulary, are the logits of the distribution and T is the temperature:
is normal random sampling.
Here's an experiment that uses these sampling techniques.
19import torch
20from torch.distributions import Categorical
21
22from labml_nn.sampling import Sampler
25class TemperatureSampler(Sampler):
temperature
is the temperature to sample with29 def __init__(self, temperature: float = 1.0):
33 self.temperature = temperature
Sample from logits
35 def __call__(self, logits: torch.Tensor):
Create a categorical distribution with temperature adjusted logits
41 dist = Categorical(logits=logits / self.temperature)
Sample
44 return dist.sample()