15import typing
16from typing import List, Optional
17
18import torch
19
20from labml import logger
21from labml.logger import Text
22from labml_nn.neox.tokenizer import get_tokenizer
23
24if typing.TYPE_CHECKING:
25 from tokenizers import Tokenizer
分词器单例
28_TOKENIZER: Optional['Tokenizer'] = None
31def get_tokens(text: str) -> List[int]:
38 global _TOKENIZER
39 if _TOKENIZER is None:
40 _TOKENIZER = get_tokenizer()
41 return _TOKENIZER.encode_batch([text])[0].ids
44def print_token_outputs(ids: List[int], *xs: torch.Tensor):
53 ids = ids + [-1]
54 xs = [[-1] + x[0].max(dim=-1)[1].tolist() for x in xs]
55
56 print_tokens(ids, xs)
59def print_tokens(target: List[int], others: List[List[int]]):
加载分词器
70 global _TOKENIZER
71 if _TOKENIZER is None:
72 _TOKENIZER = get_tokenizer()
将标记转换为字符串列表
75 text = []
76 for i in range(len(target)):
77 tokens = [_TOKENIZER.decode([target[i]]) if target[i] != -1 else '---']
78 for j in range(len(others)):
79 tokens.append(_TOKENIZER.decode([others[j][i]]) if others[j][i] != -1 else '---')
80
81 text.append(tokens)
统计数据
84 correct = [0 for _ in others]
85 total = 0
遍历令牌
88 for i in range(len(target)):
89 parts = [(f'{i}: ', Text.meta)]
90 parts += [('"', Text.subtle), (text[i][0], Text.subtle), ('"', Text.subtle), '\t']
空目标
93 if target[i] == -1:
94 for j in range(len(others)):
95 parts += [('"', Text.subtle), (text[i][j + 1], Text.subtle), ('"', Text.subtle), '\t']
96
97 logger.log(parts)
98 continue
代币数量
101 total += 1
其他输出
104 for j in range(len(others)):
105 correct[j] += 1 if others[j][i] == target[i] else 0
106
107 parts += [('"', Text.subtle),
108 (text[i][j + 1], Text.success if others[j][i] == target[i] else Text.danger),
109 ('"', Text.subtle), '\t']
110
111 logger.log(parts)
统计数据
114 parts = [(f'{total}', Text.highlight), '\t']
115 for j in range(len(others)):
116 parts += [(f'{correct[j]}', Text.value), '\t']
117 logger.log(parts)
120def balance_layers_simple(n_layers: int, n_chunks: int):
130 balance = []
131 for i in range(n_chunks):
132 balance.append((n_layers - sum(balance)) // (n_chunks - i))
133
134 return list(reversed(balance))