Files
exllamav3/eval/diversity.py
2026-01-14 21:40:31 +01:00

234 lines
7.8 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
from exllamav3.util.progress import ProgressBar
from exllamav3 import model_init, Generator, Job, FormatronFilter, GreedySampler
from formatron.formatter import FormatterBuilder
from formatron.schemas.dict_inference import infer_mapping
import torch
import json
from collections import Counter
import math
import re
"""
Sampling diversity test, highly scientific.
"""
system_prompt = \
"""You are a creative writing assistant."""
prompts = [
{
"prompt": (
"""Write the opening paragraph to a short story about a cat and its owner. The story should at a minimum mention """
"""the owner's name and the cat's name and color. The ownwer is wearing a colorful dress. Make sure to also """
"""mention the color of the dress."""
),
"questions": [
("cat_name", "What is the name of the cat in the paragraph above?", "x"),
("cat_color", "What is the color of the cat in the paragraph above?", "x"),
("owner_name", "What is the name of the cat's owner in the paragraph above?", "x"),
("dress_color", "What is the color of the owner's dress in the paragraph above?", "x"),
]
},
{
"prompt": (
"""I'm writing a story. Give me the first paragraph, which should describe the main character. Make sure to """
"""include their name and occupation, and also make it clear which city the story takes place in."""
),
"questions": [
("character_name", "What is the name of the main character in the paragraph above?", "x"),
("occupation", "What is the occupation of the main character in the paragraph above?", "x"),
("location", "In which town or city does the story take place?", "x"),
]
}
]
post_prompt = \
""" Answer in JSON format."""
prefix_response = \
"""Here is the answer, in JSON format:\n\n"""
# ANSI codes
ESC = "\u001b"
col_default = "\u001b[0m"
col_yellow = "\u001b[33;1m"
col_blue = "\u001b[34;1m"
col_green = "\u001b[32;1m"
col_red = "\u001b[31;1m"
col_gray = "\u001b[37;1m"
def diversity_score(samples):
"""
Compute score as (1 - P(X1 = X2)) ^ 2 where X1 and X2 are two random samples:
0.0 means all samples are the same
1.0 means all samples are unique
"""
n = len(samples)
if n < 2: return 0.0
counts = Counter(samples)
# number of matching unordered pairs
same_pairs = sum(c * (c - 1) for c in counts.values())
# total unordered pairs
total_pairs = n * (n - 1)
theta = same_pairs / total_pairs
return (1 - theta) ** 2
def clean(text: str) -> str:
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
text = re.sub(r"<seed:think>.*?</seed:think>", "", text, flags=re.DOTALL)
text = text.strip()
return text
@torch.inference_mode()
def main(args):
# Load model
model, config, cache, tokenizer = model_init.init(args)
generator = Generator(
model,
cache,
tokenizer,
show_visualizer = args.visualize_cache,
max_batch_size = args.max_batch_size,
)
bpw_layer, bpw_head, vram_bits = model.get_storage_info()
print(f" -- Model: {args.model_dir}")
print(f" -- Bitrate: {bpw_layer:.2f} bpw / {bpw_head:.2f} bpw (head)")
def make_job(messages, ex_dict = None, prefix = ""):
nonlocal args
input_ids = tokenizer.hf_chat_template(messages, add_generation_prompt = True)
if ex_dict is not None:
f = FormatterBuilder()
schema = infer_mapping(ex_dict)
f.append_line(f"{prefix}{f.json(schema, capture_name = 'json')}")
filters = [FormatronFilter(tokenizer, eos_after_completed = True, formatter_builder = f)]
sampler = GreedySampler()
else:
filters = None
sampler = model_init.get_arg_sampler(args)
job = Job(
input_ids = input_ids,
max_new_tokens = args.max_tokens,
stop_conditions = config.eos_token_id_list,
sampler = sampler,
filters = filters,
max_rq_tokens = 512,
)
return job
all_sets = {}
for p in prompts:
# Generate samples
jobs = []
for i in range(args.num_samples):
jobs.append(make_job([
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": p["prompt"]
}
]))
generator.enqueue(jobs)
with ProgressBar("Inference", len(jobs)) as pb:
while j := generator.num_remaining_jobs():
generator.iterate()
pb.update(len(jobs) - j)
samples = [clean(job.full_completion) for job in jobs]
# Some feedback
print(f"{col_yellow}\nSample:{col_default}")
print(f"{col_gray}{samples[0]}{col_default}")
# Extract information
jobs = []
for i in range(args.num_samples):
for j in range(len(p["questions"])):
var, prompt, ex = p["questions"][j]
prompt += " Anwser in JSON format."
jobs.append(make_job([
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": samples[i] + "\n\n" + prompt
}
], {var: ex}, prefix_response))
generator.enqueue(jobs)
with ProgressBar("Inference", len(jobs)) as pb:
while j := generator.num_remaining_jobs():
generator.iterate()
pb.update(len(jobs) - j)
results = [job.full_completion.strip() for job in jobs]
# Parse results
sets = {v: [] for v, _, _ in p["questions"]}
for result in results:
r = result[len(prefix_response):]
try:
j = json.loads(r)
except json.JSONDecodeError:
continue
for k, v in j.items():
sets[k].append(v.strip().lower())
# Even more feedback
print(f"\n{col_yellow}Extracted:{col_default}")
for k, v in sets.items():
print(f"{k:20}", end = "")
c = Counter(v)
for s, t in c.most_common():
print(f"{col_blue}{s}{col_gray}: {t}, ", end = "")
print(f"{col_default}")
all_sets.update(sets)
# Compute and print scores
print(f"\n{col_yellow}Scores:{col_default}")
score_sum = 0.0
for k, v in all_sets.items():
score = diversity_score(v)
print(f"{k:20} {col_green}{score:8.6f}{col_default}")
score_sum += score
mean = score_sum / len(all_sets)
print("-" * 29)
print(f"{'mean':20} {col_green}{mean:8.6f}{col_default}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
model_init.add_args(
parser,
default_cache_size = 32768,
add_sampling_args = True,
default_sampling_args = {
"temperature": 0.8,
"min_p": 0.08,
}
)
parser.add_argument("-samples", "--num_samples", type = int, help = "Number of samples (default: 50)", default = 50)
parser.add_argument("-vis", "--visualize_cache", action = "store_true", help = "Show cache visualizer (slow)")
parser.add_argument("-max_tokens", "--max_tokens", type = int, help = "Max number of tokens per sample (default: 2048)", default = 2048)
parser.add_argument("-mbs", "--max_batch_size", type = int, help = "Max batch size (default: 16)", default = 16)
_args = parser.parse_args()
main(_args)