使用 llm.int8 () 量化使用 GPT-NEOX 生成文本

这说明了如何使用 llm.int8 () 量化从 GPT-NEOX 生成文本。

这需要一个具有 24GB 内存的 GPU。

15import torch
16from torch import nn
17
18from labml import monit
19from labml_nn.neox.model import LayerGenerator
20from labml_nn.neox.samples.generate import PROMPT, infer
21from labml_nn.neox.utils import get_tokens, print_tokens
22from labml_nn.neox.utils.cache import get_cache

生成文本

25def generate():

设置缓存以缓存中间键/值对以加快生成速度

31    cache = get_cache()
32    cache.set('use_cache', True)

设备

35    device = torch.device('cuda:0')

float16 中的层加载到 CPU 中。我们稍后将图层转换为int8,因为在将图层加载到GPU后即时执行此操作会导致CUDA内存碎片(大约3GB的内存可能会由于碎片而丢失)。

40    layer_generator = LayerGenerator(is_clone_layers=True,
41                                     dtype=torch.float16,
42                                     device=torch.device('cpu'),
43                                     is_llm_int8=False,
44                                     )
45    layers = list(layer_generator.load())

这减少了 CUDA 内存碎片

48    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
49        layer_generator.post_load_prepare(layer,
50                                          device=device,
51                                          is_llm_int8=True,
52                                          llm_int8_threshold=6.0,
53                                          )
54        layer.to(device)

创建nn.Sequential 模型

57    model = nn.Sequential(*layers)

清除缓存和打印内存摘要以进行调试

60    torch.cuda.empty_cache()
61    print(torch.cuda.memory_summary())

获取代币 ID

64    ids = get_tokens(PROMPT)

运行模型。我们使用中定义的infer 函数 generate.py

68    cache.set('state_ids', (None, 1))
69    with monit.section('Infer'):
70        next_token = infer(model, ids, device)[-1]

追加预测的令牌

73    ids += [next_token]

预测 100 个代币

76    for i in range(1, 100):

设置状态以使用缓存的激活

78        cache.set('state_ids', (i, i + 1))

获取下一个令牌。请注意,我们只将最后一个令牌提供给模型,因为我们缓存了先前令牌的键/值对。

81        with monit.section('Infer'):
82            next_token = infer(model, [next_token], device)[-1]

追加预测的令牌

84        ids += [next_token]

打印

86        print_tokens(ids, [ids])

90if __name__ == '__main__':
91    generate()