语言模型的采样技术

这是一个使用这些采样技术的实验

18import torch

采样器基类

21class Sampler:

来自 logits 的样本

  • logits 是形状分布的对数[..., n_tokens]
25    def __call__(self, logits: torch.Tensor) -> torch.Tensor:
31        raise NotImplementedError()