Generate Text with GPT-NeoX using LLM.int8() quantization

This shows how to generate text from GPT-NeoX using LLM.int8() quantization.

This needs a GPU with 24GB memory.

15import torch
16from torch import nn
17
18from labml import monit
19from labml_nn.neox.model import LayerGenerator
20from labml_nn.neox.samples.generate import PROMPT, infer
21from labml_nn.neox.utils import get_tokens, print_tokens
22from labml_nn.neox.utils.cache import get_cache

Generate text

25def generate():

Setup cache to cache intermediate key/value pairs for faster generation

31    cache = get_cache()
32    cache.set('use_cache', True)

Device

35    device = torch.device('cuda:0')

Load layers in float16 into CPU. We convert the layers to int8 later, because doing that on the fly after loading layers to GPU causes CUDA memory fragmentation (about 3GB memory can get lost due to fragmentation).

40    layer_generator = LayerGenerator(is_clone_layers=True,
41                                     dtype=torch.float16,
42                                     device=torch.device('cpu'),
43                                     is_llm_int8=False,
44                                     )
45    layers = list(layer_generator.load())

This reduces CUDA memory fragmentation

48    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
49        layer_generator.post_load_prepare(layer,
50                                          device=device,
51                                          is_llm_int8=True,
52                                          llm_int8_threshold=6.0,
53                                          )
54        layer.to(device)

Create nn.Sequential model

57    model = nn.Sequential(*layers)

Clear cache and print memory summary for debugging

60    torch.cuda.empty_cache()
61    print(torch.cuda.memory_summary())

Get token ids

64    ids = get_tokens(PROMPT)

Run the model. We use the infer function defined in generate.py

68    cache.set('state_ids', (None, 1))
69    with monit.section('Infer'):
70        next_token = infer(model, ids, device)[-1]

Append the predicted token

73    ids += [next_token]

Predict 100 tokens

76    for i in range(1, 100):

Set the state to use cached activations

78        cache.set('state_ids', (i, i + 1))

Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.

81        with monit.section('Infer'):
82            next_token = infer(model, [next_token], device)[-1]

Append the predicted token

84        ids += [next_token]

Print

86        print_tokens(ids, [ids])

90if __name__ == '__main__':
91    generate()