15import torch
16
17from labml_nn.sampling import Sampler20class TopKSampler(Sampler):k
選択するトークンの数ですsampler
トップkのトークンに使用するサンプラーですsampler
ロジッツテンソルを入力として受け取り、トークンテンソルを返すサンプラーならどれでもかまいません(例:`TemperatureSampler')。
24 def __init__(self, k: int, sampler: Sampler):32 self.k = k
33 self.sampler = samplerロジットからのサンプル
35 def __call__(self, logits: torch.Tensor):新しいロジットを埋める、つまり確率がゼロ
40 zeros = logits.new_ones(logits.shape) * float('-inf')最大のロジットとそのインデックスを選択してください
42 values, indices = torch.topk(logits, self.k, dim=-1)選択した上位kのインデックスの値を実際のロジットに設定します。他のトークンのロジットは残ります
45 zeros.scatter_(-1, indices, values)指定されたサンプラーを使用して、上からk個のロジットをサンプリングします。
48 return self.sampler(zeros)