15import torch
16
17from labml_nn.sampling import Sampler
20class TopKSampler(Sampler):
k
是要挑选的代币数量sampler
是用于前 k 个代币的采样器sampler
可以是任何以 logits 张量作为输入并返回令牌张量的采样器;例如 “TemperatureSample”。
24 def __init__(self, k: int, sampler: Sampler):
32 self.k = k
33 self.sampler = sampler
来自 logits 的样本
35 def __call__(self, logits: torch.Tensor):
新的 logit 填充了;即零概率
40 zeros = logits.new_ones(logits.shape) * float('-inf')
选择最大的对数及其指数
42 values, indices = torch.topk(logits, self.k, dim=-1)
将选定前 k 个索引的值设置为实际对数。其他代币的记录仍然存在
45 zeros.scatter_(-1, indices, values)
使用指定采样器从 top-k logits 中抽样。
48 return self.sampler(zeros)