15import torch
16from torch import nn
17
18from labml import monit
19from labml_nn.neox.model import LayerGenerator
20from labml_nn.neox.samples.generate import PROMPT, infer
21from labml_nn.neox.utils import get_tokens, print_tokens
22from labml_nn.neox.utils.cache import get_cache
25def generate():
设备
35 device = torch.device('cuda:0')
float16 中的层加载到 CPU 中。我们稍后将图层转换为int8,因为在将图层加载到GPU后即时执行此操作会导致CUDA内存碎片(大约3GB的内存可能会由于碎片而丢失)。
40 layer_generator = LayerGenerator(is_clone_layers=True,
41 dtype=torch.float16,
42 device=torch.device('cpu'),
43 is_llm_int8=False,
44 )
45 layers = list(layer_generator.load())
这减少了 CUDA 内存碎片
48 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
49 layer_generator.post_load_prepare(layer,
50 device=device,
51 is_llm_int8=True,
52 llm_int8_threshold=6.0,
53 )
54 layer.to(device)
创建nn.Sequential
模型
57 model = nn.Sequential(*layers)
清除缓存和打印内存摘要以进行调试
60 torch.cuda.empty_cache()
61 print(torch.cuda.memory_summary())
获取代币 ID
64 ids = get_tokens(PROMPT)
运行模型。我们使用中定义的infer
函数 generate.py
68 cache.set('state_ids', (None, 1))
69 with monit.section('Infer'):
70 next_token = infer(model, ids, device)[-1]
追加预测的令牌
73 ids += [next_token]
预测 100 个代币
76 for i in range(1, 100):
设置状态以使用缓存的激活
78 cache.set('state_ids', (i, i + 1))
获取下一个令牌。请注意,我们只将最后一个令牌提供给模型,因为我们缓存了先前令牌的键/值对。
81 with monit.section('Infer'):
82 next_token = infer(model, [next_token], device)[-1]
追加预测的令牌
84 ids += [next_token]
打印
86 print_tokens(ids, [ids])
90if __name__ == '__main__':
91 generate()