核サンプリング

これは核サンプリングの実装で、論文「神経テキスト変性の奇妙な事例」で紹介されています。

この論文では、ビームサーチ、ピュアサンプリング、温度サンプリングTOP-Kサンプリングなどの他のサンプリング方法の問題について説明します。この論文では、核サンプリングのアイデアを紹介しています。核サンプリングは、テキスト生成において他のサンプリング方法よりも実質的に優れています

Nucleus サンプリングでは、最初にボキャブラリのサブセットを選択します。ここでは次のようなトークンの最小セットを選択します。

つまり、確率の合計がそれより小さくなるまで、最も可能性の高いトークンを選択します。

次に、選択したトークンからサンプリングします。

これは、これらのサンプリング手法を使用した実験です

29import torch
30from torch import nn
31
32from labml_nn.sampling import Sampler

核サンプラー

35class NucleusSampler(Sampler):
  • p ピックするトークンの確率の合計です
  • sampler 選択したトークンに使用するサンプラーです
39    def __init__(self, p: float, sampler: Sampler):
44        self.p = p
45        self.sampler = sampler

ロジットから計算するソフトマックス

47        self.softmax = nn.Softmax(dim=-1)

Nucleus サンプリングによるロジットからのサンプリング

49    def __call__(self, logits: torch.Tensor):

確率を取得

55        probs = self.softmax(logits)

確率を降順に並べ替える

58        sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)

確率の累積合計をソートされた順序で求める

60        cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)

より小さい累積和を求めます。

62        nucleus = cum_sum_probs < self.p

累積確率がそれより小さいトークンの最小数の後にトークンを1つ追加するように、1を先頭に追加します。

65        nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)

対数確率を取得して非核をマスクする

68        sorted_log_probs = torch.log(sorted_probs)
69        sorted_log_probs[~nucleus] = float('-inf')

サンプラーからのサンプル

72        sampled_sorted_indexes = self.sampler(sorted_log_probs)

実際のインデックスを取得

75        res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))

78        return res.squeeze(-1)