MMLU eval: More feedback during eval

This commit is contained in:
turboderp
2025-07-12 18:31:32 +02:00
parent 48ade457da
commit 415a55cc2d

View File

@@ -50,6 +50,11 @@ def main(args):
sys.exit()
all_subjects = set(sel_subjects)
# Skip
all_subjects = sorted(list(all_subjects))
all_subjects = all_subjects[args.skip_subjects:]
all_subjects = set(all_subjects)
# Optionally shuffle
if args.shuffle:
for problem in dataset_all:
@@ -88,22 +93,49 @@ def main(args):
if q["subject"] in all_subjects:
total_jobs += 1
num_remaining = {s: 0 for s in all_subjects}
with ProgressBar("Questions", total_jobs, transient = False) as progress:
for q in dataset_all:
if q["subject"] not in all_subjects:
sub = q["subject"]
if sub not in all_subjects:
continue
if args.max_q_per_subject and num_remaining[sub] >= args.max_q_per_subject:
continue
prompt = format_question(q["question"], q["choices"], None)
prompt_ids = tokenizer.encode(prompt, add_bos = False)
job = Job(
input_ids = torch.cat([preprompt_ids[q["subject"]], prompt_ids], dim = -1),
input_ids = torch.cat([preprompt_ids[sub], prompt_ids], dim = -1),
max_new_tokens = 1,
return_logits = True,
identifier = q,
)
generator.enqueue(job)
num_remaining[sub] += 1
progress.update(generator.num_remaining_jobs())
# Evaluate
def print_results(p_subject):
nonlocal dataset_all
total = 0
correct = 0
confidence_sum = 0.0
for q in dataset_all:
if not "answer_correct" in q:
continue
if p_subject is not None and q["subject"] != p_subject:
continue
total += 1
if q["answer_correct"]:
correct += 1
confidence_sum += q["correct_answer_confidence"]
if p_subject is None:
p_subject = "all"
print(
f"{p_subject:40}: {correct: 5}/{total: 5} = {correct/total*100:6.2f}% correct, "
f"({confidence_sum/total*100:6.2f}% confidence)"
)
with ProgressBar("Testing", total_jobs, transient = False) as progress:
while generator.num_remaining_jobs():
results = generator.iterate()
@@ -119,23 +151,13 @@ def main(args):
confidence = model_probs[correct_answer]
q["correct_answer_confidence"] = confidence
q["answer_correct"] = favored_anwser == correct_answer
sub = q["subject"]
num_remaining[sub] -= 1
if num_remaining[sub] == 0:
print_results(sub)
progress.update(total_jobs - generator.num_remaining_jobs())
# Summarize
total = 0
correct = 0
confidence_sum = 0.0
for q in dataset_all:
if not "answer_correct" in q:
continue
total += 1
if q["answer_correct"]:
correct += 1
confidence_sum += q["correct_answer_confidence"]
print(f"Correct answers: {correct}/{total} = {correct/total*100:.2f}%")
print(f"Avg. confidence: {confidence_sum/total*100:.2f}%")
print_results(None)
if __name__ == "__main__":
@@ -145,5 +167,7 @@ if __name__ == "__main__":
parser.add_argument("-sub", "--subjects", type = str, default = "all", help = "Comma-separated list of categories to test, or 'all'")
parser.add_argument("-shf", "--shuffle", action = "store_true", help = "Shuffle choices randomly")
parser.add_argument("-vis", "--visualize_cache", action = "store_true", help = "Show cache visualizer (slow)")
parser.add_argument("-skip", "--skip_subjects", type = int, default = 0, help = "Skip number of categories")
parser.add_argument("-mqps", "--max_q_per_subject", type = int, default = None, help = "Max questions per subject (default: unlimited)")
_args = parser.parse_args()
main(_args)