mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Update HumanEval test to dynamic generator
This commit is contained in:
211
eval/humaneval.py
Normal file
211
eval/humaneval.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from human_eval.data import write_jsonl, read_problems
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, model_init
|
||||
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_8bit
|
||||
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
|
||||
import torch, argparse, contextlib
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
||||
|
||||
# Args
|
||||
|
||||
parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
|
||||
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
|
||||
parser.add_argument("-cs", "--cache_size", type = int, default = None)
|
||||
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
|
||||
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
|
||||
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
|
||||
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
|
||||
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
|
||||
model_init.add_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate args
|
||||
|
||||
directory = os.path.dirname(args.output)
|
||||
if directory and not os.path.isdir(directory):
|
||||
print(f" ## Directory for output file {args.output} does not exist.")
|
||||
sys.exit()
|
||||
if os.path.exists(args.output):
|
||||
print(f" !! Warning: Output file exists and will be overwritten.")
|
||||
|
||||
# Prompt formats
|
||||
|
||||
prompt_formats = {
|
||||
"raw": (
|
||||
"```python\n{{problem}} ",
|
||||
" "
|
||||
),
|
||||
"granite": (
|
||||
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
|
||||
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
|
||||
" "
|
||||
),
|
||||
"llama3": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n"
|
||||
"You are a helpful AI coding assistant.<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
|
||||
" "
|
||||
)
|
||||
}
|
||||
|
||||
if args.prompt_format is None:
|
||||
prompt_format, prefix = "{{problem}}", " "
|
||||
elif args.prompt_format in prompt_formats:
|
||||
prompt_format, prefix = prompt_formats[args.prompt_format]
|
||||
else:
|
||||
print("Prompt format is not supported. Available formats:")
|
||||
print("\n".join(prompt_formats.keys()))
|
||||
sys.exit()
|
||||
|
||||
# Init model
|
||||
|
||||
model_init.check_args(args)
|
||||
model_init.print_options(args)
|
||||
model, tokenizer = model_init.init(
|
||||
args,
|
||||
allow_auto_split = True,
|
||||
progress = True,
|
||||
max_output_len = 4,
|
||||
max_input_len = 2048
|
||||
)
|
||||
|
||||
# Create cache
|
||||
|
||||
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
|
||||
else: cache_type = ExLlamaV2Cache
|
||||
cache = cache_type(
|
||||
model,
|
||||
lazy = not model.loaded,
|
||||
max_seq_len = args.cache_size or model.config.max_seq_len
|
||||
)
|
||||
|
||||
# Load model
|
||||
|
||||
if not model.loaded:
|
||||
model.load_autosplit(cache, progress = True)
|
||||
|
||||
# Generator
|
||||
|
||||
generator = ExLlamaV2DynamicGenerator(
|
||||
model = model,
|
||||
cache = cache,
|
||||
tokenizer = tokenizer,
|
||||
max_batch_size = 256,
|
||||
max_q_size = 4
|
||||
)
|
||||
|
||||
gen_settings = ExLlamaV2Sampler.Settings(
|
||||
token_repetition_penalty = 1.0,
|
||||
temperature = 0.8,
|
||||
top_k = 100,
|
||||
top_p = 0.8
|
||||
)
|
||||
|
||||
# Get problems
|
||||
|
||||
problems = read_problems()
|
||||
num_samples_per_task = args.samples_per_task
|
||||
|
||||
# Create jobs
|
||||
|
||||
with Progress(
|
||||
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
|
||||
BarColumn(bar_width = None),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
|
||||
TimeRemainingColumn()
|
||||
) as progress:
|
||||
|
||||
task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
|
||||
for problem_id, problem in problems.items():
|
||||
|
||||
b_problem = problem["prompt"]
|
||||
f_problem = prompt_format.replace("{{problem}}", b_problem)
|
||||
input_ids = tokenizer.encode(f_problem, encode_special_tokens=True, add_bos=True)
|
||||
|
||||
for s in range(num_samples_per_task):
|
||||
|
||||
job = ExLlamaV2DynamicJob(
|
||||
input_ids = input_ids,
|
||||
gen_settings = gen_settings,
|
||||
max_new_tokens = args.max_tokens,
|
||||
stop_conditions = [tokenizer.eos_token_id],
|
||||
token_healing = True,
|
||||
identifier = (problem_id, s),
|
||||
min_new_tokens = 6
|
||||
)
|
||||
|
||||
generator.enqueue(job)
|
||||
progress.update(task1, advance = 1)
|
||||
|
||||
# Collect samples here
|
||||
|
||||
samples = []
|
||||
|
||||
# Progress bar
|
||||
|
||||
total_jobs = generator.num_remaining_jobs()
|
||||
if args.verbose:
|
||||
cm = contextlib.nullcontext()
|
||||
else:
|
||||
cm = Progress(
|
||||
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
|
||||
BarColumn(bar_width = None),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
|
||||
TimeRemainingColumn()
|
||||
)
|
||||
|
||||
# Work
|
||||
|
||||
with cm as progress:
|
||||
|
||||
if not args.verbose:
|
||||
task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Generating samples")
|
||||
|
||||
while generator.num_remaining_jobs():
|
||||
|
||||
results = generator.iterate()
|
||||
for result in results:
|
||||
|
||||
# End sample if generator says EOS or if there is a non-indented line at the end of the output
|
||||
|
||||
job = result["job"]
|
||||
eos = False
|
||||
completion = job.full_completion
|
||||
last_newline_index = completion.rfind("\n")
|
||||
if last_newline_index >= 0:
|
||||
last_line = completion[last_newline_index + 1:]
|
||||
if last_line != "" and not last_line[0].isspace():
|
||||
completion = completion[:last_newline_index]
|
||||
eos = True
|
||||
eos = eos or result["eos"]
|
||||
|
||||
# Collect completed sample
|
||||
|
||||
if eos:
|
||||
identifier = result["identifier"]
|
||||
sample = problems[identifier[0]]["prompt"] + prefix + completion.strip()
|
||||
if not result["eos"]:
|
||||
generator.cancel(job)
|
||||
|
||||
if args.verbose:
|
||||
print("----------------------------------------------------------------------")
|
||||
print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {num_samples_per_task}")
|
||||
print("----------------------------------------------------------------------")
|
||||
print(sample)
|
||||
print()
|
||||
else:
|
||||
progress.update(task1, advance = 1)
|
||||
|
||||
samples.append(dict(task_id = identifier[0], completion = prefix + completion.strip()))
|
||||
|
||||
# Save output
|
||||
|
||||
print(f" -- Saving: {args.output}")
|
||||
write_jsonl(args.output, samples)
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from human_eval.data import write_jsonl, read_problems
|
||||
|
||||
from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
model_init
|
||||
)
|
||||
|
||||
from exllamav2.generator import(
|
||||
ExLlamaV2BaseGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
|
||||
import torch, argparse
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
||||
|
||||
# Args
|
||||
|
||||
parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
|
||||
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
|
||||
parser.add_argument("-bs", "--batch_size", type = int, default = 10)
|
||||
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
|
||||
parser.add_argument("-c8", "--cache_8bit", action = "store_true", help = "Use 8-bit (FP8) cache")
|
||||
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
|
||||
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
|
||||
parser.add_argument("-f", "--format", type = str, help = "Instruct format to apply")
|
||||
parser.add_argument("-nv", "--nonverbose", action = "store_true", help = "Don't spam completions to console while generating")
|
||||
model_init.add_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate args
|
||||
|
||||
directory = os.path.dirname(args.output)
|
||||
if directory and not os.path.isdir(directory):
|
||||
print(f" ## Directory for output file {args.output} does not exist.")
|
||||
sys.exit()
|
||||
if os.path.exists(args.output):
|
||||
print(f" !! Warning: Output file exists and will be overwritten.")
|
||||
|
||||
# Init model
|
||||
|
||||
model_init.check_args(args)
|
||||
model_init.print_options(args)
|
||||
model, tokenizer = model_init.init(args, allow_auto_split = True, max_batch_size = args.batch_size)
|
||||
|
||||
# Create cache
|
||||
|
||||
if args.cache_8bit: cache_type = ExLlamaV2Cache_8bit
|
||||
elif args.cache_q4: cache_type = ExLlamaV2Cache_Q4
|
||||
else: cache_type = ExLlamaV2Cache
|
||||
cache = cache_type(model, lazy = not model.loaded, batch_size = args.batch_size)
|
||||
|
||||
# Load model
|
||||
|
||||
if not model.loaded:
|
||||
|
||||
print(" -- Loading model...")
|
||||
model.load_autosplit(cache)
|
||||
|
||||
# Generator
|
||||
|
||||
gen = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
gen_settings.token_repetition_penalty = 1.0
|
||||
gen_settings.temperature = 0.8
|
||||
gen_settings.top_k = 100
|
||||
gen_settings.top_p = 0.8
|
||||
|
||||
# Get problems
|
||||
|
||||
problems = read_problems()
|
||||
num_samples_per_task = args.samples_per_task
|
||||
samples = []
|
||||
sub_progress = num_samples_per_task > args.batch_size
|
||||
|
||||
formats = {
|
||||
"granite": (
|
||||
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n" +
|
||||
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
|
||||
" "
|
||||
)
|
||||
}
|
||||
|
||||
with Progress(
|
||||
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
|
||||
BarColumn(bar_width = None),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
|
||||
task1 = progress.add_task("[red]Problem", total = len(problems), name = "Problems")
|
||||
for task_id in problems:
|
||||
|
||||
rem_samples = num_samples_per_task
|
||||
sample_idx = 0
|
||||
if sub_progress: task2 = progress.add_task("[red]Sample", total = num_samples_per_task, name = "Samples", parent = task1)
|
||||
while rem_samples:
|
||||
bs = min(args.batch_size, rem_samples)
|
||||
|
||||
# Get problem and batch of completions
|
||||
|
||||
b_problem = problems[task_id]["prompt"]
|
||||
if args.format is None:
|
||||
f, p = "{{problem}}", ""
|
||||
else:
|
||||
f, p = formats[args.format]
|
||||
f_problem = f.replace("{{problem}}", b_problem)
|
||||
|
||||
problem = [f_problem] * bs
|
||||
responses = gen.generate_simple(problem, gen_settings, args.max_tokens, stop_token = tokenizer.eos_token_id, add_bos = True, token_healing = True)
|
||||
|
||||
for response in responses:
|
||||
|
||||
# Simplified cleanup of response: remove all lines starting from the first line with no indentation,
|
||||
# i.e. keep exactly one function
|
||||
|
||||
r = response[len(problem[0]):]
|
||||
r = p + r
|
||||
s = r.split("\n")
|
||||
crop = len(s)
|
||||
for l in range(1, len(s)):
|
||||
if len(s[l]) > 0:
|
||||
b = s[l][0:1]
|
||||
if b != " " and b != "\t" and b != "#":
|
||||
crop = l
|
||||
break
|
||||
r = "\n".join(s[:crop])
|
||||
|
||||
# Store sample
|
||||
|
||||
samples.append(dict(task_id = task_id, completion = r))
|
||||
|
||||
# Print to console
|
||||
|
||||
if not args.nonverbose:
|
||||
print("----------------------------------------------------------------------")
|
||||
print(f" ** Problem {task_id}, sample {sample_idx + 1} / {num_samples_per_task}")
|
||||
print("----------------------------------------------------------------------")
|
||||
print(b_problem + p + r)
|
||||
|
||||
sample_idx += 1
|
||||
|
||||
rem_samples -= bs
|
||||
if sub_progress: progress.advance(task2, bs)
|
||||
|
||||
if sub_progress: progress.remove_task(task2)
|
||||
progress.update(task1, advance = 1)
|
||||
|
||||
# Save output
|
||||
|
||||
print(f" -- Saving: {args.output}")
|
||||
write_jsonl(args.output, samples)
|
||||
Reference in New Issue
Block a user