This shows how to generate text from GPT-NeoX using LLM.int8() quantization.
This needs a GPU with 24GB memory.
15import torch
16from torch import nn
17
18from labml import monit
19from labml_nn.neox.model import LayerGenerator
20from labml_nn.neox.samples.generate import PROMPT, infer
21from labml_nn.neox.utils import get_tokens, print_tokens
22from labml_nn.neox.utils.cache import get_cache
25def generate():
31 cache = get_cache()
32 cache.set('use_cache', True)
Device
35 device = torch.device('cuda:0')
Load layers in float16 into CPU. We convert the layers to int8 later, because doing that on the fly after loading layers to GPU causes CUDA memory fragmentation (about 3GB memory can get lost due to fragmentation).
40 layer_generator = LayerGenerator(is_clone_layers=True,
41 dtype=torch.float16,
42 device=torch.device('cpu'),
43 is_llm_int8=False,
44 )
45 layers = list(layer_generator.load())
This reduces CUDA memory fragmentation
48 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
49 layer_generator.post_load_prepare(layer,
50 device=device,
51 is_llm_int8=True,
52 llm_int8_threshold=6.0,
53 )
54 layer.to(device)
Create nn.Sequential
model
57 model = nn.Sequential(*layers)
Clear cache and print memory summary for debugging
60 torch.cuda.empty_cache()
61 print(torch.cuda.memory_summary())
Get token ids
64 ids = get_tokens(PROMPT)
Run the model. We use the infer
function defined in generate.py
68 cache.set('state_ids', (None, 1))
69 with monit.section('Infer'):
70 next_token = infer(model, ids, device)[-1]
Append the predicted token
73 ids += [next_token]
Predict 100 tokens
76 for i in range(1, 100):
Set the state to use cached activations
78 cache.set('state_ids', (i, i + 1))
Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.
81 with monit.section('Infer'):
82 next_token = infer(model, [next_token], device)[-1]
Append the predicted token
84 ids += [next_token]
86 print_tokens(ids, [ids])
90if __name__ == '__main__':
91 generate()