Files
exllamav3/eval/bbeh_mini.py
2026-04-04 23:29:19 +02:00

292 lines
11 KiB
Python

from __future__ import annotations
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav3.util.file import disk_lru_cache, disk_lru_cache_clear
from exllamav3 import model_init, Generator, Job, ComboSampler
from exllamav3.util.progress import ProgressBar
import argparse
import torch
import time, random, json
from pathlib import Path
from urllib import request
from collections import deque
import math
def evaluate_correctness(_sample: str, _reference: str) -> bool:
"""
Canonical evaluation, adapted from https://github.com/google-deepmind/bbeh/blob/main/bbeh/evaluate.py
"""
def strip_latex(response: str) -> str:
if response.startswith("$") and response.endswith("$"):
response = response[1:-1]
if "boxed{" in response and response.endswith("}"):
response = response[0:-1].split("boxed{")[1]
if "text{" in response and response.endswith("}"):
response = response[0:-1].split("text{")[1]
if "texttt{" in response and response.endswith("}"):
response = response[0:-1].split("texttt{")[1]
return response
def extract_answer(sample: str) -> str:
"""Extracts the final answer from the sample."""
answer_prefixes = [
"The answer is:",
"The final answer is ",
"The final answer is: ",
"The answer is "
]
answer = sample
for answer_prefix in answer_prefixes:
if answer_prefix in answer:
answer = answer.split(answer_prefix)[-1].strip()
if answer.endswith("."):
answer = answer[:-1]
return strip_latex(answer)
def fuzzy_match(prediction: str, reference: str) -> bool:
"""Fuzzy match function for BigBench Extra Hard."""
if prediction == reference:
return True
# (a) vs a
if len(prediction) == 3 and prediction[0] == "(" and prediction[-1] == ")":
return prediction[1] == reference
if len(reference) == 3 and reference[0] == "(" and reference[-1] == ")":
return reference[1] == prediction
# Numbers
try:
if float(prediction) == float(reference):
return True
except ValueError:
pass
# quote issues
if prediction.replace("'", "") == reference.replace("'", ""):
return True
# Bracket issues
if f"[{reference}]" == prediction or f"[{prediction}]" == reference:
return True
# Question mark issues
if prediction.endswith("?") and prediction[:-1] == reference:
return True
return False
def preprocess_sample(sample: str) -> str:
prediction = extract_answer(sample.strip()).lower()
prediction = prediction.replace(", ", ",").replace("**", "")
prediction = prediction.split("\n")[0]
prediction = prediction[0:-1] if prediction.endswith(".") else prediction
return prediction
def preprocess_reference(reference: str) -> str:
reference = reference.strip().lower()
reference = reference.replace(", ", ",")
return reference
_prediction = preprocess_sample(_sample)
_reference = preprocess_reference(_reference)
return fuzzy_match(_prediction, _reference)
@disk_lru_cache("fetch_bbeh_mini_test_data_git")
def fetch_bbeh_mini_test_data_git() -> (list[dict], list):
url = "https://raw.githubusercontent.com/google-deepmind/bbeh/refs/heads/main/bbeh/mini/data.json"
with request.urlopen(url) as response:
raw = response.read().decode("utf-8")
data = json.loads(raw)
return data["examples"]
def strip_reasoning(s: str, args) -> str:
# TODO: Respect other think tags
think_start, think_end = args.thinktags
l = s.find(think_start)
r = s.rfind(think_end)
if l >= 0 and r >= 0:
s = s[:l] + s[r+len(think_end):]
elif r >= 0:
s = s[r + len(think_end):]
return s.strip()
def write_jsonl(rows: list[dict], path: Path) -> None:
with open(path, "w") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
@torch.inference_mode()
def main(args):
# Get the stuff
bbeh = fetch_bbeh_mini_test_data_git()
if args.limit:
bbeh = bbeh[:args.limit]
rng = random.Random(123)
rng.shuffle(bbeh)
# Initialize
model, config, cache, tokenizer = model_init.init(args)
generator = Generator(
model = model,
cache = cache,
max_batch_size = args.max_batch_size,
tokenizer = tokenizer,
show_visualizer = args.visualize_cache,
)
sampler = model_init.get_arg_sampler(args)
# Create prompts
template_args = {}
if args.think: template_args["enable_thinking"] = True
if args.nothink: template_args["enable_thinking"] = False
all_results: list[dict] = []
with ProgressBar("Prompts", len(bbeh), transient = False) as progress:
for idx, bp in enumerate(bbeh):
prompt = bp["input"]
input_ids = tokenizer.hf_chat_template(
[{ "role": "user", "content": prompt }],
add_generation_prompt = True,
**template_args,
)
all_results.append({"input": prompt, "target": bp["target"], "raw_prompt": tokenizer.decode(input_ids[0], decode_special_tokens = True)})
job = Job(
input_ids = input_ids,
max_new_tokens = args.max_tokens,
stop_conditions = config.eos_token_id_list,
sampler = sampler,
identifier = idx,
max_rq_tokens = 512,
stop_on_loop = (300, 3),
)
generator.enqueue(job)
progress.update(idx + 1)
# Generate
tps_hist = deque()
sampled_tokens = 0
last_update = time.time()
total_jobs = generator.num_remaining_jobs()
done_jobs = 0
total_answered = 0
total_correct = 0
eval_string = "(evaluating)"
with (ProgressBar("Generating samples", total_jobs, transient = False) as progress):
while generator.num_remaining_jobs():
results = generator.iterate()
# Some feedback
now = time.time()
if now > last_update + 1:
tps = sampled_tokens / (now - last_update)
sampled_tokens = 0
tps_hist.append(tps)
if len(tps_hist) > 3:
tps_hist.popleft()
tps = round(sum(tps_hist) / len(tps_hist))
num_pend = generator.num_pending_jobs()
num_act = generator.num_active_jobs()
print(f" -- result: {eval_string:32} pending: {num_pend:4} active {num_act:4} {tps:6} tokens/s", end = "")
if num_act:
sjob = random.choice(generator.active_jobs)
snum = "#" + str(sjob.identifier)
ssamp = repr(sjob.full_completion)[-64:-1]
print(f" sample from {snum:>4}: {ssamp}")
else:
print()
last_update = now
# Collect results
for result in results:
if "token_ids" in result:
sampled_tokens += result["token_ids"].shape[-1]
if result.get("eos"):
idx = result["identifier"]
completion = result["full_completion"]
match result["eos_reason"]:
case "max_new_tokens":
print(f" !! Job #{idx} exceeded token limit, ends in: {repr(completion)[-100:-1]}")
case "loop_detected":
print(f" !! Job #{idx} loop detected, ends in: {repr(completion)[-100:-1]}")
case _:
pass
# Running evaluation
total_answered += 1
answer = strip_reasoning(completion, args)
reference = bbeh[idx]["target"]
correct = evaluate_correctness(answer, reference)
if correct:
total_correct += 1
score = total_correct / total_answered
if total_answered == total_jobs:
eval_string = f"{total_correct: 4}/{total_answered: 4} = {score * 100:6.2f}%"
else:
interval = 1.96 * math.sqrt(score * (1 - score) / total_answered * (total_jobs - total_answered) / (total_jobs - 1))
eval_string = f"{total_correct: 4}/{total_answered: 4} = {score * 100:6.2f}% +/- {interval * 100: 6.2f}%"
# Save answer
all_results[idx]["full_completion"] = completion
all_results[idx]["answer"] = answer
all_results[idx]["correct"] = correct
done_jobs += 1
progress.update(done_jobs)
# Print result
print(f" -- Final result: {eval_string} (95% CI)")
# Save results
if _args.output:
write_jsonl(all_results, args.output)
print (f" -- Responses written to {args.output}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = "Run BigBench Extra Hard evaluation (mini sample set)")
model_init.add_args(
parser,
add_sampling_args = True,
default_cache_size = 65536,
default_sampling_args = {
"temperature": 0.8,
"repetition_penalty": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"penalty_range": 1024,
"min_p": 0.0,
"top_k": 0,
"top_p": 0.8,
"adaptive_target": 1.0,
"adaptive_decay": 0.9,
}
)
parser.add_argument("-vis", "--visualize_cache", action = "store_true", help = "Show cache visualizer (slow)")
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", default = None)
parser.add_argument("-mt", "--max_tokens", type = int, default = 16384, help = "Max number of tokens for each completion")
parser.add_argument("-mbs", "--max_batch_size", type = int, default = 64, help = "Max batch size")
parser.add_argument("-think", "--think", action = "store_true", help = "explicitly set template_arg enable_thinking=true")
parser.add_argument("-nothink", "--nothink", action = "store_true", help = "explicitly set template_arg enable_thinking=false")
parser.add_argument("-thinktags", "--thinktags", nargs = 2, help = 'Think tags for reasoning models, default: "<think>" "</think>"', default = ["<think>", "</think>"])
parser.add_argument("-limit", "--limit", type = int, help = "Limit number of questions (creates incomplete results file)", default = 0)
_args = parser.parse_args()
# Validate args
if _args.output:
directory = os.path.dirname(_args.output)
if os.path.exists(_args.output):
print(f" !! Warning: Output file exists and will be overwritten.")
assert not (_args.think and _args.nothink), "Be nice"
main(_args)