Update HumanEval test to dynamic generator

This commit is contained in:
turboderp
2024-05-31 22:24:16 +02:00
parent c14dc314ba
commit e7cbb300ff
2 changed files with 211 additions and 160 deletions

211
eval/humaneval.py Normal file
View File

@@ -0,0 +1,211 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from human_eval.data import write_jsonl, read_problems
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_8bit
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
import torch, argparse, contextlib
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
# Args
parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
parser.add_argument("-cs", "--cache_size", type = int, default = None)
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
model_init.add_args(parser)
args = parser.parse_args()
# Validate args
directory = os.path.dirname(args.output)
if directory and not os.path.isdir(directory):
print(f" ## Directory for output file {args.output} does not exist.")
sys.exit()
if os.path.exists(args.output):
print(f" !! Warning: Output file exists and will be overwritten.")
# Prompt formats
prompt_formats = {
"raw": (
"```python\n{{problem}} ",
" "
),
"granite": (
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
" "
),
"llama3": (
"<|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful AI coding assistant.<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
" "
)
}
if args.prompt_format is None:
prompt_format, prefix = "{{problem}}", " "
elif args.prompt_format in prompt_formats:
prompt_format, prefix = prompt_formats[args.prompt_format]
else:
print("Prompt format is not supported. Available formats:")
print("\n".join(prompt_formats.keys()))
sys.exit()
# Init model
model_init.check_args(args)
model_init.print_options(args)
model, tokenizer = model_init.init(
args,
allow_auto_split = True,
progress = True,
max_output_len = 4,
max_input_len = 2048
)
# Create cache
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
else: cache_type = ExLlamaV2Cache
cache = cache_type(
model,
lazy = not model.loaded,
max_seq_len = args.cache_size or model.config.max_seq_len
)
# Load model
if not model.loaded:
model.load_autosplit(cache, progress = True)
# Generator
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
max_batch_size = 256,
max_q_size = 4
)
gen_settings = ExLlamaV2Sampler.Settings(
token_repetition_penalty = 1.0,
temperature = 0.8,
top_k = 100,
top_p = 0.8
)
# Get problems
problems = read_problems()
num_samples_per_task = args.samples_per_task
# Create jobs
with Progress(
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
BarColumn(bar_width = None),
"[progress.percentage]{task.percentage:>3.0f}%",
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
TimeRemainingColumn()
) as progress:
task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
for problem_id, problem in problems.items():
b_problem = problem["prompt"]
f_problem = prompt_format.replace("{{problem}}", b_problem)
input_ids = tokenizer.encode(f_problem, encode_special_tokens=True, add_bos=True)
for s in range(num_samples_per_task):
job = ExLlamaV2DynamicJob(
input_ids = input_ids,
gen_settings = gen_settings,
max_new_tokens = args.max_tokens,
stop_conditions = [tokenizer.eos_token_id],
token_healing = True,
identifier = (problem_id, s),
min_new_tokens = 6
)
generator.enqueue(job)
progress.update(task1, advance = 1)
# Collect samples here
samples = []
# Progress bar
total_jobs = generator.num_remaining_jobs()
if args.verbose:
cm = contextlib.nullcontext()
else:
cm = Progress(
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
BarColumn(bar_width = None),
"[progress.percentage]{task.percentage:>3.0f}%",
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
TimeRemainingColumn()
)
# Work
with cm as progress:
if not args.verbose:
task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Generating samples")
while generator.num_remaining_jobs():
results = generator.iterate()
for result in results:
# End sample if generator says EOS or if there is a non-indented line at the end of the output
job = result["job"]
eos = False
completion = job.full_completion
last_newline_index = completion.rfind("\n")
if last_newline_index >= 0:
last_line = completion[last_newline_index + 1:]
if last_line != "" and not last_line[0].isspace():
completion = completion[:last_newline_index]
eos = True
eos = eos or result["eos"]
# Collect completed sample
if eos:
identifier = result["identifier"]
sample = problems[identifier[0]]["prompt"] + prefix + completion.strip()
if not result["eos"]:
generator.cancel(job)
if args.verbose:
print("----------------------------------------------------------------------")
print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {num_samples_per_task}")
print("----------------------------------------------------------------------")
print(sample)
print()
else:
progress.update(task1, advance = 1)
samples.append(dict(task_id = identifier[0], completion = prefix + completion.strip()))
# Save output
print(f" -- Saving: {args.output}")
write_jsonl(args.output, samples)

