mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Update MMLU test to use dynamic batching
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from human_eval.data import write_jsonl, read_problems
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, model_init
|
||||
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_8bit
|
||||
from exllamav2 import model_init
|
||||
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4
|
||||
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
|
||||
import torch, argparse, contextlib
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
||||
import argparse, contextlib
|
||||
import util
|
||||
|
||||
# Args
|
||||
|
||||
@@ -61,7 +62,7 @@ else:
|
||||
print("\n".join(prompt_formats.keys()))
|
||||
sys.exit()
|
||||
|
||||
# Init model
|
||||
# Init model and cache
|
||||
|
||||
model_init.check_args(args)
|
||||
model_init.print_options(args)
|
||||
@@ -73,8 +74,6 @@ model, tokenizer = model_init.init(
|
||||
max_input_len = 2048
|
||||
)
|
||||
|
||||
# Create cache
|
||||
|
||||
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
|
||||
else: cache_type = ExLlamaV2Cache
|
||||
cache = cache_type(
|
||||
@@ -83,8 +82,6 @@ cache = cache_type(
|
||||
max_seq_len = args.cache_size or model.config.max_seq_len
|
||||
)
|
||||
|
||||
# Load model
|
||||
|
||||
if not model.loaded:
|
||||
model.load_autosplit(cache, progress = True)
|
||||
|
||||
@@ -112,13 +109,7 @@ num_samples_per_task = args.samples_per_task
|
||||
|
||||
# Create jobs
|
||||
|
||||
with Progress(
|
||||
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
|
||||
BarColumn(bar_width = None),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
|
||||
TimeRemainingColumn()
|
||||
) as progress:
|
||||
with util.get_progress() as progress:
|
||||
|
||||
task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
|
||||
for problem_id, problem in problems.items():
|
||||
@@ -146,22 +137,10 @@ with Progress(
|
||||
|
||||
samples = []
|
||||
|
||||
# Progress bar
|
||||
|
||||
total_jobs = generator.num_remaining_jobs()
|
||||
if args.verbose:
|
||||
cm = contextlib.nullcontext()
|
||||
else:
|
||||
cm = Progress(
|
||||
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
|
||||
BarColumn(bar_width = None),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
|
||||
TimeRemainingColumn()
|
||||
)
|
||||
|
||||
# Work
|
||||
|
||||
total_jobs = generator.num_remaining_jobs()
|
||||
cm = contextlib.nullcontext() if args.verbose else util.get_progress()
|
||||
with cm as progress:
|
||||
|
||||
if not args.verbose:
|
||||
|
||||
185
eval/mmlu.py
Normal file
185
eval/mmlu.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from exllamav2 import model_init
|
||||
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4
|
||||
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
|
||||
import argparse, contextlib
|
||||
import torch
|
||||
import util
|
||||
|
||||
# Args
|
||||
|
||||
parser = argparse.ArgumentParser(description = "Run MMLU evaluation on EXL2 model")
|
||||
parser.add_argument("-cs", "--cache_size", type = int, default = None)
|
||||
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
|
||||
parser.add_argument("-sub", "--subjects", type = str, default = "all", help = "Comma-separated list of categories to test, or 'all'")
|
||||
parser.add_argument("-fs", "--fewshot_examples", type = int, default = 5, help = "Number of examples for fewshot examples, max 5")
|
||||
model_init.add_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init model and cache
|
||||
|
||||
model_init.check_args(args)
|
||||
model_init.print_options(args)
|
||||
model, tokenizer = model_init.init(
|
||||
args,
|
||||
allow_auto_split = True,
|
||||
progress = True,
|
||||
max_output_len = 1,
|
||||
max_input_len = 2048
|
||||
)
|
||||
|
||||
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
|
||||
else: cache_type = ExLlamaV2Cache
|
||||
cache = cache_type(
|
||||
model,
|
||||
lazy = not model.loaded,
|
||||
max_seq_len = args.cache_size or model.config.max_seq_len
|
||||
)
|
||||
|
||||
if not model.loaded:
|
||||
model.load_autosplit(cache, progress = True)
|
||||
|
||||
# Generator
|
||||
|
||||
generator = ExLlamaV2DynamicGenerator(
|
||||
model = model,
|
||||
cache = cache,
|
||||
tokenizer = tokenizer,
|
||||
max_batch_size = 1024,
|
||||
max_q_size = 1
|
||||
)
|
||||
|
||||
c_options = "ABCD"
|
||||
|
||||
gen_settings = ExLlamaV2Sampler.Settings(
|
||||
token_repetition_penalty = 1.0,
|
||||
temperature = 1.0,
|
||||
top_k = 10,
|
||||
top_p = 1.0,
|
||||
)
|
||||
|
||||
token_map = [tokenizer.single_id(piece) for piece in [" " + c for c in c_options]]
|
||||
token_rmap = { token_map[i]: i for i in range(len(c_options)) }
|
||||
gen_settings.allow_tokens(tokenizer, token_map)
|
||||
|
||||
# Get dataset
|
||||
|
||||
dataset_dev = util.get_dataset("cais/mmlu", "all", "dev")
|
||||
dataset_all = util.get_dataset("cais/mmlu", "all", "test")
|
||||
dataset_dev = sorted(dataset_dev, key = lambda q: q["subject"])
|
||||
dataset_all = sorted(dataset_all, key = lambda q: q["subject"])
|
||||
|
||||
all_subjects = set([q["subject"] for q in dataset_dev])
|
||||
if args.subjects != "all":
|
||||
sel_subjects = args.subjects.split(",")
|
||||
for s in sel_subjects:
|
||||
if s not in all_subjects:
|
||||
print(f"Subject: {s} is not present in dataset")
|
||||
sys.exit()
|
||||
all_subjects = set(sel_subjects)
|
||||
|
||||
# Format
|
||||
|
||||
def format_question(question: str, choices: list[str], answer: int | None):
|
||||
f = "Question:\n\n"
|
||||
f += question
|
||||
f += "\n\nChoices:\n\n"
|
||||
for i, c in enumerate(c_options):
|
||||
f += c + ": " + choices[i] + "\n"
|
||||
f += "\nCorrect answer:"
|
||||
if answer is not None:
|
||||
f += " " + c_options[answer] + "\n\n"
|
||||
f += "----\n\n"
|
||||
return f
|
||||
|
||||
# Fewshot preprompts
|
||||
|
||||
preprompt_ids = {}
|
||||
with util.get_progress() as progress:
|
||||
task1 = progress.add_task("[red]Preprompts", total = len(all_subjects), name = "Preparing preprompts")
|
||||
for subject in all_subjects:
|
||||
|
||||
preprompt = ""
|
||||
fewshots = 0
|
||||
for pq in dataset_dev:
|
||||
if fewshots == args.fewshot_examples: break
|
||||
if pq["subject"] != subject: continue
|
||||
preprompt += format_question(pq["question"], pq["choices"], pq["answer"])
|
||||
preprompt_ids[subject] = tokenizer.encode(preprompt, add_bos = True)
|
||||
progress.update(task1, advance = 1)
|
||||
|
||||
# Questions
|
||||
|
||||
total_jobs = 0
|
||||
for q in dataset_all:
|
||||
if q["subject"] in all_subjects:
|
||||
total_jobs += 1
|
||||
|
||||
with util.get_progress() as progress:
|
||||
task1 = progress.add_task("[red]Questions", total=total_jobs, name="Preparing questions")
|
||||
|
||||
for q in dataset_all:
|
||||
if q["subject"] not in all_subjects:
|
||||
continue
|
||||
|
||||
prompt = format_question(q["question"], q["choices"], None)
|
||||
prompt_ids = tokenizer.encode(prompt, add_bos = False)
|
||||
|
||||
job = ExLlamaV2DynamicJob(
|
||||
input_ids = torch.cat([preprompt_ids[q["subject"]], prompt_ids], dim = -1),
|
||||
gen_settings = gen_settings,
|
||||
max_new_tokens = 1,
|
||||
return_top_tokens = 4,
|
||||
identifier = q,
|
||||
)
|
||||
|
||||
generator.enqueue(job)
|
||||
progress.update(task1, advance = 1)
|
||||
|
||||
# Work
|
||||
|
||||
with util.get_progress() as progress:
|
||||
task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Testing")
|
||||
|
||||
while generator.num_remaining_jobs():
|
||||
|
||||
results = generator.iterate()
|
||||
for result in results:
|
||||
|
||||
if not result["eos"]:
|
||||
continue
|
||||
|
||||
# Ignore completion and use top-K tokens only
|
||||
|
||||
top_tokens = result["top_k_tokens"]
|
||||
top_probs = result["top_k_probs"]
|
||||
q = result["identifier"]
|
||||
|
||||
correct_answer = q["answer"]
|
||||
for i in range(top_tokens.shape[-1]):
|
||||
if top_tokens[0, 0, i].item() == token_map[correct_answer]:
|
||||
confidence = top_probs[0, 0, i].item()
|
||||
|
||||
q["correct_answer_confidence"] = confidence
|
||||
q["answer_correct"] = token_rmap[top_tokens[0, 0, 0].item()] == correct_answer
|
||||
|
||||
progress.update(task1, advance = 1)
|
||||
|
||||
# Summarize
|
||||
|
||||
total = 0
|
||||
correct = 0
|
||||
confidence_sum = 0.0
|
||||
|
||||
for q in dataset_all:
|
||||
if not "answer_correct" in q:
|
||||
continue
|
||||
total += 1
|
||||
if q["answer_correct"]:
|
||||
correct += 1
|
||||
confidence_sum += q["correct_answer_confidence"]
|
||||
|
||||
print(f"Correct answers: {correct}/{total} = {correct/total*100:.2f}%")
|
||||
print(f"Confidence: {confidence_sum/total*100:.2f}%")
|
||||
40
eval/util.py
Normal file
40
eval/util.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datasets import load_dataset
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
||||
import os, json
|
||||
|
||||
# Rich progress bar format
|
||||
|
||||
def get_progress():
|
||||
|
||||
return Progress(
|
||||
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
|
||||
BarColumn(bar_width = None),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
|
||||
TimeRemainingColumn()
|
||||
)
|
||||
|
||||
# Cached dataset loader
|
||||
|
||||
def get_dataset(ds_name, category, split):
|
||||
|
||||
cpath = os.path.dirname(os.path.abspath(__file__))
|
||||
cpath = os.path.join(cpath, "dataset_cache")
|
||||
if not os.path.exists(cpath):
|
||||
os.mkdir(cpath)
|
||||
|
||||
filename = ds_name + "-" + category + "-" + split + ".jsonl"
|
||||
filename = filename.replace("/", "_")
|
||||
filename = os.path.join(cpath, filename)
|
||||
|
||||
if os.path.exists(filename):
|
||||
print(f" -- Loading dataset: {ds_name}/{category}/{split} (cached)...")
|
||||
with open(filename, "r") as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
print(f" -- Loading dataset: {ds_name}/{category}/{split}...")
|
||||
dataset = load_dataset(ds_name, category, split = split)
|
||||
rows = [example for example in dataset]
|
||||
with open(filename, "w") as f:
|
||||
f.write(json.dumps(rows, indent = 4))
|
||||
return rows
|
||||
@@ -1,210 +0,0 @@
|
||||
|
||||
import sys, os, gc
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
# Models to test
|
||||
|
||||
model_base = "/mnt/str/models/llama3-70b-instruct-exl2"
|
||||
variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
|
||||
if not variants: variants = ["."]
|
||||
|
||||
# variants = \
|
||||
# [
|
||||
# "2.4bpw",
|
||||
# "2.8bpw",
|
||||
# "4.0bpw",
|
||||
# ]
|
||||
|
||||
# gpu_split = (20, 21.3, 21, 24)
|
||||
gpu_split = None #auto
|
||||
|
||||
qa_set = "cais/mmlu"
|
||||
qa_split = "test"
|
||||
|
||||
categories = \
|
||||
[
|
||||
"anatomy",
|
||||
"computer_security",
|
||||
"formal_logic",
|
||||
"logical_fallacies",
|
||||
"philosophy",
|
||||
"nutrition",
|
||||
]
|
||||
|
||||
examples_per_category = 3
|
||||
questions_per_category = 97
|
||||
|
||||
# Load model
|
||||
|
||||
def get_model(base, variant_, gpu_split_, batch_size_):
|
||||
|
||||
model_dir = os.path.join(base, variant_)
|
||||
|
||||
config = ExLlamaV2Config()
|
||||
config.model_dir = model_dir
|
||||
config.prepare()
|
||||
config.max_seq_len = 2048
|
||||
config.max_batch_size = batch_size_
|
||||
|
||||
model_ = ExLlamaV2(config)
|
||||
print(" -- Loading model: " + model_dir)
|
||||
|
||||
if gpu_split_:
|
||||
model_.load(gpu_split_)
|
||||
cache_ = None
|
||||
else:
|
||||
cache_ = ExLlamaV2Cache(model_, batch_size = batch_size_, lazy = True)
|
||||
model_.load_autosplit(cache_)
|
||||
|
||||
tokenizer_ = ExLlamaV2Tokenizer(config)
|
||||
|
||||
return model_, cache_, tokenizer_
|
||||
|
||||
|
||||
|
||||
# Prepare the prompts
|
||||
|
||||
def load_datasets():
|
||||
global categoriwes
|
||||
|
||||
def get_dataset(ds_name, category_, split_):
|
||||
|
||||
print(f" -- Loading dataset: {ds_name}/{category_}...")
|
||||
dataset_ = load_dataset(ds_name, category_, split = split_)
|
||||
return dataset_
|
||||
|
||||
def format_question(question, options, answer, ex=False):
|
||||
|
||||
clabels = "ABCD"
|
||||
text = f"Question:\n"
|
||||
text += question
|
||||
text += "\n\nChoices:\n"
|
||||
for i, o in enumerate(options):
|
||||
text += clabels[i] + ": " + o + "\n"
|
||||
text += "\nAnswer: " + clabels[answer]
|
||||
# if ex:
|
||||
# text += ", " + options[answer]
|
||||
return text
|
||||
|
||||
prep_prompts_ = {}
|
||||
for category_ in categories:
|
||||
|
||||
dataset = get_dataset(qa_set, category_, qa_split)
|
||||
|
||||
rows = []
|
||||
for example in dataset:
|
||||
rows.append(example)
|
||||
if len(rows) == questions_per_category + examples_per_category: break
|
||||
|
||||
examples_prompt = ""
|
||||
for i_ in range(examples_per_category):
|
||||
examples_prompt += format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"], ex = True)
|
||||
examples_prompt += "\n\n"
|
||||
|
||||
prompts_ = []
|
||||
labels_ = []
|
||||
for j_ in range(questions_per_category):
|
||||
i_ = j_ + examples_per_category
|
||||
q_prompt = format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"])
|
||||
prompts_.append(examples_prompt + q_prompt)
|
||||
labels_.append(rows[i_]["answer"])
|
||||
|
||||
prep = {"prompts": prompts_,
|
||||
"labels": labels_}
|
||||
|
||||
prep_prompts_[category_] = prep
|
||||
|
||||
return prep_prompts_
|
||||
|
||||
|
||||
hash_obj = hashlib.sha256()
|
||||
hash_obj.update(''.join(categories).encode())
|
||||
h = hash_obj.hexdigest()
|
||||
module_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
filename = os.path.join(module_dir, "mmlu_prompts_" + h[:12] + ".json")
|
||||
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as f:
|
||||
prep_prompts = json.load(f)
|
||||
else:
|
||||
prep_prompts = load_datasets()
|
||||
with open(filename, "w") as f:
|
||||
f.write(json.dumps(prep_prompts, indent = 4))
|
||||
|
||||
# Do the test
|
||||
|
||||
results = ";".join([""] + categories) + "\n"
|
||||
|
||||
for variant in variants:
|
||||
|
||||
# Model
|
||||
|
||||
model = None
|
||||
cache = None
|
||||
tokenizer = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
model, cache, tokenizer = get_model(model_base, variant, gpu_split, 1)
|
||||
|
||||
# Logit positions corresponding to valid answers
|
||||
|
||||
answer_logits = []
|
||||
llabels = "ABCD"
|
||||
for i in range(4):
|
||||
answer_ = "The answer is: " + llabels[i]
|
||||
answer_logits.append(tokenizer.tokenizer.encode(answer_)[-1])
|
||||
|
||||
# Categories
|
||||
|
||||
cat_results = []
|
||||
|
||||
for category in categories:
|
||||
|
||||
print(f" -- Testing: {category}...")
|
||||
|
||||
prompts = prep_prompts[category]["prompts"]
|
||||
labels = prep_prompts[category]["labels"]
|
||||
|
||||
# Evaluate prompts
|
||||
|
||||
score = 0.0
|
||||
# for prompt_ids, mask in zip(prompt_ids_list, mask_list):
|
||||
|
||||
for prompt, label in zip(prompts, labels):
|
||||
|
||||
prompt_ids = tokenizer.encode(prompt)
|
||||
prompt_ids = prompt_ids[:, :-1]
|
||||
|
||||
logits = model.forward(prompt_ids, last_id_only = True)
|
||||
logits = logits.float()
|
||||
|
||||
logits_ans = logits[:, :, answer_logits]
|
||||
prob_ans = torch.softmax(logits_ans, dim = -1)
|
||||
|
||||
score += prob_ans[0, 0, label]
|
||||
|
||||
score /= questions_per_category
|
||||
print(f" -- Score: {score:.4f}")
|
||||
|
||||
cat_results.append(f"{score:.4f}");
|
||||
|
||||
results += ";".join([variant] + cat_results) + "\n"
|
||||
|
||||
print(" -- Finished")
|
||||
print()
|
||||
print(results)
|
||||
Reference in New Issue
Block a user