18import torch

サンプラー基本クラス

21class Sampler:

ロジットからのサンプル

  • logits 形状分布のロジットです [..., n_tokens]
25    def __call__(self, logits: torch.Tensor) -> torch.Tensor:
31        raise NotImplementedError()