Files
exllamav3/eval/compare_q.py

414 lines
13 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn.functional as F
from exllamav3.util.file import disk_lru_cache, disk_lru_cache_clear
from exllamav3.util.progress import ProgressBar
from exllamav3.util.memory import free_mem
from datasets import load_dataset
import math
import argparse
import json
import matplotlib.pyplot as plt
from adjustText import adjust_text
import glob
from safetensors.torch import save_file
from safetensors import safe_open
import gc
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 200)
# Lookup tables to ensure test functions are cacheable
from compare_q_transformers import (
load_transformers_auto_bf16,
load_transformers_auto,
load_transformers,
fwd_transformers,
tokenize_transformers
)
from compare_q_exllamav2 import (
load_exllamav2,
fwd_exllamav2
)
from compare_q_exllamav3 import (
load_exllamav3,
fwd_exllamav3
)
from compare_q_llamacpp import (
load_llamacpp,
fwd_llamacpp
)
from compare_q_anyprecision import (
load_anyprecision,
fwd_anyprecision,
)
from compare_q_qtip import (
load_qtip,
fwd_qtip,
)
load_fns = {
"transformers_auto_bf16": load_transformers_auto_bf16,
"transformers_auto": load_transformers_auto,
"transformers": load_transformers,
"exllamav2": load_exllamav2,
"exllamav3": load_exllamav3,
"llamacpp": load_llamacpp,
"anyprecision": load_anyprecision,
"qtip": load_qtip,
}
fwd_fns = {
"transformers": fwd_transformers,
"exllamav2": fwd_exllamav2,
"exllamav3": fwd_exllamav3,
"llamacpp": fwd_llamacpp,
"anyprecision": fwd_anyprecision,
"qtip": fwd_qtip,
}
tokenize_fns = {
"transformers": tokenize_transformers,
}
# Util fn
def load_tensor(filename):
with safe_open(filename, framework = "pt", device = "cpu") as f:
if "tensor" in f.keys():
return f.get_tensor("tensor")
else:
tensors = []
i = 0
while f"tensor.{i}" in f.keys():
tensors.append(f.get_tensor(f"tensor.{i}"))
i += 1
return tensors
def save_tensor(tensor, filename: str):
if isinstance(tensor, dict):
save_file({k: v for k, v in tensor.items()}, filename)
elif isinstance(tensor, list):
save_file({f"tensor.{i}": t for i, t in enumerate(tensor)}, filename)
else:
save_file({f"tensor": tensor}, filename)
# Tokenize ppl test data
@disk_lru_cache("get_dataset")
def get_test_data(spec: dict):
tokenize_fn = tokenize_fns[spec["tokenize_fn"]]
assert spec["dataset"] == "wiki2", "Only wiki2 implemented atm"
eval_stride = spec["eval_stride"]
eval_len = spec["eval_len"]
max_rows = spec.get("max_rows", 0)
eval_tokens = tokenize_fn(
spec["tokenizer_dir"],
"\n\n".join(
load_dataset("wikitext", "wikitext-2-raw-v1", split = "test")
["text"]
)
)
num_tokens = eval_tokens.shape[-1]
seqs = []
for a in range(0, num_tokens - eval_len, eval_stride):
b = a + eval_len
seqs.append(eval_tokens[:, a:b])
if max_rows and len(seqs) >= max_rows:
break
eval_tokens = torch.cat(seqs, dim = 0)[:, :]
return eval_tokens
# Run ppl test
@disk_lru_cache("test_ppl")
def test_ppl(data_spec: dict, spec: dict, logits_file: str):
load_fn = load_fns[spec["load_fn"]]
fwd_fn = fwd_fns[spec["fwd_fn"]]
model_dir = spec["model_dir"]
print(f"Loading dataset: {data_spec['dataset']}")
eval_ids = get_test_data(data_spec)
rows = eval_ids.shape[0]
print(f"Loading: {model_dir}")
model_instance, bpw_layer, bpw_head, vram_bits = load_fn(model_dir)
vram_gb = vram_bits / 8 / 1024**3
logprob_sum = 0.0
logprob_count = 0
kl_div_sum_ab = 0.0
kl_div_count = 0.0
print(f"Testing: {model_dir} ({spec['label']})")
collect_logits = False
if logits_file:
if "out_logits" in spec:
collect_logits = True
ref_logits = []
else:
collect_logits = False
ref_logits = load_tensor(logits_file)
if not isinstance(ref_logits, list):
ref_logits = ref_logits.split(1, 0)
with ProgressBar("Evaluating", rows) as pb:
for row in range(rows):
pb.update(row)
input_ids = eval_ids[row:row + 1, :]
logits = fwd_fn(model_instance, input_ids)
logits = logits.float()
# kld
if logits_file and row < 10:
probs_a = torch.softmax(logits, dim = -1)
if collect_logits:
ref_logits.append(logits.cpu())
kl_div_count += 1
else:
probs_b = torch.softmax(ref_logits[row].to(logits.device), dim = -1)
vs = min(probs_a.shape[-1], probs_b.shape[-1])
probs_a = probs_a[..., :vs]
probs_b = probs_b[..., :vs]
for r in range(probs_a.shape[1]):
kl_div = F.kl_div(torch.log(probs_a[:, r:r+1, :] + 1e-10), probs_b[:, r:r+1, :], reduction = 'sum')
kl_div_sum_ab += kl_div.item()
kl_div_count += 1
del kl_div
del probs_b
del probs_a
# ppl
logits = logits[:, :-1, :]
logits += 1e-10
log_probs = F.log_softmax(logits, dim = -1)
del logits
target_ids = input_ids[:, 1:].to(log_probs.device)
del input_ids
target_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
del log_probs
logprob_sum += target_log_probs.sum().item()
logprob_count += target_ids.numel()
del target_log_probs
del target_ids
torch.cuda.empty_cache()
pb.update(rows)
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
if logits_file:
kl_div = kl_div_sum_ab / kl_div_count
print(f"KL div: {kl_div:.6f}")
if collect_logits:
save_tensor(ref_logits, logits_file)
print(f"Perplexity: {perplexity:.6f}")
del model_instance
del eval_ids
free_mem()
res = {
"label": spec.get("label", spec.get("model_dir")),
"layer_bpw": bpw_layer,
"head_bpw": bpw_head,
"vram_gb": vram_gb,
"ppl": perplexity
}
if logits_file:
res.update({
"kld": kl_div
})
return res
def plot(results, args):
def col(light, dark):
return dark if args.dark else light
if args.dark:
plt.style.use('dark_background')
def get_color(s):
d = {
"EXL2": col("green", "greenyellow"),
"EXL3": col("purple", "palevioletred"),
"AWQ": col("olive", "tan"),
"imat": col("brown", "darkorange"),
"GGUF": col("red", "tomato"),
"VPTQ": col("blue", "cornflowerblue"),
"QTIP": col("teal", "lightseagreen"),
"****": col("black", "silver"),
}
for k, v in d.items():
if f"[{v}]" in s:
return v
for k, v in d.items():
if k in s:
return v
return col("black", "silver")
plt.rcParams["figure.figsize"] = (14, 11)
plt.subplots_adjust(left = 0.05, right = 0.95, top = 0.95, bottom = 0.05)
lpoints = {}
x = []
y = []
labels = []
colors = []
for r in results:
x_ = r["vram_gb"] if args.vram else r["layer_bpw"]
y_ = r["ppl"] if not args.kld else r["kld"]
if x_ > args.max_x or y_ > args.max_y:
continue
x.append(x_)
y.append(y_)
labels.append(r["label"].split("[")[0].strip() + f"\n{y_:.3f}")
color = get_color(r["label"])
colors.append(color)
if color != col("black", "silver"):
if color not in lpoints:
lpoints[color] = []
lpoints[color].append((x_, y_))
plt.scatter(x, y, c = colors, marker = "o")
texts = []
for i, label in enumerate(labels):
texts.append(
plt.text(
x[i],
y[i],
label,
fontsize = 8.5,
ha = "left",
va = "bottom",
color = colors[i],
)
)
adjust_text(
texts,
x = x,
y = y,
arrowprops = {"arrowstyle": "->", "color": col("lightgray", "dimgray")},
expand = (1.35, 2.3),
ensure_inside_axes = True,
min_arrow_len = 0.10,
prevent_crossings = False,
pull_threshold = 0.20,
# force_explode = (0.2, 0.6),
max_move = 100
)
for col, lines in lpoints.items():
x, y = zip(*sorted(lines))
plt.plot(x, y, color = col, linestyle=':')
plt.xlabel("VRAM // GB (decoder + head)" if args.vram else "bits per weight (decoder only)")
plt.ylabel("Perplexity" if not args.kld else "KL divergence")
plt.title(args.title)
if args.dark:
plt.grid(color = 'dimgray', linestyle = '--', linewidth = 0.5)
else:
plt.grid(True)
if args.plot_file:
plt.savefig(args.plot_file)
else:
plt.show()
def dict_hash(x: dict) -> str:
import hashlib
key = str(json.dumps(x, sort_keys = True))
encoded_string = key.encode('utf-8')
hash_object = hashlib.sha256(encoded_string)
hex_digest = hash_object.hexdigest()
return hex_digest
@torch.inference_mode()
def main(args):
with open(args.dataspec, "r", encoding = "utf8") as f:
test_data_spec = json.load(f)
models_files = args.modelspec
models_files_g = []
models_spec = []
for filename in models_files:
if "*" in filename:
models_files_g += glob.glob(filename)
else:
models_files_g.append(filename)
for filename in models_files_g:
with open(filename, "r", encoding = "utf8") as f:
m = json.load(f)
models_spec += m
if args.logits_file:
logits_file = args.logits_file
else:
logits_file = None
for idx, spec in enumerate(models_spec):
if "out_logits" in spec:
logits_dir = spec["out_logits"]
if not os.path.exists(logits_dir):
os.makedirs(logits_dir)
logits_file = os.path.join(logits_dir, dict_hash(test_data_spec) + ".safetensors")
logits_idx = idx
if logits_file is not None:
models_spec = [models_spec[logits_idx]] + models_spec[:logits_idx] + models_spec[logits_idx + 1:]
if args.mask:
masks = args.mask.split(";")
ms = []
for spec in models_spec:
if any(m.upper() in spec["label"].upper() for m in masks):
ms.append(spec)
models_spec = ms
if args.clear_cache:
for spec in models_spec:
disk_lru_cache_clear("test_ppl", test_data_spec, spec, logits_file)
results = []
for spec in models_spec:
r = test_ppl(test_data_spec, spec, logits_file)
print(r)
results.append(r)
torch.cuda.empty_cache()
gc.collect()
print("------")
print(json.dumps(results, indent = 4))
if args.plot:
plot(results, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataspec", type = str, help = "Data specification (JSON file)")
parser.add_argument("-m", "--modelspec", type = str, nargs="+", help = "Model specification (JSONL file), accepts wildcard")
parser.add_argument("-cc", "--clear_cache", action = "store_true", help = "Clear cache")
parser.add_argument("-p", "--plot", action = "store_true", help = "Scatter plot")
parser.add_argument("-v", "--vram", action = "store_true", help = "Use VRAM footprint as scatter plot X axis")
parser.add_argument("-mx", "--max_x", type = float, default = 999999, help = "Don't plot results beyond X value")
parser.add_argument("-my", "--max_y", type = float, default = 999999, help = "Don't plot results beyond Y value")
parser.add_argument("-t", "--title", type = str, default = "Very plot", help = "Plot title")
parser.add_argument("-kld", "--kld", action = "store_true", help = "Test KL divergence")
parser.add_argument("-mask", "--mask", type = str, help = "Semicolon-separated list of strings to match against model labels for inclusion")
parser.add_argument("-lf", "--logits_file", type = str, help = "Reference logits file for KLD", required = False)
parser.add_argument("-dark", "--dark", action = "store_true", help = "Dark mode")
parser.add_argument("-pf", "--plot_file", type = str, help = "Write the plot to a file")
_args = parser.parse_args()
main(_args)