トップkサンプリング

ここでは、最初にロジットの分布から上位k個のトークンを選択し、次にそれらからサンプリングします。

これは、これらのサンプリング手法を使用した実験です

15import torch
16
17from labml_nn.sampling import Sampler

トップkサンプラー

20class TopKSampler(Sampler):
  • k 選択するトークンの数です
  • sampler トップkのトークンに使用するサンプラーです

sampler ロジッツテンソルを入力として受け取り、トークンテンソルを返すサンプラーならどれでもかまいません(例:`TemperatureSampler')。

24    def __init__(self, k: int, sampler: Sampler):
32        self.k = k
33        self.sampler = sampler

ロジットからのサンプル

35    def __call__(self, logits: torch.Tensor):

新しいロジットを埋める、つまり確率がゼロ

40        zeros = logits.new_ones(logits.shape) * float('-inf')

最大のロジットとそのインデックスを選択してください

42        values, indices = torch.topk(logits, self.k, dim=-1)

選択した上位kのインデックスの値を実際のロジットに設定します。他のトークンのロジットは残ります

45        zeros.scatter_(-1, indices, values)

指定されたサンプラーを使用して、上からk個のロジットをサンプリングします。

48        return self.sampler(zeros)