This is the code to test the model on EleutherAI/lm-evaluation-harness.
15import math
16from typing import List
17
18import torch
19import torch.nn.functional as F
20from lm_eval import tasks, evaluator, utils
21from lm_eval.base import BaseLM
22from tokenizers import Tokenizer
23from torch import nn
24from tqdm import tqdm
25
26from labml import monit
27from labml_nn.neox.tokenizer import get_tokenizer
30class EvalHarnessAdapter(BaseLM):
tokenizer
is the Huggingface Tokenizer vocab_size
is the size of the vocabulary (this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer model parallel.) batch_size
is the batch size37 def __init__(self, tokenizer: Tokenizer, vocab_size: int, batch_size: int):
45 super().__init__()
46 self.tokenizer = tokenizer
47 self._eot_token_id = self.tokenizer.token_to_id("<|endoftext|>")
48 self._vocab_size = vocab_size
49
50 self._batch_size = batch_size
Size of the vocabulary
52 @property
53 def device(self):
54 raise RuntimeError()
55
56 @property
57 def vocab_size(self):
59 return self._vocab_size
End-of-text token
61 @property
62 def eot_token_id(self):
64 return self._eot_token_id
Maximum sequence length
66 @property
67 def max_length(self):
69 return 2048
Maximum number of tokens to generate
71 @property
72 def max_gen_toks(self):
74 return 128
Batch size
76 @property
77 def batch_size(self):
81 return self._batch_size
Encode a given text
83 def tok_encode(self, string: str):
87 return self.tokenizer.encode(string).ids
Decode text from token ids
89 def tok_decode(self, tokens: List[int]):
93 return self.tokenizer.decode(tokens)
95 def _model_call(self, inps: torch.Tensor):
96 raise NotImplementedError
98 def _model_generate(self, context, max_length, eos_token_id):
99 raise RuntimeError()
101 def greedy_until(self, requests):
102 raise RuntimeError()
requests
List of requests containing the context and the expected continuation. disable_tqdm
If True, disable tqdm progress bar.104 @torch.no_grad()
105 def _loglikelihood_tokens(self, requests, disable_tqdm=False):
For results
114 res = []
Reorder the requests in the descending order of the lengths, so that sequences with similar lengths are close
118 def _collate(x):
119 toks = x[1] + x[2]
120 return -len(toks), tuple(toks)
121
122 reord = utils.Reorderer(requests, _collate)
Loop through requests with batch_size
number of requests at a time
125 for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
To store the inputs for the batch
127 inps = []
The continuations for the batch
129 continuations = []
Lengths of the input sequences
131 inplens = []
Padded length for the batch
133 padded_length = None
Loop through each request in the chunk and collect them into PyTorch tensors with paddings
135 for _, context_enc, continuation_enc in chunk:
Concatenate the context and continuation
137 inp = context_enc + continuation_enc
Truncate from left if the size exceeds the max_length
139 inp = inp[-(self.max_length + 1):]
Remove final token
141 inp = inp[:-1]
Create a tensor
143 inp = torch.tensor(inp, dtype=torch.long)
Input length
145 inplen = inp.shape[0]
Determine the padded length. Shorter sequences will get padded.
149 if padded_length is None:
150 padded_length = int(math.ceil(inplen / 32)) * 32
padded_length = padded_length if padded_length is not None else inplen
Padding
154 padding = torch.zeros(padded_length - inplen, dtype=torch.long)
Add padding
157 inp = torch.cat([inp, padding], dim=0)
158
159 inps.append(inp)
160 continuations.append(continuation_enc)
161 inplens.append(inplen)
Get model logits
164 logits = self._model_call(torch.stack(inps))
Get log softmaxes
167 multi_logits = F.log_softmax(logits, dim=-1)
Loop through the input/output pairs of the batch
170 for logits, inplen, cont_toks in zip(multi_logits, inplens, continuations):
Get number of predicted tokens
172 contlen = len(cont_toks)
Get logits of those
174 logits = logits[inplen - contlen: inplen]
Get the tokens with the highest probabilities
176 greedy_tokens = logits.argmax(dim=-1)
Get the target tokens
178 cont_toks = torch.tensor(cont_toks, dtype=torch.long).to(logits.device)
Whether there's an exact match
180 max_equal = (greedy_tokens == cont_toks).all()
Log-likelihoods of the target tokens
182 logits = torch.gather(logits, 1, cont_toks[:, None])
Add the total log-likelihoods and whether there was a match to the results
184 res.append((float(logits.sum()), bool(max_equal)))
Re-order and return results
187 return reord.get_original(res)
189 @torch.no_grad()
190 def run_eval(self, name: str, eval_tasks: List[str]):
Run EleutherAI/lm-evaluation-harness evaluator
196 results = evaluator.evaluate(lm=self, task_dict=tasks.get_task_dict(eval_tasks))
Add configs
199 results["config"] = {
200 "name": name,
201 }
204 return results
207class NoeXEvalHarnessAdapter(EvalHarnessAdapter):
model
is model tokenizer
is the Huggingface Tokenizer vocab_size
is the size of the vocabulary (this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer model parallel.) batch_size
is the batch size device
is the device of the model214 def __init__(self, model: nn.Module, tokenizer: Tokenizer, vocab_size: int, batch_size: int, device: torch.device):
224 super().__init__(tokenizer, vocab_size, batch_size)
225 self.model = model
226 self._device = device
Call the model
228 def _model_call(self, inps: torch.Tensor):
232 return self.model(inps.to(self._device))
235def run_eval_harness(model: nn.Module, name: str, eval_tasks: List[str], device: torch.device, batch_size: int = 8):
Load the tokenizer
241 with monit.section('Load tokenizer'):
242 tokenizer = get_tokenizer()
All tasks if nothing is specified
245 if not eval_tasks:
246 eval_tasks = [
247 "anli_r1",
248 "anli_r2",
249 "anli_r3",
250 "hellaswag",
251 "lambada",
252 "piqa",
253 "winogrande",
254 "wsc",
255 "mathqa",
256 ]
Create the adapter
259 adapter = NoeXEvalHarnessAdapter(model, tokenizer, 50_432, batch_size, device)
Run
262 return adapter.run_eval(name, eval_tasks)