15import torch
16
17from labml_nn.sampling import Sampler
20class TopKSampler(Sampler):
k
選択するトークンの数ですsampler
トップkのトークンに使用するサンプラーですsampler
ロジッツテンソルを入力として受け取り、トークンテンソルを返すサンプラーならどれでもかまいません(例:`TemperatureSampler')。
24 def __init__(self, k: int, sampler: Sampler):
32 self.k = k
33 self.sampler = sampler
ロジットからのサンプル
35 def __call__(self, logits: torch.Tensor):
新しいロジットを埋める、つまり確率がゼロ
40 zeros = logits.new_ones(logits.shape) * float('-inf')
最大のロジットとそのインデックスを選択してください
42 values, indices = torch.topk(logits, self.k, dim=-1)
選択した上位kのインデックスの値を実際のロジットに設定します。他のトークンのロジットは残ります
45 zeros.scatter_(-1, indices, values)
指定されたサンプラーを使用して、上からk個のロジットをサンプリングします。
48 return self.sampler(zeros)