from __future__ import annotations import os, sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from exllamav3.util.file import disk_lru_cache, disk_lru_cache_clear from exllamav3 import model_init, Generator, Job, ComboSampler from exllamav3.util.progress import ProgressBar import argparse import torch import time, random, json from pathlib import Path from urllib import request from collections import deque import math def evaluate_correctness(_sample: str, _reference: str) -> bool: """ Canonical evaluation, adapted from https://github.com/google-deepmind/bbeh/blob/main/bbeh/evaluate.py """ def strip_latex(response: str) -> str: if response.startswith("$") and response.endswith("$"): response = response[1:-1] if "boxed{" in response and response.endswith("}"): response = response[0:-1].split("boxed{")[1] if "text{" in response and response.endswith("}"): response = response[0:-1].split("text{")[1] if "texttt{" in response and response.endswith("}"): response = response[0:-1].split("texttt{")[1] return response def extract_answer(sample: str) -> str: """Extracts the final answer from the sample.""" answer_prefixes = [ "The answer is:", "The final answer is ", "The final answer is: ", "The answer is " ] answer = sample for answer_prefix in answer_prefixes: if answer_prefix in answer: answer = answer.split(answer_prefix)[-1].strip() if answer.endswith("."): answer = answer[:-1] return strip_latex(answer) def fuzzy_match(prediction: str, reference: str) -> bool: """Fuzzy match function for BigBench Extra Hard.""" if prediction == reference: return True # (a) vs a if len(prediction) == 3 and prediction[0] == "(" and prediction[-1] == ")": return prediction[1] == reference if len(reference) == 3 and reference[0] == "(" and reference[-1] == ")": return reference[1] == prediction # Numbers try: if float(prediction) == float(reference): return True except ValueError: pass # quote issues if prediction.replace("'", "") == reference.replace("'", ""): return True # Bracket issues if f"[{reference}]" == prediction or f"[{prediction}]" == reference: return True # Question mark issues if prediction.endswith("?") and prediction[:-1] == reference: return True return False def preprocess_sample(sample: str) -> str: prediction = extract_answer(sample.strip()).lower() prediction = prediction.replace(", ", ",").replace("**", "") prediction = prediction.split("\n")[0] prediction = prediction[0:-1] if prediction.endswith(".") else prediction return prediction def preprocess_reference(reference: str) -> str: reference = reference.strip().lower() reference = reference.replace(", ", ",") return reference _prediction = preprocess_sample(_sample) _reference = preprocess_reference(_reference) return fuzzy_match(_prediction, _reference) @disk_lru_cache("fetch_bbeh_mini_test_data_git") def fetch_bbeh_mini_test_data_git() -> (list[dict], list): url = "https://raw.githubusercontent.com/google-deepmind/bbeh/refs/heads/main/bbeh/mini/data.json" with request.urlopen(url) as response: raw = response.read().decode("utf-8") data = json.loads(raw) return data["examples"] def strip_reasoning(s: str, args) -> str: # TODO: Respect other think tags think_start, think_end = args.thinktags l = s.find(think_start) r = s.rfind(think_end) if l >= 0 and r >= 0: s = s[:l] + s[r+len(think_end):] elif r >= 0: s = s[r + len(think_end):] return s.strip() def write_jsonl(rows: list[dict], path: Path) -> None: with open(path, "w") as f: for row in rows: f.write(json.dumps(row, ensure_ascii=False) + "\n") @torch.inference_mode() def main(args): # Get the stuff bbeh = fetch_bbeh_mini_test_data_git() if args.limit: bbeh = bbeh[:args.limit] rng = random.Random(123) rng.shuffle(bbeh) # Initialize model, config, cache, tokenizer = model_init.init(args) generator = Generator( model = model, cache = cache, max_batch_size = args.max_batch_size, tokenizer = tokenizer, show_visualizer = args.visualize_cache, ) sampler = model_init.get_arg_sampler(args) # Create prompts template_args = {} if args.think: template_args["enable_thinking"] = True if args.nothink: template_args["enable_thinking"] = False all_results: list[dict] = [] with ProgressBar("Prompts", len(bbeh), transient = False) as progress: for idx, bp in enumerate(bbeh): prompt = bp["input"] input_ids = tokenizer.hf_chat_template( [{ "role": "user", "content": prompt }], add_generation_prompt = True, **template_args, ) all_results.append({"input": prompt, "target": bp["target"], "raw_prompt": tokenizer.decode(input_ids[0], decode_special_tokens = True)}) job = Job( input_ids = input_ids, max_new_tokens = args.max_tokens, stop_conditions = config.eos_token_id_list, sampler = sampler, identifier = idx, max_rq_tokens = 512, stop_on_loop = (300, 3), ) generator.enqueue(job) progress.update(idx + 1) # Generate tps_hist = deque() sampled_tokens = 0 last_update = time.time() total_jobs = generator.num_remaining_jobs() done_jobs = 0 total_answered = 0 total_correct = 0 eval_string = "(evaluating)" with (ProgressBar("Generating samples", total_jobs, transient = False) as progress): while generator.num_remaining_jobs(): results = generator.iterate() # Some feedback now = time.time() if now > last_update + 1: tps = sampled_tokens / (now - last_update) sampled_tokens = 0 tps_hist.append(tps) if len(tps_hist) > 3: tps_hist.popleft() tps = round(sum(tps_hist) / len(tps_hist)) num_pend = generator.num_pending_jobs() num_act = generator.num_active_jobs() print(f" -- result: {eval_string:32} pending: {num_pend:4} active {num_act:4} {tps:6} tokens/s", end = "") if num_act: sjob = random.choice(generator.active_jobs) snum = "#" + str(sjob.identifier) ssamp = repr(sjob.full_completion)[-64:-1] print(f" sample from {snum:>4}: {ssamp}") else: print() last_update = now # Collect results for result in results: if "token_ids" in result: sampled_tokens += result["token_ids"].shape[-1] if result.get("eos"): idx = result["identifier"] completion = result["full_completion"] match result["eos_reason"]: case "max_new_tokens": print(f" !! Job #{idx} exceeded token limit, ends in: {repr(completion)[-100:-1]}") case "loop_detected": print(f" !! Job #{idx} loop detected, ends in: {repr(completion)[-100:-1]}") case _: pass # Running evaluation total_answered += 1 answer = strip_reasoning(completion, args) reference = bbeh[idx]["target"] correct = evaluate_correctness(answer, reference) if correct: total_correct += 1 score = total_correct / total_answered if total_answered == total_jobs: eval_string = f"{total_correct: 4}/{total_answered: 4} = {score * 100:6.2f}%" else: interval = 1.96 * math.sqrt(score * (1 - score) / total_answered * (total_jobs - total_answered) / (total_jobs - 1)) eval_string = f"{total_correct: 4}/{total_answered: 4} = {score * 100:6.2f}% +/- {interval * 100: 6.2f}%" # Save answer all_results[idx]["full_completion"] = completion all_results[idx]["answer"] = answer all_results[idx]["correct"] = correct done_jobs += 1 progress.update(done_jobs) # Print result print(f" -- Final result: {eval_string} (95% CI)") # Save results if _args.output: write_jsonl(all_results, args.output) print (f" -- Responses written to {args.output}") if __name__ == "__main__": parser = argparse.ArgumentParser(description = "Run BigBench Extra Hard evaluation (mini sample set)") model_init.add_args( parser, add_sampling_args = True, default_cache_size = 65536, default_sampling_args = { "temperature": 0.8, "repetition_penalty": 1.0, "presence_penalty": 0.0, "frequency_penalty": 0.0, "penalty_range": 1024, "min_p": 0.0, "top_k": 0, "top_p": 0.8, "adaptive_target": 1.0, "adaptive_decay": 0.9, } ) parser.add_argument("-vis", "--visualize_cache", action = "store_true", help = "Show cache visualizer (slow)") parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", default = None) parser.add_argument("-mt", "--max_tokens", type = int, default = 16384, help = "Max number of tokens for each completion") parser.add_argument("-mbs", "--max_batch_size", type = int, default = 64, help = "Max batch size") parser.add_argument("-think", "--think", action = "store_true", help = "explicitly set template_arg enable_thinking=true") parser.add_argument("-nothink", "--nothink", action = "store_true", help = "explicitly set template_arg enable_thinking=false") parser.add_argument("-thinktags", "--thinktags", nargs = 2, help = 'Think tags for reasoning models, default: "" ""', default = ["", ""]) parser.add_argument("-limit", "--limit", type = int, help = "Limit number of questions (creates incomplete results file)", default = 0) _args = parser.parse_args() # Validate args if _args.output: directory = os.path.dirname(_args.output) if os.path.exists(_args.output): print(f" !! Warning: Output file exists and will be overwritten.") assert not (_args.think and _args.nothink), "Be nice" main(_args)