Sampling Techniques for Language Models

Here's an experiment that uses these sampling techniques.

18import torch

Sampler base class

21class Sampler:

Sample from logits

  • logits are the logits of the distribution of shape [..., n_tokens]
25    def __call__(self, logits: torch.Tensor) -> torch.Tensor:
31        raise NotImplementedError()