Evaluation

This is the code to test the model on EleutherAI/lm-evaluation-harness.

15import math
16from typing import List
17
18import torch
19import torch.nn.functional as F
20from lm_eval import tasks, evaluator, utils
21from lm_eval.base import BaseLM
22from tokenizers import Tokenizer
23from torch import nn
24from tqdm import tqdm
25
26from labml import monit
27from labml_nn.neox.tokenizer import get_tokenizer

Evaluation Harness Adapter

This is based on the adapter from EleutherAI/gpt-neox

30class EvalHarnessAdapter(BaseLM):
  • tokenizer is the Huggingface Tokenizer
  • vocab_size is the size of the vocabulary (this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer model parallel.)
  • batch_size is the batch size
37    def __init__(self, tokenizer: Tokenizer, vocab_size: int, batch_size: int):
45        super().__init__()
46        self.tokenizer = tokenizer
47        self._eot_token_id = self.tokenizer.token_to_id("<|endoftext|>")
48        self._vocab_size = vocab_size
49
50        self._batch_size = batch_size

Size of the vocabulary

52    @property
53    def device(self):
54        raise RuntimeError()
55
56    @property
57    def vocab_size(self):
59        return self._vocab_size

End-of-text token

61    @property
62    def eot_token_id(self):
64        return self._eot_token_id

Maximum sequence length

66    @property
67    def max_length(self):
69        return 2048

Maximum number of tokens to generate

71    @property
72    def max_gen_toks(self):
74        return 128

Batch size

76    @property
77    def batch_size(self):
81        return self._batch_size

Encode a given text

83    def tok_encode(self, string: str):
87        return self.tokenizer.encode(string).ids

Decode text from token ids

89    def tok_decode(self, tokens: List[int]):
93        return self.tokenizer.decode(tokens)
95    def _model_call(self, inps: torch.Tensor):
96        raise NotImplementedError
98    def _model_generate(self, context, max_length, eos_token_id):
99        raise RuntimeError()
101    def greedy_until(self, requests):
102        raise RuntimeError()

Get log-likelihoods of the next tokens

  • requests List of requests containing the context and the expected continuation.
  • disable_tqdm If True, disable tqdm progress bar.
104    @torch.no_grad()
105    def _loglikelihood_tokens(self, requests, disable_tqdm=False):

For results

114        res = []

Reorder the requests in the descending order of the lengths, so that sequences with similar lengths are close

118        def _collate(x):
119            toks = x[1] + x[2]
120            return -len(toks), tuple(toks)
121
122        reord = utils.Reorderer(requests, _collate)

Loop through requests with batch_size number of requests at a time

125        for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):

To store the inputs for the batch

127            inps = []

The continuations for the batch

129            continuations = []

Lengths of the input sequences

131            inplens = []

Padded length for the batch

133            padded_length = None

Loop through each request in the chunk and collect them into PyTorch tensors with paddings

135            for _, context_enc, continuation_enc in chunk:

Concatenate the context and continuation

137                inp = context_enc + continuation_enc

Truncate from left if the size exceeds the max_length

139                inp = inp[-(self.max_length + 1):]

Remove final token

141                inp = inp[:-1]

Create a tensor

143                inp = torch.tensor(inp, dtype=torch.long)

Input length

145                inplen = inp.shape[0]

Determine the padded length. Shorter sequences will get padded.

149                if padded_length is None:
150                    padded_length = int(math.ceil(inplen / 32)) * 32

padded_length = padded_length if padded_length is not None else inplen

Padding

154                padding = torch.zeros(padded_length - inplen, dtype=torch.long)

Add padding

157                inp = torch.cat([inp, padding], dim=0)
158
159                inps.append(inp)
160                continuations.append(continuation_enc)
161                inplens.append(inplen)

Get model logits

164            logits = self._model_call(torch.stack(inps))

Get log softmaxes

167            multi_logits = F.log_softmax(logits, dim=-1)

Loop through the input/output pairs of the batch

170            for logits, inplen, cont_toks in zip(multi_logits, inplens, continuations):

Get number of predicted tokens

172                contlen = len(cont_toks)

Get logits of those

174                logits = logits[inplen - contlen: inplen]

Get the tokens with the highest probabilities

176                greedy_tokens = logits.argmax(dim=-1)

Get the target tokens

178                cont_toks = torch.tensor(cont_toks, dtype=torch.long).to(logits.device)

Whether there's an exact match

180                max_equal = (greedy_tokens == cont_toks).all()

Log-likelihoods of the target tokens

182                logits = torch.gather(logits, 1, cont_toks[:, None])

Add the total log-likelihoods and whether there was a match to the results

184                res.append((float(logits.sum()), bool(max_equal)))

Re-order and return results

187        return reord.get_original(res)

Run given evaluations

189    @torch.no_grad()
190    def run_eval(self, name: str, eval_tasks: List[str]):
196        results = evaluator.evaluate(lm=self, task_dict=tasks.get_task_dict(eval_tasks))

Add configs

199        results["config"] = {
200            "name": name,
201        }

204        return results

Evaluation Harness Adapter

This is based on the adapter from EleutherAI/gpt-neox

207class NoeXEvalHarnessAdapter(EvalHarnessAdapter):
  • model is model
  • tokenizer is the Huggingface Tokenizer
  • vocab_size is the size of the vocabulary (this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer model parallel.)
  • batch_size is the batch size
  • device is the device of the model
214    def __init__(self, model: nn.Module, tokenizer: Tokenizer, vocab_size: int, batch_size: int, device: torch.device):
224        super().__init__(tokenizer, vocab_size, batch_size)
225        self.model = model
226        self._device = device

Call the model

228    def _model_call(self, inps: torch.Tensor):
232        return self.model(inps.to(self._device))

Run evaluation harness with a given model

235def run_eval_harness(model: nn.Module, name: str, eval_tasks: List[str], device: torch.device, batch_size: int = 8):

Load the tokenizer

241    with monit.section('Load tokenizer'):
242        tokenizer = get_tokenizer()

All tasks if nothing is specified

245    if not eval_tasks:
246        eval_tasks = [
247            "anli_r1",
248            "anli_r2",
249            "anli_r3",
250            "hellaswag",
251            "lambada",
252            "piqa",
253            "winogrande",
254            "wsc",
255            "mathqa",
256        ]

Create the adapter

259    adapter = NoeXEvalHarnessAdapter(model, tokenizer, 50_432, batch_size, device)

Run

262    return adapter.run_eval(name, eval_tasks)