言語モデルのサンプリング手法の試み

この実験では、HuggingFaceのGPT2モデルで上記のサンプリング手法を使用しています。

18import torch
19
20from labml import monit, logger, lab
21
22from labml.logger import Text
23
24from labml_nn.sampling import Sampler
25from labml_nn.sampling.greedy import GreedySampler
26from labml_nn.sampling.nucleus import NucleusSampler
27from labml_nn.sampling.temperature import TemperatureSampler
28from labml_nn.sampling.top_k import TopKSampler
29from transformers import GPT2Tokenizer, GPT2LMHeadModel

モデルからのサンプル

  • model サンプリング元のモデルです
  • tokenizer 使用するトークナイザーです
  • sampler 使用するサンプラーは
  • n_samples は生成するサンプルの数です
  • n_tokens は生成するトークンの数です
  • seq_len モデルの最大シーケンス長です
  • prompt は開始プロンプトです
32@torch.no_grad()
33def sample(model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, sampler: Sampler,
34           n_samples: int, n_tokens: int, seq_len: int, prompt: str):

prompt n_samples をトークン化してコピーを作成

47    data = torch.tile(torch.tensor(tokenizer.encode(prompt))[None, :], (n_samples, 1))

印刷用の出力を収集

50    logs = [[(prompt, Text.meta)] for _ in range(n_samples)]

[サンプル] n_tokens

52    for i in monit.iterate('Sample', n_tokens):

データを最大シーケンス長まで切り捨てる

54        data = data[-seq_len:]

モデル出力を取得します。「ロジット」には形があります [batch_size, seq_len, n_tokens]

56        logits = model(data)[0]

最後のトークンの取得 logits

58        logits = logits[:, -1]

からのサンプル logits

60        res = sampler(logits)

サンプリングしたトークンをデータに追加します

62        data = torch.cat([data, res[:, None]], dim=1)

サンプリングしたトークンをデコードしてロギング用に追加

64        for j in range(n_samples):
65            logs[j] += [('' + tokenizer.decode(res[j]), Text.value)]

サンプル出力を印刷

68    for j in range(n_samples):
69        logger.log(logs[j])

さまざまなサンプリング手法を試してください

72def main():

モデルとトークナイザーをロード

78    with monit.section('Load tokenizer/model'):
79        tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
80        model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')

モデルを eval モードに設定

82    model.eval()

サンプリングに使用するプロンプト

85    prompt = 'I saw an interesting dream last night. '
88    with monit.section('greedy'):
89        sample(model, tokenizer, GreedySampler(), 4, 32, 128, prompt)
92    with monit.section('temperature=1.'):
93        sample(model, tokenizer, TemperatureSampler(1.), 4, 32, 128, prompt)
94    with monit.section('temperature=.1'):
95        sample(model, tokenizer, TemperatureSampler(.1), 4, 32, 128, prompt)
96    with monit.section('temperature=10.'):
97        sample(model, tokenizer, TemperatureSampler(10.), 4, 32, 128, prompt)
100    with monit.section('top_k=5'):
101        sample(model, tokenizer, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, prompt)
104    with monit.section('nucleus p=.95'):
105        sample(model, tokenizer, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, prompt)
106    with monit.section('nucleus p=.1'):
107        sample(model, tokenizer, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, prompt)

110if __name__ == '__main__':
111    main()