home sampling
View code on Github
Here's an experiment that uses these sampling techniques.
18import torch
21class Sampler:
logits
[..., n_tokens]
25 def __call__(self, logits: torch.Tensor) -> torch.Tensor:
31 raise NotImplementedError()