15import typing
16from typing import List, Optional
17
18import torch
19
20from labml import logger
21from labml.logger import Text
22from labml_nn.neox.tokenizer import get_tokenizer
23
24if typing.TYPE_CHECKING:
25    from tokenizers import Tokenizer

トークナイザーシングルトン

28_TOKENIZER: Optional['Tokenizer'] = None

トークン ID を取得

  • text トークン化するテキストです
  • トークン ID を返します

31def get_tokens(text: str) -> List[int]:
38    global _TOKENIZER
39    if _TOKENIZER is None:
40        _TOKENIZER = get_tokenizer()
41    return _TOKENIZER.encode_batch([text])[0].ids

モデル出力からトークンを印刷

Pretty は、モデルからの出力と一緒にターゲットトークンを出力します。

  • ids ターゲットトークン ID
  • xs はモデルの出力です
44def print_token_outputs(ids: List[int], *xs: torch.Tensor):
53    ids = ids + [-1]
54    xs = [[-1] + x[0].max(dim=-1)[1].tolist() for x in xs]
55
56    print_tokens(ids, xs)

トークンの印刷

Pretty は比較用のトークンを印刷します

  • target ターゲットトークン ID
  • others はモデルからサンプリングされた出力です
59def print_tokens(target: List[int], others: List[List[int]]):

トークナイザーをロード

70    global _TOKENIZER
71    if _TOKENIZER is None:
72        _TOKENIZER = get_tokenizer()

トークンを文字列のリストに変換

75    text = []
76    for i in range(len(target)):
77        tokens = [_TOKENIZER.decode([target[i]]) if target[i] != -1 else '---']
78        for j in range(len(others)):
79            tokens.append(_TOKENIZER.decode([others[j][i]]) if others[j][i] != -1 else '---')
80
81        text.append(tokens)

統計情報

84    correct = [0 for _ in others]
85    total = 0

トークンを繰り返し処理

88    for i in range(len(target)):
89        parts = [(f'{i}: ', Text.meta)]
90        parts += [('"', Text.subtle), (text[i][0], Text.subtle), ('"', Text.subtle), '\t']

空のターゲット

93        if target[i] == -1:
94            for j in range(len(others)):
95                parts += [('"', Text.subtle), (text[i][j + 1], Text.subtle), ('"', Text.subtle), '\t']
96
97            logger.log(parts)
98            continue

トークンの数

101        total += 1

その他の出力

104        for j in range(len(others)):
105            correct[j] += 1 if others[j][i] == target[i] else 0
106
107            parts += [('"', Text.subtle),
108                      (text[i][j + 1], Text.success if others[j][i] == target[i] else Text.danger),
109                      ('"', Text.subtle), '\t']
110
111        logger.log(parts)

統計情報

114    parts = [(f'{total}', Text.highlight), '\t']
115    for j in range(len(others)):
116        parts += [(f'{correct[j]}', Text.value), '\t']
117    logger.log(parts)

バランスレイヤー

に分割n_chunks . n_layers これはパイプラインの並列トレーニングに使用されます。

  • n_layers はレイヤーの数
  • n_chunks はチャンクの数
  • Returns は、各チャンクのレイヤー数のリストを返します

120def balance_layers_simple(n_layers: int, n_chunks: int):
130    balance = []
131    for i in range(n_chunks):
132        balance.append((n_layers - sum(balance)) // (n_chunks - i))
133
134    return list(reversed(balance))