From 823bf11c6833c0fd5ddd7c13ee264426fe80272e Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 1 Jun 2024 03:31:40 +0200 Subject: [PATCH] Update MMLU test to use dynamic batching --- eval/humaneval.py | 39 ++------- eval/mmlu.py | 185 +++++++++++++++++++++++++++++++++++++++ eval/util.py | 40 +++++++++ tests/test_mmlu.py | 210 --------------------------------------------- 4 files changed, 234 insertions(+), 240 deletions(-) create mode 100644 eval/mmlu.py create mode 100644 eval/util.py delete mode 100644 tests/test_mmlu.py diff --git a/eval/humaneval.py b/eval/humaneval.py index d12a4cc..f25354f 100644 --- a/eval/humaneval.py +++ b/eval/humaneval.py @@ -1,11 +1,12 @@ +from __future__ import annotations import sys, os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from human_eval.data import write_jsonl, read_problems -from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, model_init -from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_8bit +from exllamav2 import model_init +from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4 from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler -import torch, argparse, contextlib -from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn +import argparse, contextlib +import util # Args @@ -61,7 +62,7 @@ else: print("\n".join(prompt_formats.keys())) sys.exit() -# Init model +# Init model and cache model_init.check_args(args) model_init.print_options(args) @@ -73,8 +74,6 @@ model, tokenizer = model_init.init( max_input_len = 2048 ) -# Create cache - if args.cache_q4: cache_type = ExLlamaV2Cache_Q4 else: cache_type = ExLlamaV2Cache cache = cache_type( @@ -83,8 +82,6 @@ cache = cache_type( max_seq_len = args.cache_size or model.config.max_seq_len ) -# Load model - if not model.loaded: model.load_autosplit(cache, progress = True) @@ -112,13 +109,7 @@ num_samples_per_task = args.samples_per_task # Create jobs -with Progress( - TextColumn("[bold blue]{task.fields[name]}", justify = "left"), - BarColumn(bar_width = None), - "[progress.percentage]{task.percentage:>3.0f}%", - TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"), - TimeRemainingColumn() -) as progress: +with util.get_progress() as progress: task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs") for problem_id, problem in problems.items(): @@ -146,22 +137,10 @@ with Progress( samples = [] -# Progress bar - -total_jobs = generator.num_remaining_jobs() -if args.verbose: - cm = contextlib.nullcontext() -else: - cm = Progress( - TextColumn("[bold blue]{task.fields[name]}", justify = "left"), - BarColumn(bar_width = None), - "[progress.percentage]{task.percentage:>3.0f}%", - TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"), - TimeRemainingColumn() - ) - # Work +total_jobs = generator.num_remaining_jobs() +cm = contextlib.nullcontext() if args.verbose else util.get_progress() with cm as progress: if not args.verbose: diff --git a/eval/mmlu.py b/eval/mmlu.py new file mode 100644 index 0000000..8d988a7 --- /dev/null +++ b/eval/mmlu.py @@ -0,0 +1,185 @@ +from __future__ import annotations +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from exllamav2 import model_init +from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4 +from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler +import argparse, contextlib +import torch +import util + +# Args + +parser = argparse.ArgumentParser(description = "Run MMLU evaluation on EXL2 model") +parser.add_argument("-cs", "--cache_size", type = int, default = None) +parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache") +parser.add_argument("-sub", "--subjects", type = str, default = "all", help = "Comma-separated list of categories to test, or 'all'") +parser.add_argument("-fs", "--fewshot_examples", type = int, default = 5, help = "Number of examples for fewshot examples, max 5") +model_init.add_args(parser) +args = parser.parse_args() + +# Init model and cache + +model_init.check_args(args) +model_init.print_options(args) +model, tokenizer = model_init.init( + args, + allow_auto_split = True, + progress = True, + max_output_len = 1, + max_input_len = 2048 +) + +if args.cache_q4: cache_type = ExLlamaV2Cache_Q4 +else: cache_type = ExLlamaV2Cache +cache = cache_type( + model, + lazy = not model.loaded, + max_seq_len = args.cache_size or model.config.max_seq_len +) + +if not model.loaded: + model.load_autosplit(cache, progress = True) + +# Generator + +generator = ExLlamaV2DynamicGenerator( + model = model, + cache = cache, + tokenizer = tokenizer, + max_batch_size = 1024, + max_q_size = 1 +) + +c_options = "ABCD" + +gen_settings = ExLlamaV2Sampler.Settings( + token_repetition_penalty = 1.0, + temperature = 1.0, + top_k = 10, + top_p = 1.0, +) + +token_map = [tokenizer.single_id(piece) for piece in [" " + c for c in c_options]] +token_rmap = { token_map[i]: i for i in range(len(c_options)) } +gen_settings.allow_tokens(tokenizer, token_map) + +# Get dataset + +dataset_dev = util.get_dataset("cais/mmlu", "all", "dev") +dataset_all = util.get_dataset("cais/mmlu", "all", "test") +dataset_dev = sorted(dataset_dev, key = lambda q: q["subject"]) +dataset_all = sorted(dataset_all, key = lambda q: q["subject"]) + +all_subjects = set([q["subject"] for q in dataset_dev]) +if args.subjects != "all": + sel_subjects = args.subjects.split(",") + for s in sel_subjects: + if s not in all_subjects: + print(f"Subject: {s} is not present in dataset") + sys.exit() + all_subjects = set(sel_subjects) + +# Format + +def format_question(question: str, choices: list[str], answer: int | None): + f = "Question:\n\n" + f += question + f += "\n\nChoices:\n\n" + for i, c in enumerate(c_options): + f += c + ": " + choices[i] + "\n" + f += "\nCorrect answer:" + if answer is not None: + f += " " + c_options[answer] + "\n\n" + f += "----\n\n" + return f + +# Fewshot preprompts + +preprompt_ids = {} +with util.get_progress() as progress: + task1 = progress.add_task("[red]Preprompts", total = len(all_subjects), name = "Preparing preprompts") + for subject in all_subjects: + + preprompt = "" + fewshots = 0 + for pq in dataset_dev: + if fewshots == args.fewshot_examples: break + if pq["subject"] != subject: continue + preprompt += format_question(pq["question"], pq["choices"], pq["answer"]) + preprompt_ids[subject] = tokenizer.encode(preprompt, add_bos = True) + progress.update(task1, advance = 1) + +# Questions + +total_jobs = 0 +for q in dataset_all: + if q["subject"] in all_subjects: + total_jobs += 1 + +with util.get_progress() as progress: + task1 = progress.add_task("[red]Questions", total=total_jobs, name="Preparing questions") + + for q in dataset_all: + if q["subject"] not in all_subjects: + continue + + prompt = format_question(q["question"], q["choices"], None) + prompt_ids = tokenizer.encode(prompt, add_bos = False) + + job = ExLlamaV2DynamicJob( + input_ids = torch.cat([preprompt_ids[q["subject"]], prompt_ids], dim = -1), + gen_settings = gen_settings, + max_new_tokens = 1, + return_top_tokens = 4, + identifier = q, + ) + + generator.enqueue(job) + progress.update(task1, advance = 1) + +# Work + +with util.get_progress() as progress: + task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Testing") + + while generator.num_remaining_jobs(): + + results = generator.iterate() + for result in results: + + if not result["eos"]: + continue + + # Ignore completion and use top-K tokens only + + top_tokens = result["top_k_tokens"] + top_probs = result["top_k_probs"] + q = result["identifier"] + + correct_answer = q["answer"] + for i in range(top_tokens.shape[-1]): + if top_tokens[0, 0, i].item() == token_map[correct_answer]: + confidence = top_probs[0, 0, i].item() + + q["correct_answer_confidence"] = confidence + q["answer_correct"] = token_rmap[top_tokens[0, 0, 0].item()] == correct_answer + + progress.update(task1, advance = 1) + +# Summarize + +total = 0 +correct = 0 +confidence_sum = 0.0 + +for q in dataset_all: + if not "answer_correct" in q: + continue + total += 1 + if q["answer_correct"]: + correct += 1 + confidence_sum += q["correct_answer_confidence"] + +print(f"Correct answers: {correct}/{total} = {correct/total*100:.2f}%") +print(f"Confidence: {confidence_sum/total*100:.2f}%") \ No newline at end of file diff --git a/eval/util.py b/eval/util.py new file mode 100644 index 0000000..44813b4 --- /dev/null +++ b/eval/util.py @@ -0,0 +1,40 @@ +from datasets import load_dataset +from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn +import os, json + +# Rich progress bar format + +def get_progress(): + + return Progress( + TextColumn("[bold blue]{task.fields[name]}", justify = "left"), + BarColumn(bar_width = None), + "[progress.percentage]{task.percentage:>3.0f}%", + TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"), + TimeRemainingColumn() + ) + +# Cached dataset loader + +def get_dataset(ds_name, category, split): + + cpath = os.path.dirname(os.path.abspath(__file__)) + cpath = os.path.join(cpath, "dataset_cache") + if not os.path.exists(cpath): + os.mkdir(cpath) + + filename = ds_name + "-" + category + "-" + split + ".jsonl" + filename = filename.replace("/", "_") + filename = os.path.join(cpath, filename) + + if os.path.exists(filename): + print(f" -- Loading dataset: {ds_name}/{category}/{split} (cached)...") + with open(filename, "r") as f: + return json.load(f) + else: + print(f" -- Loading dataset: {ds_name}/{category}/{split}...") + dataset = load_dataset(ds_name, category, split = split) + rows = [example for example in dataset] + with open(filename, "w") as f: + f.write(json.dumps(rows, indent = 4)) + return rows diff --git a/tests/test_mmlu.py b/tests/test_mmlu.py deleted file mode 100644 index 56c2bab..0000000 --- a/tests/test_mmlu.py +++ /dev/null @@ -1,210 +0,0 @@ - -import sys, os, gc -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from exllamav2 import( - ExLlamaV2, - ExLlamaV2Config, - ExLlamaV2Cache, - ExLlamaV2Tokenizer, -) - -from datasets import load_dataset -import torch -import hashlib -import json - -# Models to test - -model_base = "/mnt/str/models/llama3-70b-instruct-exl2" -variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))] -if not variants: variants = ["."] - -# variants = \ -# [ -# "2.4bpw", -# "2.8bpw", -# "4.0bpw", -# ] - -# gpu_split = (20, 21.3, 21, 24) -gpu_split = None #auto - -qa_set = "cais/mmlu" -qa_split = "test" - -categories = \ -[ - "anatomy", - "computer_security", - "formal_logic", - "logical_fallacies", - "philosophy", - "nutrition", -] - -examples_per_category = 3 -questions_per_category = 97 - -# Load model - -def get_model(base, variant_, gpu_split_, batch_size_): - - model_dir = os.path.join(base, variant_) - - config = ExLlamaV2Config() - config.model_dir = model_dir - config.prepare() - config.max_seq_len = 2048 - config.max_batch_size = batch_size_ - - model_ = ExLlamaV2(config) - print(" -- Loading model: " + model_dir) - - if gpu_split_: - model_.load(gpu_split_) - cache_ = None - else: - cache_ = ExLlamaV2Cache(model_, batch_size = batch_size_, lazy = True) - model_.load_autosplit(cache_) - - tokenizer_ = ExLlamaV2Tokenizer(config) - - return model_, cache_, tokenizer_ - - - -# Prepare the prompts - -def load_datasets(): - global categoriwes - - def get_dataset(ds_name, category_, split_): - - print(f" -- Loading dataset: {ds_name}/{category_}...") - dataset_ = load_dataset(ds_name, category_, split = split_) - return dataset_ - - def format_question(question, options, answer, ex=False): - - clabels = "ABCD" - text = f"Question:\n" - text += question - text += "\n\nChoices:\n" - for i, o in enumerate(options): - text += clabels[i] + ": " + o + "\n" - text += "\nAnswer: " + clabels[answer] - # if ex: - # text += ", " + options[answer] - return text - - prep_prompts_ = {} - for category_ in categories: - - dataset = get_dataset(qa_set, category_, qa_split) - - rows = [] - for example in dataset: - rows.append(example) - if len(rows) == questions_per_category + examples_per_category: break - - examples_prompt = "" - for i_ in range(examples_per_category): - examples_prompt += format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"], ex = True) - examples_prompt += "\n\n" - - prompts_ = [] - labels_ = [] - for j_ in range(questions_per_category): - i_ = j_ + examples_per_category - q_prompt = format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"]) - prompts_.append(examples_prompt + q_prompt) - labels_.append(rows[i_]["answer"]) - - prep = {"prompts": prompts_, - "labels": labels_} - - prep_prompts_[category_] = prep - - return prep_prompts_ - - -hash_obj = hashlib.sha256() -hash_obj.update(''.join(categories).encode()) -h = hash_obj.hexdigest() -module_dir = os.path.dirname(os.path.abspath(__file__)) -filename = os.path.join(module_dir, "mmlu_prompts_" + h[:12] + ".json") - -if os.path.exists(filename): - with open(filename, "r") as f: - prep_prompts = json.load(f) -else: - prep_prompts = load_datasets() - with open(filename, "w") as f: - f.write(json.dumps(prep_prompts, indent = 4)) - -# Do the test - -results = ";".join([""] + categories) + "\n" - -for variant in variants: - - # Model - - model = None - cache = None - tokenizer = None - - gc.collect() - torch.cuda.empty_cache() - gc.collect() - - model, cache, tokenizer = get_model(model_base, variant, gpu_split, 1) - - # Logit positions corresponding to valid answers - - answer_logits = [] - llabels = "ABCD" - for i in range(4): - answer_ = "The answer is: " + llabels[i] - answer_logits.append(tokenizer.tokenizer.encode(answer_)[-1]) - - # Categories - - cat_results = [] - - for category in categories: - - print(f" -- Testing: {category}...") - - prompts = prep_prompts[category]["prompts"] - labels = prep_prompts[category]["labels"] - - # Evaluate prompts - - score = 0.0 - # for prompt_ids, mask in zip(prompt_ids_list, mask_list): - - for prompt, label in zip(prompts, labels): - - prompt_ids = tokenizer.encode(prompt) - prompt_ids = prompt_ids[:, :-1] - - logits = model.forward(prompt_ids, last_id_only = True) - logits = logits.float() - - logits_ans = logits[:, :, answer_logits] - prob_ans = torch.softmax(logits_ans, dim = -1) - - score += prob_ans[0, 0, label] - - score /= questions_per_category - print(f" -- Score: {score:.4f}") - - cat_results.append(f"{score:.4f}"); - - results += ";".join([variant] + cat_results) + "\n" - -print(" -- Finished") -print() -print(results)