Update MMLU test to use dynamic batching

This commit is contained in:
turboderp
2024-06-01 03:31:40 +02:00
parent 74b49ba28b
commit 823bf11c68
4 changed files with 234 additions and 240 deletions

View File

@@ -1,11 +1,12 @@
from __future__ import annotations
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from human_eval.data import write_jsonl, read_problems
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_8bit
from exllamav2 import model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
import torch, argparse, contextlib
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
import argparse, contextlib
import util
# Args
@@ -61,7 +62,7 @@ else:
print("\n".join(prompt_formats.keys()))
sys.exit()
# Init model
# Init model and cache
model_init.check_args(args)
model_init.print_options(args)
@@ -73,8 +74,6 @@ model, tokenizer = model_init.init(
max_input_len = 2048
)
# Create cache
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
else: cache_type = ExLlamaV2Cache
cache = cache_type(
@@ -83,8 +82,6 @@ cache = cache_type(
max_seq_len = args.cache_size or model.config.max_seq_len
)
# Load model
if not model.loaded:
model.load_autosplit(cache, progress = True)
@@ -112,13 +109,7 @@ num_samples_per_task = args.samples_per_task
# Create jobs
with Progress(
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
BarColumn(bar_width = None),
"[progress.percentage]{task.percentage:>3.0f}%",
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
TimeRemainingColumn()
) as progress:
with util.get_progress() as progress:
task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
for problem_id, problem in problems.items():
@@ -146,22 +137,10 @@ with Progress(
samples = []
# Progress bar
total_jobs = generator.num_remaining_jobs()
if args.verbose:
cm = contextlib.nullcontext()
else:
cm = Progress(
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
BarColumn(bar_width = None),
"[progress.percentage]{task.percentage:>3.0f}%",
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
TimeRemainingColumn()
)
# Work
total_jobs = generator.num_remaining_jobs()
cm = contextlib.nullcontext() if args.verbose else util.get_progress()
with cm as progress:
if not args.verbose:

185
eval/mmlu.py Normal file
View File

@@ -0,0 +1,185 @@
from __future__ import annotations
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
import argparse, contextlib
import torch
import util
# Args
parser = argparse.ArgumentParser(description = "Run MMLU evaluation on EXL2 model")
parser.add_argument("-cs", "--cache_size", type = int, default = None)
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
parser.add_argument("-sub", "--subjects", type = str, default = "all", help = "Comma-separated list of categories to test, or 'all'")
parser.add_argument("-fs", "--fewshot_examples", type = int, default = 5, help = "Number of examples for fewshot examples, max 5")
model_init.add_args(parser)
args = parser.parse_args()
# Init model and cache
model_init.check_args(args)
model_init.print_options(args)
model, tokenizer = model_init.init(
args,
allow_auto_split = True,
progress = True,
max_output_len = 1,
max_input_len = 2048
)
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
else: cache_type = ExLlamaV2Cache
cache = cache_type(
model,
lazy = not model.loaded,
max_seq_len = args.cache_size or model.config.max_seq_len
)
if not model.loaded:
model.load_autosplit(cache, progress = True)
# Generator
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
max_batch_size = 1024,
max_q_size = 1
)
c_options = "ABCD"
gen_settings = ExLlamaV2Sampler.Settings(
token_repetition_penalty = 1.0,
temperature = 1.0,
top_k = 10,
top_p = 1.0,
)
token_map = [tokenizer.single_id(piece) for piece in [" " + c for c in c_options]]
token_rmap = { token_map[i]: i for i in range(len(c_options)) }
gen_settings.allow_tokens(tokenizer, token_map)
# Get dataset
dataset_dev = util.get_dataset("cais/mmlu", "all", "dev")
dataset_all = util.get_dataset("cais/mmlu", "all", "test")
dataset_dev = sorted(dataset_dev, key = lambda q: q["subject"])
dataset_all = sorted(dataset_all, key = lambda q: q["subject"])
all_subjects = set([q["subject"] for q in dataset_dev])
if args.subjects != "all":
sel_subjects = args.subjects.split(",")
for s in sel_subjects:
if s not in all_subjects:
print(f"Subject: {s} is not present in dataset")
sys.exit()
all_subjects = set(sel_subjects)
# Format
def format_question(question: str, choices: list[str], answer: int | None):
f = "Question:\n\n"
f += question
f += "\n\nChoices:\n\n"
for i, c in enumerate(c_options):
f += c + ": " + choices[i] + "\n"
f += "\nCorrect answer:"
if answer is not None:
f += " " + c_options[answer] + "\n\n"
f += "----\n\n"
return f
# Fewshot preprompts
preprompt_ids = {}
with util.get_progress() as progress:
task1 = progress.add_task("[red]Preprompts", total = len(all_subjects), name = "Preparing preprompts")
for subject in all_subjects:
preprompt = ""
fewshots = 0
for pq in dataset_dev:
if fewshots == args.fewshot_examples: break
if pq["subject"] != subject: continue
preprompt += format_question(pq["question"], pq["choices"], pq["answer"])
preprompt_ids[subject] = tokenizer.encode(preprompt, add_bos = True)
progress.update(task1, advance = 1)
# Questions
total_jobs = 0
for q in dataset_all:
if q["subject"] in all_subjects:
total_jobs += 1
with util.get_progress() as progress:
task1 = progress.add_task("[red]Questions", total=total_jobs, name="Preparing questions")
for q in dataset_all:
if q["subject"] not in all_subjects:
continue
prompt = format_question(q["question"], q["choices"], None)
prompt_ids = tokenizer.encode(prompt, add_bos = False)
job = ExLlamaV2DynamicJob(
input_ids = torch.cat([preprompt_ids[q["subject"]], prompt_ids], dim = -1),
gen_settings = gen_settings,
max_new_tokens = 1,
return_top_tokens = 4,
identifier = q,
)
generator.enqueue(job)
progress.update(task1, advance = 1)
# Work
with util.get_progress() as progress:
task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Testing")
while generator.num_remaining_jobs():
results = generator.iterate()
for result in results:
if not result["eos"]:
continue
# Ignore completion and use top-K tokens only
top_tokens = result["top_k_tokens"]
top_probs = result["top_k_probs"]
q = result["identifier"]
correct_answer = q["answer"]
for i in range(top_tokens.shape[-1]):
if top_tokens[0, 0, i].item() == token_map[correct_answer]:
confidence = top_probs[0, 0, i].item()
q["correct_answer_confidence"] = confidence
q["answer_correct"] = token_rmap[top_tokens[0, 0, 0].item()] == correct_answer
progress.update(task1, advance = 1)
# Summarize
total = 0
correct = 0
confidence_sum = 0.0
for q in dataset_all:
if not "answer_correct" in q:
continue
total += 1
if q["answer_correct"]:
correct += 1
confidence_sum += q["correct_answer_confidence"]
print(f"Correct answers: {correct}/{total} = {correct/total*100:.2f}%")
print(f"Confidence: {confidence_sum/total*100:.2f}%")

40
eval/util.py Normal file
View File

@@ -0,0 +1,40 @@
from datasets import load_dataset
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
import os, json
# Rich progress bar format
def get_progress():
return Progress(
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
BarColumn(bar_width = None),
"[progress.percentage]{task.percentage:>3.0f}%",
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
TimeRemainingColumn()
)
# Cached dataset loader
def get_dataset(ds_name, category, split):
cpath = os.path.dirname(os.path.abspath(__file__))
cpath = os.path.join(cpath, "dataset_cache")
if not os.path.exists(cpath):
os.mkdir(cpath)
filename = ds_name + "-" + category + "-" + split + ".jsonl"
filename = filename.replace("/", "_")
filename = os.path.join(cpath, filename)
if os.path.exists(filename):
print(f" -- Loading dataset: {ds_name}/{category}/{split} (cached)...")
with open(filename, "r") as f:
return json.load(f)
else:
print(f" -- Loading dataset: {ds_name}/{category}/{split}...")
dataset = load_dataset(ds_name, category, split = split)
rows = [example for example in dataset]
with open(filename, "w") as f:
f.write(json.dumps(rows, indent = 4))
return rows

View File

@@ -1,210 +0,0 @@
import sys, os, gc
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from datasets import load_dataset
import torch
import hashlib
import json
# Models to test
model_base = "/mnt/str/models/llama3-70b-instruct-exl2"
variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
if not variants: variants = ["."]
# variants = \
# [
# "2.4bpw",
# "2.8bpw",
# "4.0bpw",
# ]
# gpu_split = (20, 21.3, 21, 24)
gpu_split = None #auto
qa_set = "cais/mmlu"
qa_split = "test"
categories = \
[
"anatomy",
"computer_security",
"formal_logic",
"logical_fallacies",
"philosophy",
"nutrition",
]
examples_per_category = 3
questions_per_category = 97
# Load model
def get_model(base, variant_, gpu_split_, batch_size_):
model_dir = os.path.join(base, variant_)
config = ExLlamaV2Config()
config.model_dir = model_dir
config.prepare()
config.max_seq_len = 2048
config.max_batch_size = batch_size_
model_ = ExLlamaV2(config)
print(" -- Loading model: " + model_dir)
if gpu_split_:
model_.load(gpu_split_)
cache_ = None
else:
cache_ = ExLlamaV2Cache(model_, batch_size = batch_size_, lazy = True)
model_.load_autosplit(cache_)
tokenizer_ = ExLlamaV2Tokenizer(config)
return model_, cache_, tokenizer_
# Prepare the prompts
def load_datasets():
global categoriwes
def get_dataset(ds_name, category_, split_):
print(f" -- Loading dataset: {ds_name}/{category_}...")
dataset_ = load_dataset(ds_name, category_, split = split_)
return dataset_
def format_question(question, options, answer, ex=False):
clabels = "ABCD"
text = f"Question:\n"
text += question
text += "\n\nChoices:\n"
for i, o in enumerate(options):
text += clabels[i] + ": " + o + "\n"
text += "\nAnswer: " + clabels[answer]
# if ex:
# text += ", " + options[answer]
return text
prep_prompts_ = {}
for category_ in categories:
dataset = get_dataset(qa_set, category_, qa_split)
rows = []
for example in dataset:
rows.append(example)
if len(rows) == questions_per_category + examples_per_category: break
examples_prompt = ""
for i_ in range(examples_per_category):
examples_prompt += format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"], ex = True)
examples_prompt += "\n\n"
prompts_ = []
labels_ = []
for j_ in range(questions_per_category):
i_ = j_ + examples_per_category
q_prompt = format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"])
prompts_.append(examples_prompt + q_prompt)
labels_.append(rows[i_]["answer"])
prep = {"prompts": prompts_,
"labels": labels_}
prep_prompts_[category_] = prep
return prep_prompts_
hash_obj = hashlib.sha256()
hash_obj.update(''.join(categories).encode())
h = hash_obj.hexdigest()
module_dir = os.path.dirname(os.path.abspath(__file__))
filename = os.path.join(module_dir, "mmlu_prompts_" + h[:12] + ".json")
if os.path.exists(filename):
with open(filename, "r") as f:
prep_prompts = json.load(f)
else:
prep_prompts = load_datasets()
with open(filename, "w") as f:
f.write(json.dumps(prep_prompts, indent = 4))
# Do the test
results = ";".join([""] + categories) + "\n"
for variant in variants:
# Model
model = None
cache = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
model, cache, tokenizer = get_model(model_base, variant, gpu_split, 1)
# Logit positions corresponding to valid answers
answer_logits = []
llabels = "ABCD"
for i in range(4):
answer_ = "The answer is: " + llabels[i]
answer_logits.append(tokenizer.tokenizer.encode(answer_)[-1])
# Categories
cat_results = []
for category in categories:
print(f" -- Testing: {category}...")
prompts = prep_prompts[category]["prompts"]
labels = prep_prompts[category]["labels"]
# Evaluate prompts
score = 0.0
# for prompt_ids, mask in zip(prompt_ids_list, mask_list):
for prompt, label in zip(prompts, labels):
prompt_ids = tokenizer.encode(prompt)
prompt_ids = prompt_ids[:, :-1]
logits = model.forward(prompt_ids, last_id_only = True)
logits = logits.float()
logits_ans = logits[:, :, answer_logits]
prob_ans = torch.softmax(logits_ans, dim = -1)
score += prob_ans[0, 0, label]
score /= questions_per_category
print(f" -- Score: {score:.4f}")
cat_results.append(f"{score:.4f}");
results += ";".join([variant] + cat_results) + "\n"
print(" -- Finished")
print()
print(results)