mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-19 22:08:58 +00:00
234 lines
7.8 KiB
Python
234 lines
7.8 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import argparse
|
|
from exllamav3.util.progress import ProgressBar
|
|
from exllamav3 import model_init, Generator, Job, FormatronFilter, GreedySampler
|
|
from formatron.formatter import FormatterBuilder
|
|
from formatron.schemas.dict_inference import infer_mapping
|
|
import torch
|
|
import json
|
|
from collections import Counter
|
|
import math
|
|
import re
|
|
|
|
"""
|
|
Sampling diversity test, highly scientific.
|
|
"""
|
|
|
|
system_prompt = \
|
|
"""You are a creative writing assistant."""
|
|
|
|
prompts = [
|
|
{
|
|
"prompt": (
|
|
"""Write the opening paragraph to a short story about a cat and its owner. The story should at a minimum mention """
|
|
"""the owner's name and the cat's name and color. The ownwer is wearing a colorful dress. Make sure to also """
|
|
"""mention the color of the dress."""
|
|
),
|
|
"questions": [
|
|
("cat_name", "What is the name of the cat in the paragraph above?", "x"),
|
|
("cat_color", "What is the color of the cat in the paragraph above?", "x"),
|
|
("owner_name", "What is the name of the cat's owner in the paragraph above?", "x"),
|
|
("dress_color", "What is the color of the owner's dress in the paragraph above?", "x"),
|
|
]
|
|
},
|
|
{
|
|
"prompt": (
|
|
"""I'm writing a story. Give me the first paragraph, which should describe the main character. Make sure to """
|
|
"""include their name and occupation, and also make it clear which city the story takes place in."""
|
|
),
|
|
"questions": [
|
|
("character_name", "What is the name of the main character in the paragraph above?", "x"),
|
|
("occupation", "What is the occupation of the main character in the paragraph above?", "x"),
|
|
("location", "In which town or city does the story take place?", "x"),
|
|
]
|
|
}
|
|
]
|
|
|
|
post_prompt = \
|
|
""" Answer in JSON format."""
|
|
|
|
prefix_response = \
|
|
"""Here is the answer, in JSON format:\n\n"""
|
|
|
|
|
|
# ANSI codes
|
|
ESC = "\u001b"
|
|
col_default = "\u001b[0m"
|
|
col_yellow = "\u001b[33;1m"
|
|
col_blue = "\u001b[34;1m"
|
|
col_green = "\u001b[32;1m"
|
|
col_red = "\u001b[31;1m"
|
|
col_gray = "\u001b[37;1m"
|
|
|
|
|
|
def diversity_score(samples):
|
|
"""
|
|
Compute score as (1 - P(X1 = X2)) ^ 2 where X1 and X2 are two random samples:
|
|
|
|
0.0 means all samples are the same
|
|
1.0 means all samples are unique
|
|
"""
|
|
n = len(samples)
|
|
if n < 2: return 0.0
|
|
|
|
counts = Counter(samples)
|
|
|
|
# number of matching unordered pairs
|
|
same_pairs = sum(c * (c - 1) for c in counts.values())
|
|
|
|
# total unordered pairs
|
|
total_pairs = n * (n - 1)
|
|
|
|
theta = same_pairs / total_pairs
|
|
return (1 - theta) ** 2
|
|
|
|
|
|
def clean(text: str) -> str:
|
|
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
|
|
text = re.sub(r"<seed:think>.*?</seed:think>", "", text, flags=re.DOTALL)
|
|
text = text.strip()
|
|
return text
|
|
|
|
@torch.inference_mode()
|
|
def main(args):
|
|
|
|
# Load model
|
|
model, config, cache, tokenizer = model_init.init(args)
|
|
generator = Generator(
|
|
model,
|
|
cache,
|
|
tokenizer,
|
|
show_visualizer = args.visualize_cache,
|
|
max_batch_size = args.max_batch_size,
|
|
)
|
|
bpw_layer, bpw_head, vram_bits = model.get_storage_info()
|
|
|
|
print(f" -- Model: {args.model_dir}")
|
|
print(f" -- Bitrate: {bpw_layer:.2f} bpw / {bpw_head:.2f} bpw (head)")
|
|
|
|
def make_job(messages, ex_dict = None, prefix = ""):
|
|
nonlocal args
|
|
input_ids = tokenizer.hf_chat_template(messages, add_generation_prompt = True)
|
|
if ex_dict is not None:
|
|
f = FormatterBuilder()
|
|
schema = infer_mapping(ex_dict)
|
|
f.append_line(f"{prefix}{f.json(schema, capture_name = 'json')}")
|
|
filters = [FormatronFilter(tokenizer, eos_after_completed = True, formatter_builder = f)]
|
|
sampler = GreedySampler()
|
|
else:
|
|
filters = None
|
|
sampler = model_init.get_arg_sampler(args)
|
|
job = Job(
|
|
input_ids = input_ids,
|
|
max_new_tokens = args.max_tokens,
|
|
stop_conditions = config.eos_token_id_list,
|
|
sampler = sampler,
|
|
filters = filters,
|
|
max_rq_tokens = 512,
|
|
)
|
|
return job
|
|
|
|
all_sets = {}
|
|
for p in prompts:
|
|
|
|
# Generate samples
|
|
jobs = []
|
|
for i in range(args.num_samples):
|
|
jobs.append(make_job([
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": p["prompt"]
|
|
}
|
|
]))
|
|
generator.enqueue(jobs)
|
|
with ProgressBar("Inference", len(jobs)) as pb:
|
|
while j := generator.num_remaining_jobs():
|
|
generator.iterate()
|
|
pb.update(len(jobs) - j)
|
|
samples = [clean(job.full_completion) for job in jobs]
|
|
|
|
# Some feedback
|
|
print(f"{col_yellow}\nSample:{col_default}")
|
|
print(f"{col_gray}{samples[0]}{col_default}")
|
|
|
|
# Extract information
|
|
jobs = []
|
|
for i in range(args.num_samples):
|
|
for j in range(len(p["questions"])):
|
|
var, prompt, ex = p["questions"][j]
|
|
prompt += " Anwser in JSON format."
|
|
jobs.append(make_job([
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": samples[i] + "\n\n" + prompt
|
|
}
|
|
], {var: ex}, prefix_response))
|
|
generator.enqueue(jobs)
|
|
with ProgressBar("Inference", len(jobs)) as pb:
|
|
while j := generator.num_remaining_jobs():
|
|
generator.iterate()
|
|
pb.update(len(jobs) - j)
|
|
results = [job.full_completion.strip() for job in jobs]
|
|
|
|
# Parse results
|
|
sets = {v: [] for v, _, _ in p["questions"]}
|
|
for result in results:
|
|
r = result[len(prefix_response):]
|
|
try:
|
|
j = json.loads(r)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
for k, v in j.items():
|
|
sets[k].append(v.strip().lower())
|
|
|
|
# Even more feedback
|
|
print(f"\n{col_yellow}Extracted:{col_default}")
|
|
for k, v in sets.items():
|
|
print(f"{k:20}", end = "")
|
|
c = Counter(v)
|
|
for s, t in c.most_common():
|
|
print(f"{col_blue}{s}{col_gray}: {t}, ", end = "")
|
|
print(f"{col_default}")
|
|
|
|
all_sets.update(sets)
|
|
|
|
# Compute and print scores
|
|
print(f"\n{col_yellow}Scores:{col_default}")
|
|
score_sum = 0.0
|
|
for k, v in all_sets.items():
|
|
score = diversity_score(v)
|
|
print(f"{k:20} {col_green}{score:8.6f}{col_default}")
|
|
score_sum += score
|
|
mean = score_sum / len(all_sets)
|
|
print("-" * 29)
|
|
print(f"{'mean':20} {col_green}{mean:8.6f}{col_default}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
model_init.add_args(
|
|
parser,
|
|
default_cache_size = 32768,
|
|
add_sampling_args = True,
|
|
default_sampling_args = {
|
|
"temperature": 0.8,
|
|
"min_p": 0.08,
|
|
}
|
|
)
|
|
parser.add_argument("-samples", "--num_samples", type = int, help = "Number of samples (default: 50)", default = 50)
|
|
parser.add_argument("-vis", "--visualize_cache", action = "store_true", help = "Show cache visualizer (slow)")
|
|
parser.add_argument("-max_tokens", "--max_tokens", type = int, help = "Max number of tokens per sample (default: 2048)", default = 2048)
|
|
parser.add_argument("-mbs", "--max_batch_size", type = int, help = "Max batch size (default: 16)", default = 16)
|
|
_args = parser.parse_args()
|
|
main(_args)
|