原子核采样

这是原子核采样的一种实现,在论文《神经文本变性的好奇案例》中进行了介绍。

本文讨论了其他采样方法(例如光束搜索、纯采样、温度采样和T op-K采样)存在的问题。本文介绍了原子核采样的概念,在文本生成方面,核采样的效果实际上比其他采样方法要好。

Nucleus 采样首先选择词汇的一个子集,其中是最小的令牌集合

也就是说,我们选择可能性最高的代币,直到它们的概率总和小于该值为止

然后我们从选定的令牌中抽样。

这是一个使用这些采样技术的实验

29import torch
30from torch import nn
31
32from labml_nn.sampling import Sampler

Nucleus 采样器

35class NucleusSampler(Sampler):
  • p 是要选择的代币概率之和
  • sampler 是用于选定令牌的采样器
39    def __init__(self, p: float, sampler: Sampler):
44        self.p = p
45        self.sampler = sampler

要根据对数计算的 softmax

47        self.softmax = nn.Softmax(dim=-1)

使用 Nucleus 采样从 logits 中提取样本

49    def __call__(self, logits: torch.Tensor):

获取概率

55        probs = self.softmax(logits)

按降序对概率进行排序

58        sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)

按排序顺序获取概率的累积总和

60        cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)

找出小于的累计总和

62        nucleus = cum_sum_probs < self.p
在@@

前面加一个,这样我们就可以在累积概率小于该值的最小代币数量之后添加一个令牌

65        nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)

获取对数概率并掩盖非核

68        sorted_log_probs = torch.log(sorted_probs)
69        sorted_log_probs[~nucleus] = float('-inf')

来自采样器的样本

72        sampled_sorted_indexes = self.sampler(sorted_log_probs)

获取实际索引

75        res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))

78        return res.squeeze(-1)