Trying out Sampling Techniques for Language Models

This experiment uses the above sampling techniques, on HuggingFace's GPT2 model.

18import torch
19
20from labml import monit, logger, lab
21
22from labml.logger import Text
23
24from labml_nn.sampling import Sampler
25from labml_nn.sampling.greedy import GreedySampler
26from labml_nn.sampling.nucleus import NucleusSampler
27from labml_nn.sampling.temperature import TemperatureSampler
28from labml_nn.sampling.top_k import TopKSampler
29from transformers import GPT2Tokenizer, GPT2LMHeadModel

Sample from model

  • model is the model to sample from
  • tokenizer is the tokenizer to use
  • sampler is the sampler to use
  • n_samples is the number of samples to generate
  • n_tokens is the number of tokens to generate
  • seq_len is the maximum sequence length for the model
  • prompt is the starting prompt
32@torch.no_grad()
33def sample(model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, sampler: Sampler,
34           n_samples: int, n_tokens: int, seq_len: int, prompt: str):

Tokenize the prompt and make n_samples copies of it

47    data = torch.tile(torch.tensor(tokenizer.encode(prompt))[None, :], (n_samples, 1))

Collect output for printing

50    logs = [[(prompt, Text.meta)] for _ in range(n_samples)]

Sample n_tokens

52    for i in monit.iterate('Sample', n_tokens):

Truncate the data to the maximum sequence length

54        data = data[-seq_len:]

Get the model output. The 'logits' has shape [batch_size, seq_len, n_tokens]

56        logits = model(data)[0]

Get the logits of the last token

58        logits = logits[:, -1]

Sample from the logits

60        res = sampler(logits)

Add the sampled token to the data

62        data = torch.cat([data, res[:, None]], dim=1)

Decode and add the sampled token for logging

64        for j in range(n_samples):
65            logs[j] += [('' + tokenizer.decode(res[j]), Text.value)]

Print the sampled outputs

68    for j in range(n_samples):
69        logger.log(logs[j])

Try different sampling techniques

72def main():

Load the model and tokenizer

78    with monit.section('Load tokenizer/model'):
79        tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
80        model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')

Set the model to eval mode

82    model.eval()

Prompts to use for sampling

85    prompt = 'I saw an interesting dream last night. '
88    with monit.section('greedy'):
89        sample(model, tokenizer, GreedySampler(), 4, 32, 128, prompt)
92    with monit.section('temperature=1.'):
93        sample(model, tokenizer, TemperatureSampler(1.), 4, 32, 128, prompt)
94    with monit.section('temperature=.1'):
95        sample(model, tokenizer, TemperatureSampler(.1), 4, 32, 128, prompt)
96    with monit.section('temperature=10.'):
97        sample(model, tokenizer, TemperatureSampler(10.), 4, 32, 128, prompt)
100    with monit.section('top_k=5'):
101        sample(model, tokenizer, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, prompt)
104    with monit.section('nucleus p=.95'):
105        sample(model, tokenizer, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, prompt)
106    with monit.section('nucleus p=.1'):
107        sample(model, tokenizer, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, prompt)

110if __name__ == '__main__':
111    main()