home sampling
View code on Github
在这里,我们从日志分布中抽取最有可能的令牌。
这是一个使用这些采样技术的实验。
14import torch 15 16from labml_nn.sampling import Sampler
19class GreedySampler(Sampler):
从日志分布中抽取最有可能的令牌
20 def __call__(self, logits: torch.Tensor):
24 return logits.argmax(dim=-1)