mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-05-11 16:30:12 +00:00
292 lines
11 KiB
Python
292 lines
11 KiB
Python
from __future__ import annotations
|
|
import os, sys
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from exllamav3.util.file import disk_lru_cache, disk_lru_cache_clear
|
|
from exllamav3 import model_init, Generator, Job, ComboSampler
|
|
from exllamav3.util.progress import ProgressBar
|
|
import argparse
|
|
import torch
|
|
import time, random, json
|
|
from pathlib import Path
|
|
from urllib import request
|
|
from collections import deque
|
|
import math
|
|
|
|
|
|
def evaluate_correctness(_sample: str, _reference: str) -> bool:
|
|
"""
|
|
Canonical evaluation, adapted from https://github.com/google-deepmind/bbeh/blob/main/bbeh/evaluate.py
|
|
"""
|
|
|
|
def strip_latex(response: str) -> str:
|
|
if response.startswith("$") and response.endswith("$"):
|
|
response = response[1:-1]
|
|
if "boxed{" in response and response.endswith("}"):
|
|
response = response[0:-1].split("boxed{")[1]
|
|
if "text{" in response and response.endswith("}"):
|
|
response = response[0:-1].split("text{")[1]
|
|
if "texttt{" in response and response.endswith("}"):
|
|
response = response[0:-1].split("texttt{")[1]
|
|
return response
|
|
|
|
def extract_answer(sample: str) -> str:
|
|
"""Extracts the final answer from the sample."""
|
|
answer_prefixes = [
|
|
"The answer is:",
|
|
"The final answer is ",
|
|
"The final answer is: ",
|
|
"The answer is "
|
|
]
|
|
answer = sample
|
|
for answer_prefix in answer_prefixes:
|
|
if answer_prefix in answer:
|
|
answer = answer.split(answer_prefix)[-1].strip()
|
|
if answer.endswith("."):
|
|
answer = answer[:-1]
|
|
return strip_latex(answer)
|
|
|
|
def fuzzy_match(prediction: str, reference: str) -> bool:
|
|
"""Fuzzy match function for BigBench Extra Hard."""
|
|
if prediction == reference:
|
|
return True
|
|
|
|
# (a) vs a
|
|
if len(prediction) == 3 and prediction[0] == "(" and prediction[-1] == ")":
|
|
return prediction[1] == reference
|
|
if len(reference) == 3 and reference[0] == "(" and reference[-1] == ")":
|
|
return reference[1] == prediction
|
|
|
|
# Numbers
|
|
try:
|
|
if float(prediction) == float(reference):
|
|
return True
|
|
except ValueError:
|
|
pass
|
|
|
|
# quote issues
|
|
if prediction.replace("'", "") == reference.replace("'", ""):
|
|
return True
|
|
|
|
# Bracket issues
|
|
if f"[{reference}]" == prediction or f"[{prediction}]" == reference:
|
|
return True
|
|
|
|
# Question mark issues
|
|
if prediction.endswith("?") and prediction[:-1] == reference:
|
|
return True
|
|
|
|
return False
|
|
|
|
def preprocess_sample(sample: str) -> str:
|
|
prediction = extract_answer(sample.strip()).lower()
|
|
prediction = prediction.replace(", ", ",").replace("**", "")
|
|
prediction = prediction.split("\n")[0]
|
|
prediction = prediction[0:-1] if prediction.endswith(".") else prediction
|
|
return prediction
|
|
|
|
def preprocess_reference(reference: str) -> str:
|
|
reference = reference.strip().lower()
|
|
reference = reference.replace(", ", ",")
|
|
return reference
|
|
|
|
_prediction = preprocess_sample(_sample)
|
|
_reference = preprocess_reference(_reference)
|
|
return fuzzy_match(_prediction, _reference)
|
|
|
|
|
|
@disk_lru_cache("fetch_bbeh_mini_test_data_git")
|
|
def fetch_bbeh_mini_test_data_git() -> (list[dict], list):
|
|
url = "https://raw.githubusercontent.com/google-deepmind/bbeh/refs/heads/main/bbeh/mini/data.json"
|
|
with request.urlopen(url) as response:
|
|
raw = response.read().decode("utf-8")
|
|
data = json.loads(raw)
|
|
return data["examples"]
|
|
|
|
|
|
def strip_reasoning(s: str, args) -> str:
|
|
# TODO: Respect other think tags
|
|
think_start, think_end = args.thinktags
|
|
l = s.find(think_start)
|
|
r = s.rfind(think_end)
|
|
if l >= 0 and r >= 0:
|
|
s = s[:l] + s[r+len(think_end):]
|
|
elif r >= 0:
|
|
s = s[r + len(think_end):]
|
|
return s.strip()
|
|
|
|
|
|
def write_jsonl(rows: list[dict], path: Path) -> None:
|
|
with open(path, "w") as f:
|
|
for row in rows:
|
|
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main(args):
|
|
|
|
# Get the stuff
|
|
bbeh = fetch_bbeh_mini_test_data_git()
|
|
if args.limit:
|
|
bbeh = bbeh[:args.limit]
|
|
|
|
rng = random.Random(123)
|
|
rng.shuffle(bbeh)
|
|
|
|
# Initialize
|
|
model, config, cache, tokenizer = model_init.init(args)
|
|
generator = Generator(
|
|
model = model,
|
|
cache = cache,
|
|
max_batch_size = args.max_batch_size,
|
|
tokenizer = tokenizer,
|
|
show_visualizer = args.visualize_cache,
|
|
)
|
|
sampler = model_init.get_arg_sampler(args)
|
|
|
|
# Create prompts
|
|
template_args = {}
|
|
if args.think: template_args["enable_thinking"] = True
|
|
if args.nothink: template_args["enable_thinking"] = False
|
|
|
|
all_results: list[dict] = []
|
|
with ProgressBar("Prompts", len(bbeh), transient = False) as progress:
|
|
for idx, bp in enumerate(bbeh):
|
|
prompt = bp["input"]
|
|
input_ids = tokenizer.hf_chat_template(
|
|
[{ "role": "user", "content": prompt }],
|
|
add_generation_prompt = True,
|
|
**template_args,
|
|
)
|
|
all_results.append({"input": prompt, "target": bp["target"], "raw_prompt": tokenizer.decode(input_ids[0], decode_special_tokens = True)})
|
|
job = Job(
|
|
input_ids = input_ids,
|
|
max_new_tokens = args.max_tokens,
|
|
stop_conditions = config.eos_token_id_list,
|
|
sampler = sampler,
|
|
identifier = idx,
|
|
max_rq_tokens = 512,
|
|
stop_on_loop = (300, 3),
|
|
)
|
|
generator.enqueue(job)
|
|
progress.update(idx + 1)
|
|
|
|
# Generate
|
|
tps_hist = deque()
|
|
sampled_tokens = 0
|
|
last_update = time.time()
|
|
total_jobs = generator.num_remaining_jobs()
|
|
done_jobs = 0
|
|
total_answered = 0
|
|
total_correct = 0
|
|
eval_string = "(evaluating)"
|
|
|
|
with (ProgressBar("Generating samples", total_jobs, transient = False) as progress):
|
|
|
|
while generator.num_remaining_jobs():
|
|
results = generator.iterate()
|
|
|
|
# Some feedback
|
|
now = time.time()
|
|
if now > last_update + 1:
|
|
tps = sampled_tokens / (now - last_update)
|
|
sampled_tokens = 0
|
|
tps_hist.append(tps)
|
|
if len(tps_hist) > 3:
|
|
tps_hist.popleft()
|
|
tps = round(sum(tps_hist) / len(tps_hist))
|
|
num_pend = generator.num_pending_jobs()
|
|
num_act = generator.num_active_jobs()
|
|
print(f" -- result: {eval_string:32} pending: {num_pend:4} active {num_act:4} {tps:6} tokens/s", end = "")
|
|
if num_act:
|
|
sjob = random.choice(generator.active_jobs)
|
|
snum = "#" + str(sjob.identifier)
|
|
ssamp = repr(sjob.full_completion)[-64:-1]
|
|
print(f" sample from {snum:>4}: {ssamp}")
|
|
else:
|
|
print()
|
|
last_update = now
|
|
|
|
# Collect results
|
|
for result in results:
|
|
if "token_ids" in result:
|
|
sampled_tokens += result["token_ids"].shape[-1]
|
|
if result.get("eos"):
|
|
idx = result["identifier"]
|
|
completion = result["full_completion"]
|
|
match result["eos_reason"]:
|
|
case "max_new_tokens":
|
|
print(f" !! Job #{idx} exceeded token limit, ends in: {repr(completion)[-100:-1]}")
|
|
case "loop_detected":
|
|
print(f" !! Job #{idx} loop detected, ends in: {repr(completion)[-100:-1]}")
|
|
case _:
|
|
pass
|
|
|
|
# Running evaluation
|
|
total_answered += 1
|
|
answer = strip_reasoning(completion, args)
|
|
reference = bbeh[idx]["target"]
|
|
correct = evaluate_correctness(answer, reference)
|
|
if correct:
|
|
total_correct += 1
|
|
score = total_correct / total_answered
|
|
if total_answered == total_jobs:
|
|
eval_string = f"{total_correct: 4}/{total_answered: 4} = {score * 100:6.2f}%"
|
|
else:
|
|
interval = 1.96 * math.sqrt(score * (1 - score) / total_answered * (total_jobs - total_answered) / (total_jobs - 1))
|
|
eval_string = f"{total_correct: 4}/{total_answered: 4} = {score * 100:6.2f}% +/- {interval * 100: 6.2f}%"
|
|
|
|
# Save answer
|
|
all_results[idx]["full_completion"] = completion
|
|
all_results[idx]["answer"] = answer
|
|
all_results[idx]["correct"] = correct
|
|
|
|
done_jobs += 1
|
|
progress.update(done_jobs)
|
|
|
|
# Print result
|
|
print(f" -- Final result: {eval_string} (95% CI)")
|
|
|
|
# Save results
|
|
if _args.output:
|
|
write_jsonl(all_results, args.output)
|
|
print (f" -- Responses written to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description = "Run BigBench Extra Hard evaluation (mini sample set)")
|
|
model_init.add_args(
|
|
parser,
|
|
add_sampling_args = True,
|
|
default_cache_size = 65536,
|
|
default_sampling_args = {
|
|
"temperature": 0.8,
|
|
"repetition_penalty": 1.0,
|
|
"presence_penalty": 0.0,
|
|
"frequency_penalty": 0.0,
|
|
"penalty_range": 1024,
|
|
"min_p": 0.0,
|
|
"top_k": 0,
|
|
"top_p": 0.8,
|
|
"adaptive_target": 1.0,
|
|
"adaptive_decay": 0.9,
|
|
}
|
|
)
|
|
parser.add_argument("-vis", "--visualize_cache", action = "store_true", help = "Show cache visualizer (slow)")
|
|
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", default = None)
|
|
parser.add_argument("-mt", "--max_tokens", type = int, default = 16384, help = "Max number of tokens for each completion")
|
|
parser.add_argument("-mbs", "--max_batch_size", type = int, default = 64, help = "Max batch size")
|
|
parser.add_argument("-think", "--think", action = "store_true", help = "explicitly set template_arg enable_thinking=true")
|
|
parser.add_argument("-nothink", "--nothink", action = "store_true", help = "explicitly set template_arg enable_thinking=false")
|
|
parser.add_argument("-thinktags", "--thinktags", nargs = 2, help = 'Think tags for reasoning models, default: "<think>" "</think>"', default = ["<think>", "</think>"])
|
|
parser.add_argument("-limit", "--limit", type = int, help = "Limit number of questions (creates incomplete results file)", default = 0)
|
|
_args = parser.parse_args()
|
|
|
|
# Validate args
|
|
if _args.output:
|
|
directory = os.path.dirname(_args.output)
|
|
if os.path.exists(_args.output):
|
|
print(f" !! Warning: Output file exists and will be overwritten.")
|
|
|
|
assert not (_args.think and _args.nothink), "Be nice"
|
|
main(_args)
|