LLM.int8 () 量子化を使用して GPT-Neox でテキストを生成

これは、LLM.int8 () 量子化を使用して GPT-Neox からテキストを生成する方法を示しています。

これには 24 GB のメモリを搭載した GPU が必要です。

15import torch
16from torch import nn
17
18from labml import monit
19from labml_nn.neox.model import LayerGenerator
20from labml_nn.neox.samples.generate import PROMPT, infer
21from labml_nn.neox.utils import get_tokens, print_tokens
22from labml_nn.neox.utils.cache import get_cache

テキストを生成

25def generate():
31    cache = get_cache()
32    cache.set('use_cache', True)

端末

35    device = torch.device('cuda:0')

float16 のレイヤーを CPU にロードします。レイヤーをGPUにロードした後にその場でこれを行うと、CUDAメモリの断片化が発生するため、後でレイヤーをint8に変換します(フラグメンテーションにより約3GBのメモリが失われる可能性があります

)。
40    layer_generator = LayerGenerator(is_clone_layers=True,
41                                     dtype=torch.float16,
42                                     device=torch.device('cpu'),
43                                     is_llm_int8=False,
44                                     )
45    layers = list(layer_generator.load())

これにより、CUDA メモリの断片化が減少します。

48    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
49        layer_generator.post_load_prepare(layer,
50                                          device=device,
51                                          is_llm_int8=True,
52                                          llm_int8_threshold=6.0,
53                                          )
54        layer.to(device)

nn.Sequential モデル作成

57    model = nn.Sequential(*layers)

デバッグ用にキャッシュをクリアしてメモリの概要を印刷

60    torch.cuda.empty_cache()
61    print(torch.cuda.memory_summary())

トークン ID を取得

64    ids = get_tokens(PROMPT)

モデルを実行します。infer で定義されている関数を使用します generate.py

68    cache.set('state_ids', (None, 1))
69    with monit.section('Infer'):
70        next_token = infer(model, ids, device)[-1]

予測トークンを追加

73    ids += [next_token]

トークンを100個予測する

76    for i in range(1, 100):

キャッシュされたアクティベーションを使用するように状態を設定します

78        cache.set('state_ids', (i, i + 1))

次のトークンを入手してください。以前のトークンのキーと値のペアをキャッシュするので、最後のトークンのみをモデルにフィードすることに注意してください

81        with monit.section('Infer'):
82            next_token = infer(model, [next_token], device)[-1]

予測トークンを追加

84        ids += [next_token]

プリント

86        print_tokens(ids, [ids])

90if __name__ == '__main__':
91    generate()