これは核サンプリングの実装で、論文「神経テキスト変性の奇妙な事例」で紹介されています。
この論文では、ビームサーチ、ピュアサンプリング、温度サンプリング、TOP-Kサンプリングなどの他のサンプリング方法の問題について説明します。この論文では、核サンプリングのアイデアを紹介しています。核サンプリングは、テキスト生成において他のサンプリング方法よりも実質的に優れています
。Nucleus サンプリングでは、最初にボキャブラリのサブセットを選択します。ここで、は次のようなトークンの最小セットを選択します。
つまり、確率の合計がそれより小さくなるまで、最も可能性の高いトークンを選択します。
次に、選択したトークンからサンプリングします。
29import torch
30from torch import nn
31
32from labml_nn.sampling import Sampler
35class NucleusSampler(Sampler):
p
ピックするトークンの確率の合計です sampler
選択したトークンに使用するサンプラーです39 def __init__(self, p: float, sampler: Sampler):
44 self.p = p
45 self.sampler = sampler
ロジットから計算するソフトマックス
47 self.softmax = nn.Softmax(dim=-1)
Nucleus サンプリングによるロジットからのサンプリング
49 def __call__(self, logits: torch.Tensor):
確率を取得
55 probs = self.softmax(logits)
確率を降順に並べ替える
58 sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
確率の累積合計をソートされた順序で求める
60 cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
より小さい累積和を求めます。
62 nucleus = cum_sum_probs < self.p
累積確率がそれより小さいトークンの最小数の後にトークンを1つ追加するように、1を先頭に追加します。
65 nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
対数確率を取得して非核をマスクする
68 sorted_log_probs = torch.log(sorted_probs)
69 sorted_log_probs[~nucleus] = float('-inf')
サンプラーからのサンプル
72 sampled_sorted_indexes = self.sampler(sorted_log_probs)
実際のインデックスを取得
75 res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
78 return res.squeeze(-1)