View File

@@ -1,160 +0,0 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from human_eval.data import write_jsonl, read_problems
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
model_init
)
from exllamav2.generator import(
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler
)
import torch, argparse
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
# Args
parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
parser.add_argument("-bs", "--batch_size", type = int, default = 10)
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
parser.add_argument("-c8", "--cache_8bit", action = "store_true", help = "Use 8-bit (FP8) cache")
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
parser.add_argument("-f", "--format", type = str, help = "Instruct format to apply")
parser.add_argument("-nv", "--nonverbose", action = "store_true", help = "Don't spam completions to console while generating")
model_init.add_args(parser)
args = parser.parse_args()
# Validate args
directory = os.path.dirname(args.output)
if directory and not os.path.isdir(directory):
print(f" ## Directory for output file {args.output} does not exist.")
sys.exit()
if os.path.exists(args.output):
print(f" !! Warning: Output file exists and will be overwritten.")
# Init model
model_init.check_args(args)
model_init.print_options(args)
model, tokenizer = model_init.init(args, allow_auto_split = True, max_batch_size = args.batch_size)
# Create cache
if args.cache_8bit: cache_type = ExLlamaV2Cache_8bit
elif args.cache_q4: cache_type = ExLlamaV2Cache_Q4
else: cache_type = ExLlamaV2Cache
cache = cache_type(model, lazy = not model.loaded, batch_size = args.batch_size)
# Load model
if not model.loaded:
print(" -- Loading model...")
model.load_autosplit(cache)
# Generator
gen = ExLlamaV2BaseGenerator(model, cache, tokenizer)
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.token_repetition_penalty = 1.0
gen_settings.temperature = 0.8
gen_settings.top_k = 100
gen_settings.top_p = 0.8
# Get problems
problems = read_problems()
num_samples_per_task = args.samples_per_task
samples = []
sub_progress = num_samples_per_task > args.batch_size
formats = {
"granite": (
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n" +
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
" "
)
}
with Progress(
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
BarColumn(bar_width = None),
"[progress.percentage]{task.percentage:>3.0f}%",
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
TimeRemainingColumn(),
) as progress:
task1 = progress.add_task("[red]Problem", total = len(problems), name = "Problems")
for task_id in problems:
rem_samples = num_samples_per_task
sample_idx = 0
if sub_progress: task2 = progress.add_task("[red]Sample", total = num_samples_per_task, name = "Samples", parent = task1)
while rem_samples:
bs = min(args.batch_size, rem_samples)
# Get problem and batch of completions
b_problem = problems[task_id]["prompt"]
if args.format is None:
f, p = "{{problem}}", ""
else:
f, p = formats[args.format]
f_problem = f.replace("{{problem}}", b_problem)
problem = [f_problem] * bs
responses = gen.generate_simple(problem, gen_settings, args.max_tokens, stop_token = tokenizer.eos_token_id, add_bos = True, token_healing = True)
for response in responses:
# Simplified cleanup of response: remove all lines starting from the first line with no indentation,
# i.e. keep exactly one function
r = response[len(problem[0]):]
r = p + r
s = r.split("\n")
crop = len(s)
for l in range(1, len(s)):
if len(s[l]) > 0:
b = s[l][0:1]
if b != " " and b != "\t" and b != "#":
crop = l
break
r = "\n".join(s[:crop])
# Store sample
samples.append(dict(task_id = task_id, completion = r))
# Print to console
if not args.nonverbose:
print("----------------------------------------------------------------------")
print(f" ** Problem {task_id}, sample {sample_idx + 1} / {num_samples_per_task}")
print("----------------------------------------------------------------------")
print(b_problem + p + r)
sample_idx += 1
rem_samples -= bs
if sub_progress: progress.advance(task2, bs)
if sub_progress: progress.remove_task(task2)
progress.update(task1, advance = 1)
# Save output
print(f" -- Saving: {args.output}")
write_jsonl(args.output, samples)