mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-24 16:29:15 +00:00
Cache prompts for MMLU test
This commit is contained in:
@@ -11,12 +11,15 @@ from exllamav2 import(
|
||||
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
# Models to test
|
||||
|
||||
# model_base = "/mnt/str/models/_exl2"
|
||||
# model_base = "/mnt/str/models/mixtral-8x7b-instruct-exl2/"
|
||||
model_base = "/mnt/str/models/tiefighter-13b-exl4/"
|
||||
model_base = "/mnt/str/models/llama3-8b-exl2"
|
||||
# variants = ["x3-8b"]
|
||||
|
||||
variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
|
||||
|
||||
@@ -72,59 +75,75 @@ def get_model(base, variant_, gpu_split_, batch_size_):
|
||||
return model_, cache_, tokenizer_
|
||||
|
||||
|
||||
# Load questions
|
||||
|
||||
def format_question(question, options, answer, ex = False):
|
||||
|
||||
clabels = "ABCD"
|
||||
text = f"Question:\n"
|
||||
text += question
|
||||
text += "\n\nChoices:\n"
|
||||
for i, o in enumerate(options):
|
||||
text += clabels[i] + ": " + o + "\n"
|
||||
text += "\nAnswer: " + clabels[answer]
|
||||
# if ex:
|
||||
# text += ", " + options[answer]
|
||||
return text
|
||||
|
||||
|
||||
def get_dataset(ds_name, category_, split_):
|
||||
|
||||
print(f" -- Loading dataset: {ds_name}/{category_}...")
|
||||
dataset_ = load_dataset(ds_name, category_, split = split_)
|
||||
return dataset_
|
||||
|
||||
|
||||
# Prepare the prompts
|
||||
|
||||
prep_prompts = {}
|
||||
for category in categories:
|
||||
def load_datasets():
|
||||
global categoriwes
|
||||
|
||||
dataset = get_dataset(qa_set, category, qa_split)
|
||||
def get_dataset(ds_name, category_, split_):
|
||||
|
||||
rows = []
|
||||
for example in dataset:
|
||||
rows.append(example)
|
||||
if len(rows) == questions_per_category + examples_per_category: break
|
||||
print(f" -- Loading dataset: {ds_name}/{category_}...")
|
||||
dataset_ = load_dataset(ds_name, category_, split = split_)
|
||||
return dataset_
|
||||
|
||||
examples_prompt = ""
|
||||
for i in range(examples_per_category):
|
||||
examples_prompt += format_question(rows[i]["question"], rows[i]["choices"], rows[i]["answer"], ex = True)
|
||||
examples_prompt += "\n\n"
|
||||
def format_question(question, options, answer, ex=False):
|
||||
|
||||
prompts = []
|
||||
labels = []
|
||||
for j in range(questions_per_category):
|
||||
i = j + examples_per_category
|
||||
q_prompt = format_question(rows[i]["question"], rows[i]["choices"], rows[i]["answer"])
|
||||
prompts.append(examples_prompt + q_prompt)
|
||||
labels.append(rows[i]["answer"])
|
||||
clabels = "ABCD"
|
||||
text = f"Question:\n"
|
||||
text += question
|
||||
text += "\n\nChoices:\n"
|
||||
for i, o in enumerate(options):
|
||||
text += clabels[i] + ": " + o + "\n"
|
||||
text += "\nAnswer: " + clabels[answer]
|
||||
# if ex:
|
||||
# text += ", " + options[answer]
|
||||
return text
|
||||
|
||||
prep = {"prompts": prompts,
|
||||
"labels": labels}
|
||||
prep_prompts_ = {}
|
||||
for category_ in categories:
|
||||
|
||||
prep_prompts[category] = prep
|
||||
dataset = get_dataset(qa_set, category_, qa_split)
|
||||
|
||||
rows = []
|
||||
for example in dataset:
|
||||
rows.append(example)
|
||||
if len(rows) == questions_per_category + examples_per_category: break
|
||||
|
||||
examples_prompt = ""
|
||||
for i_ in range(examples_per_category):
|
||||
examples_prompt += format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"], ex = True)
|
||||
examples_prompt += "\n\n"
|
||||
|
||||
prompts_ = []
|
||||
labels_ = []
|
||||
for j_ in range(questions_per_category):
|
||||
i_ = j_ + examples_per_category
|
||||
q_prompt = format_question(rows[i_]["question"], rows[i_]["choices"], rows[i_]["answer"])
|
||||
prompts_.append(examples_prompt + q_prompt)
|
||||
labels_.append(rows[i_]["answer"])
|
||||
|
||||
prep = {"prompts": prompts_,
|
||||
"labels": labels_}
|
||||
|
||||
prep_prompts_[category_] = prep
|
||||
|
||||
return prep_prompts_
|
||||
|
||||
|
||||
hash_obj = hashlib.sha256()
|
||||
hash_obj.update(''.join(categories).encode())
|
||||
h = hash_obj.hexdigest()
|
||||
module_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
filename = os.path.join(module_dir, "mmlu_prompts_" + h[:12] + ".json")
|
||||
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as f:
|
||||
prep_prompts = json.load(f)
|
||||
else:
|
||||
prep_prompts = load_datasets()
|
||||
with open(filename, "w") as f:
|
||||
f.write(json.dumps(prep_prompts, indent = 4))
|
||||
|
||||
# Do the test
|
||||
|
||||
|
||||
Reference in New Issue
Block a user