从带温度的语言模型中采样

在这里,我们从以下概率分布中抽样,其中是词汇,是分布的对数,T 是温度:

是正常的随机抽样。

这是一个使用这些采样技术的实验

19import torch
20from torch.distributions import Categorical
21
22from labml_nn.sampling import Sampler

带温度的采样器

25class TemperatureSampler(Sampler):
  • temperature 是要采样的温度
29    def __init__(self, temperature: float = 1.0):
33        self.temperature = temperature

来自 logits 的样本

35    def __call__(self, logits: torch.Tensor):

使用温度调整后的对数创建分类分布

41        dist = Categorical(logits=logits / self.temperature)

样本

44        return dist.sample()