mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-05-04 05:01:39 +00:00
132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from human_eval.data import write_jsonl, read_problems
|
|
|
|
from exllamav2 import(
|
|
ExLlamaV2,
|
|
ExLlamaV2Config,
|
|
ExLlamaV2Cache,
|
|
ExLlamaV2Cache_Q4,
|
|
ExLlamaV2Cache_8bit,
|
|
ExLlamaV2Tokenizer,
|
|
model_init
|
|
)
|
|
|
|
from exllamav2.generator import(
|
|
ExLlamaV2BaseGenerator,
|
|
ExLlamaV2Sampler
|
|
)
|
|
|
|
import torch, argparse
|
|
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
|
|
|
# Args
|
|
|
|
parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
|
|
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
|
|
parser.add_argument("-bs", "--batch_size", type = int, default = 10)
|
|
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
|
|
parser.add_argument("-c8", "--cache_8bit", action = "store_true", help = "Use 8-bit (FP8) cache")
|
|
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
|
|
parser.add_argument("--max_tokens", type = int, default = 768)
|
|
model_init.add_args(parser)
|
|
args = parser.parse_args()
|
|
|
|
# Validate args
|
|
|
|
directory = os.path.dirname(args.output)
|
|
if directory and not os.path.isdir(directory):
|
|
print(f" ## Directory for output file {args.output} does not exist.")
|
|
sys.exit()
|
|
if os.path.exists(args.output):
|
|
print(f" !! Warning: Output file exists and will be overwritten.")
|
|
|
|
# Init model
|
|
|
|
model_init.check_args(args)
|
|
model_init.print_options(args)
|
|
model, tokenizer = model_init.init(args, allow_auto_split = True, max_batch_size = args.batch_size)
|
|
|
|
# Create cache
|
|
|
|
if args.cache_8bit: cache_type = ExLlamaV2Cache_8bit
|
|
elif args.cache_q4: cache_type = ExLlamaV2Cache_Q4
|
|
else: cache_type = ExLlamaV2Cache
|
|
cache = cache_type(model, lazy = not model.loaded, batch_size = args.batch_size)
|
|
|
|
# Load model
|
|
|
|
if not model.loaded:
|
|
|
|
print(" -- Loading model...")
|
|
model.load_autosplit(cache)
|
|
|
|
# Generator
|
|
|
|
gen = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
|
gen_settings = ExLlamaV2Sampler.Settings()
|
|
gen_settings.token_repetition_penalty = 1.0
|
|
gen_settings.temperature = 0.8
|
|
gen_settings.top_k = 100
|
|
gen_settings.top_p = 0.8
|
|
|
|
# Get problems
|
|
|
|
problems = read_problems()
|
|
num_samples_per_task = args.samples_per_task
|
|
samples = []
|
|
sub_progress = num_samples_per_task > args.batch_size
|
|
|
|
with Progress(
|
|
TextColumn("[bold blue]{task.fields[name]}", justify = "left"),
|
|
BarColumn(bar_width = None),
|
|
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
TextColumn("{task.completed: 4} of {task.total: 4}", justify = "right"),
|
|
TimeRemainingColumn(),
|
|
) as progress:
|
|
|
|
task1 = progress.add_task("[red]Problem", total = len(problems), name = "Problems")
|
|
for task_id in problems:
|
|
|
|
rem_samples = num_samples_per_task
|
|
if sub_progress: task2 = progress.add_task("[red]Sample", total = num_samples_per_task, name = "Samples", parent = task1)
|
|
while rem_samples:
|
|
bs = min(args.batch_size, rem_samples)
|
|
|
|
# Get problem and batch of completions
|
|
|
|
problem = [problems[task_id]["prompt"]] * bs
|
|
responses = gen.generate_simple(problem, gen_settings, args.max_tokens, stop_token = tokenizer.eos_token_id, add_bos = True)
|
|
|
|
for response in responses:
|
|
|
|
# Simplified cleanup of response: remove all lines starting from the first line with no indentation,
|
|
# i.e. keep exactly one function
|
|
|
|
r = response[len(problem[0]):]
|
|
s =r.split("\n")
|
|
crop = len(s)
|
|
for l in range(1, len(s)):
|
|
if len(s[l]) > 0:
|
|
b = s[l][0:1]
|
|
if b != " " and b != "\t" and b != "#":
|
|
crop = l
|
|
break
|
|
r = "\n".join(s[:crop])
|
|
|
|
# Store sample
|
|
|
|
samples.append(dict(task_id = task_id, completion = r))
|
|
|
|
rem_samples -= bs
|
|
if sub_progress: progress.advance(task2, bs)
|
|
|
|
if sub_progress: progress.remove_task(task2)
|
|
progress.update(task1, advance = 1)
|
|
|
|
# Save output
|
|
|
|
print(f" -- Saving: {args.output}")
|
|
write_jsonl(args.output, samples)
|