Add Granite formatting to HumanEval test

This commit is contained in:
turboderp
2024-05-09 12:14:57 +02:00
parent 24b1b55a3a
commit d778ebc489

View File

@@ -29,7 +29,9 @@ parser.add_argument("-bs", "--batch_size", type = int, default = 10)
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
parser.add_argument("-c8", "--cache_8bit", action = "store_true", help = "Use 8-bit (FP8) cache")
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
parser.add_argument("--max_tokens", type = int, default = 768)
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
parser.add_argument("-f", "--format", type = str, help = "Instruct format to apply")
parser.add_argument("-nv", "--nonverbose", action = "store_true", help = "Don't spam completions to console while generating")
model_init.add_args(parser)
args = parser.parse_args()
@@ -78,6 +80,14 @@ num_samples_per_task = args.samples_per_task
samples = []
sub_progress = num_samples_per_task > args.batch_size
formats = {
"granite": (
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n" +
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
" "
)
}
with Progress(
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
BarColumn(bar_width = None),
@@ -90,14 +100,22 @@ with Progress(
for task_id in problems:
rem_samples = num_samples_per_task
sample_idx = 0
if sub_progress: task2 = progress.add_task("[red]Sample", total = num_samples_per_task, name = "Samples", parent = task1)
while rem_samples:
bs = min(args.batch_size, rem_samples)
# Get problem and batch of completions
problem = [problems[task_id]["prompt"]] * bs
responses = gen.generate_simple(problem, gen_settings, args.max_tokens, stop_token = tokenizer.eos_token_id, add_bos = True)
b_problem = problems[task_id]["prompt"]
if args.format is None:
f, p = "{{problem}}", ""
else:
f, p = formats[args.format]
f_problem = f.replace("{{problem}}", b_problem)
problem = [f_problem] * bs
responses = gen.generate_simple(problem, gen_settings, args.max_tokens, stop_token = tokenizer.eos_token_id, add_bos = True, token_healing = True)
for response in responses:
@@ -105,7 +123,8 @@ with Progress(
# i.e. keep exactly one function
r = response[len(problem[0]):]
s =r.split("\n")
r = p + r
s = r.split("\n")
crop = len(s)
for l in range(1, len(s)):
if len(s[l]) > 0:
@@ -119,6 +138,16 @@ with Progress(
samples.append(dict(task_id = task_id, completion = r))
# Print to console
if not args.nonverbose:
print("----------------------------------------------------------------------")
print(f" ** Problem {task_id}, sample {sample_idx + 1} / {num_samples_per_task}")
print("----------------------------------------------------------------------")
print(b_problem + p + r)
sample_idx += 1
rem_samples -= bs
if sub_progress: progress.advance(task2, bs)