home sampling
View code on Github
这是一个使用这些采样技术的实验。
18import torch
21class Sampler:
logits
[..., n_tokens]
25 def __call__(self, logits: torch.Tensor) -> torch.Tensor:
31 raise NotImplementedError()