前 k 个采样

在这里,我们首先从logits分布中挑选top-k代币,然后从中采样。

这是一个使用这些采样技术的实验

15import torch
16
17from labml_nn.sampling import Sampler

Top-k 采样器

20class TopKSampler(Sampler):
  • k 是要挑选的代币数量
  • sampler 是用于前 k 个代币的采样器

sampler 可以是任何以 logits 张量作为输入并返回令牌张量的采样器;例如 “TemperatureSample”

24    def __init__(self, k: int, sampler: Sampler):
32        self.k = k
33        self.sampler = sampler

来自 logits 的样本

35    def __call__(self, logits: torch.Tensor):

新的 logit 填充了;即零概率

40        zeros = logits.new_ones(logits.shape) * float('-inf')

选择最大的对数及其指数

42        values, indices = torch.topk(logits, self.k, dim=-1)

将选定前 k 个索引的值设置为实际对数。其他代币的记录仍然存在

45        zeros.scatter_(-1, indices, values)

使用指定采样器从 top-k logits 中抽样。

48        return self.sampler(zeros)