这是原子核采样的一种实现,在论文《神经文本变性的好奇案例》中进行了介绍。
本文讨论了其他采样方法(例如光束搜索、纯采样、温度采样和T op-K采样)存在的问题。本文介绍了原子核采样的概念,在文本生成方面,核采样的效果实际上比其他采样方法要好。
Nucleus 采样首先选择词汇的一个子集,其中是最小的令牌集合
也就是说,我们选择可能性最高的代币,直到它们的概率总和小于该值为止。
然后我们从选定的令牌中抽样。
这是一个使用这些采样技术的实验。
29import torch
30from torch import nn
31
32from labml_nn.sampling import Sampler
35class NucleusSampler(Sampler):
p
是要选择的代币概率之和sampler
是用于选定令牌的采样器39 def __init__(self, p: float, sampler: Sampler):
44 self.p = p
45 self.sampler = sampler
要根据对数计算的 softmax
47 self.softmax = nn.Softmax(dim=-1)
使用 Nucleus 采样从 logits 中提取样本
49 def __call__(self, logits: torch.Tensor):
获取概率
55 probs = self.softmax(logits)
按降序对概率进行排序
58 sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
按排序顺序获取概率的累积总和
60 cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
找出小于的累计总和。
62 nucleus = cum_sum_probs < self.p
前面加一个,这样我们就可以在累积概率小于该值的最小代币数量之后添加一个令牌。
65 nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
获取对数概率并掩盖非核
68 sorted_log_probs = torch.log(sorted_probs)
69 sorted_log_probs[~nucleus] = float('-inf')
来自采样器的样本
72 sampled_sorted_indexes = self.sampler(sorted_log_probs)
获取实际索引
75 res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
78 return res.squeeze(-1)