mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
MMLU eval: More feedback during eval
This commit is contained in:
58
eval/mmlu.py
58
eval/mmlu.py
@@ -50,6 +50,11 @@ def main(args):
|
||||
sys.exit()
|
||||
all_subjects = set(sel_subjects)
|
||||
|
||||
# Skip
|
||||
all_subjects = sorted(list(all_subjects))
|
||||
all_subjects = all_subjects[args.skip_subjects:]
|
||||
all_subjects = set(all_subjects)
|
||||
|
||||
# Optionally shuffle
|
||||
if args.shuffle:
|
||||
for problem in dataset_all:
|
||||
@@ -88,22 +93,49 @@ def main(args):
|
||||
if q["subject"] in all_subjects:
|
||||
total_jobs += 1
|
||||
|
||||
num_remaining = {s: 0 for s in all_subjects}
|
||||
|
||||
with ProgressBar("Questions", total_jobs, transient = False) as progress:
|
||||
for q in dataset_all:
|
||||
if q["subject"] not in all_subjects:
|
||||
sub = q["subject"]
|
||||
if sub not in all_subjects:
|
||||
continue
|
||||
if args.max_q_per_subject and num_remaining[sub] >= args.max_q_per_subject:
|
||||
continue
|
||||
prompt = format_question(q["question"], q["choices"], None)
|
||||
prompt_ids = tokenizer.encode(prompt, add_bos = False)
|
||||
job = Job(
|
||||
input_ids = torch.cat([preprompt_ids[q["subject"]], prompt_ids], dim = -1),
|
||||
input_ids = torch.cat([preprompt_ids[sub], prompt_ids], dim = -1),
|
||||
max_new_tokens = 1,
|
||||
return_logits = True,
|
||||
identifier = q,
|
||||
)
|
||||
generator.enqueue(job)
|
||||
num_remaining[sub] += 1
|
||||
progress.update(generator.num_remaining_jobs())
|
||||
|
||||
# Evaluate
|
||||
def print_results(p_subject):
|
||||
nonlocal dataset_all
|
||||
total = 0
|
||||
correct = 0
|
||||
confidence_sum = 0.0
|
||||
for q in dataset_all:
|
||||
if not "answer_correct" in q:
|
||||
continue
|
||||
if p_subject is not None and q["subject"] != p_subject:
|
||||
continue
|
||||
total += 1
|
||||
if q["answer_correct"]:
|
||||
correct += 1
|
||||
confidence_sum += q["correct_answer_confidence"]
|
||||
if p_subject is None:
|
||||
p_subject = "all"
|
||||
print(
|
||||
f"{p_subject:40}: {correct: 5}/{total: 5} = {correct/total*100:6.2f}% correct, "
|
||||
f"({confidence_sum/total*100:6.2f}% confidence)"
|
||||
)
|
||||
|
||||
with ProgressBar("Testing", total_jobs, transient = False) as progress:
|
||||
while generator.num_remaining_jobs():
|
||||
results = generator.iterate()
|
||||
@@ -119,23 +151,13 @@ def main(args):
|
||||
confidence = model_probs[correct_answer]
|
||||
q["correct_answer_confidence"] = confidence
|
||||
q["answer_correct"] = favored_anwser == correct_answer
|
||||
sub = q["subject"]
|
||||
num_remaining[sub] -= 1
|
||||
if num_remaining[sub] == 0:
|
||||
print_results(sub)
|
||||
progress.update(total_jobs - generator.num_remaining_jobs())
|
||||
|
||||
# Summarize
|
||||
total = 0
|
||||
correct = 0
|
||||
confidence_sum = 0.0
|
||||
|
||||
for q in dataset_all:
|
||||
if not "answer_correct" in q:
|
||||
continue
|
||||
total += 1
|
||||
if q["answer_correct"]:
|
||||
correct += 1
|
||||
confidence_sum += q["correct_answer_confidence"]
|
||||
|
||||
print(f"Correct answers: {correct}/{total} = {correct/total*100:.2f}%")
|
||||
print(f"Avg. confidence: {confidence_sum/total*100:.2f}%")
|
||||
print_results(None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -145,5 +167,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-sub", "--subjects", type = str, default = "all", help = "Comma-separated list of categories to test, or 'all'")
|
||||
parser.add_argument("-shf", "--shuffle", action = "store_true", help = "Shuffle choices randomly")
|
||||
parser.add_argument("-vis", "--visualize_cache", action = "store_true", help = "Show cache visualizer (slow)")
|
||||
parser.add_argument("-skip", "--skip_subjects", type = int, default = 0, help = "Skip number of categories")
|
||||
parser.add_argument("-mqps", "--max_q_per_subject", type = int, default = None, help = "Max questions per subject (default: unlimited)")
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
Reference in New Issue
Block a user