温度を用いた言語モデルからのサンプリング

ここでは、次の確率分布からサンプリングします。ここで、はボキャブラリー、は分布のロジット、Tは温度です。

は通常のランダムサンプリングです。

これは、これらのサンプリング手法を使用した実験です

19import torch
20from torch.distributions import Categorical
21
22from labml_nn.sampling import Sampler

温度機能付きサンプラー

25class TemperatureSampler(Sampler):
  • temperature はサンプリングする温度
29    def __init__(self, temperature: float = 1.0):
33        self.temperature = temperature

ロジットからのサンプル

35    def __call__(self, logits: torch.Tensor):

温度調整済みロジットによるカテゴリ分布の作成

41        dist = Categorical(logits=logits / self.temperature)

[サンプル]

44        return dist.sample()