19import torch
20from torch.distributions import Categorical
21
22from labml_nn.sampling import Sampler
25class TemperatureSampler(Sampler):
temperature
是要采样的温度29 def __init__(self, temperature: float = 1.0):
33 self.temperature = temperature
来自 logits 的样本
35 def __call__(self, logits: torch.Tensor):
使用温度调整后的对数创建分类分布
41 dist = Categorical(logits=logits / self.temperature)
样本
44 return dist.sample()