15import typing
16from typing import List, Optional
17
18import torch
19
20from labml import logger
21from labml.logger import Text
22from labml_nn.neox.tokenizer import get_tokenizer
23
24if typing.TYPE_CHECKING:
25 from tokenizers import Tokenizer
Tokenizer singleton
28_TOKENIZER: Optional['Tokenizer'] = None
31def get_tokens(text: str) -> List[int]:
38 global _TOKENIZER
39 if _TOKENIZER is None:
40 _TOKENIZER = get_tokenizer()
41 return _TOKENIZER.encode_batch([text])[0].ids
Pretty prints target tokens along side outputs from the model(s).
ids
are the target token ids xs
are the model(s) outputs44def print_token_outputs(ids: List[int], *xs: torch.Tensor):
53 ids = ids + [-1]
54 xs = [[-1] + x[0].max(dim=-1)[1].tolist() for x in xs]
55
56 print_tokens(ids, xs)
Pretty prints tokens for comparison
target
are the target token ids others
are the sampled outputs from the model(s)59def print_tokens(target: List[int], others: List[List[int]]):
Load tokenizer
70 global _TOKENIZER
71 if _TOKENIZER is None:
72 _TOKENIZER = get_tokenizer()
Convert the tokens to list of strings
75 text = []
76 for i in range(len(target)):
77 tokens = [_TOKENIZER.decode([target[i]]) if target[i] != -1 else '---']
78 for j in range(len(others)):
79 tokens.append(_TOKENIZER.decode([others[j][i]]) if others[j][i] != -1 else '---')
80
81 text.append(tokens)
Stats
84 correct = [0 for _ in others]
85 total = 0
Iterate through tokens
88 for i in range(len(target)):
89 parts = [(f'{i}: ', Text.meta)]
90 parts += [('"', Text.subtle), (text[i][0], Text.subtle), ('"', Text.subtle), '\t']
Empty target
93 if target[i] == -1:
94 for j in range(len(others)):
95 parts += [('"', Text.subtle), (text[i][j + 1], Text.subtle), ('"', Text.subtle), '\t']
96
97 logger.log(parts)
98 continue
Number of tokens
101 total += 1
Other outputs
104 for j in range(len(others)):
105 correct[j] += 1 if others[j][i] == target[i] else 0
106
107 parts += [('"', Text.subtle),
108 (text[i][j + 1], Text.success if others[j][i] == target[i] else Text.danger),
109 ('"', Text.subtle), '\t']
110
111 logger.log(parts)
Stats
114 parts = [(f'{total}', Text.highlight), '\t']
115 for j in range(len(others)):
116 parts += [(f'{correct[j]}', Text.value), '\t']
117 logger.log(parts)
Split the n_layers
into n_chunks
. This is used for pipeline parallel training.
n_layers
is the number of layers n_chunks
is the number of chunks Returns returns a list with the number of layers for each chunk
120def balance_layers_simple(n_layers: int, n_chunks: int):
130 balance = []
131 for i in range(n_chunks):
132 balance.append((n_layers - sum(balance)) // (n_chunks - i))
133
134 return list(reversed(balance))