mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
compare_q.py: Fix some logic for KLD test
This commit is contained in:
@@ -130,6 +130,7 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
|
||||
|
||||
print(f"Testing: {model_dir} ({spec['label']})")
|
||||
|
||||
collect_logits = False
|
||||
if logits_file:
|
||||
if "out_logits" in spec:
|
||||
collect_logits = True
|
||||
@@ -180,15 +181,17 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pb.update(rows)
|
||||
mean_log_prob = logprob_sum / logprob_count
|
||||
perplexity = math.exp(-mean_log_prob)
|
||||
|
||||
mean_log_prob = logprob_sum / logprob_count
|
||||
perplexity = math.exp(-mean_log_prob)
|
||||
if logits_file:
|
||||
kl_div = kl_div_sum_ab / kl_div_count
|
||||
print(f"KL div: {kl_div:.6f}")
|
||||
|
||||
if collect_logits:
|
||||
save_tensor(ref_logits, logits_file)
|
||||
|
||||
print(f"Perplexity: {perplexity:.6f}")
|
||||
print(f"KL div: {kl_div:.6f}")
|
||||
|
||||
del model_instance
|
||||
del eval_ids
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"tokenize_fn": "transformers",
|
||||
"tokenizer_dir": "/mnt/str/eval_models/llama3.2-1b/hf/",
|
||||
"tokenizer_dir": "/mnt/str/models/llama3.2-1b-instruct/hf/",
|
||||
"dataset": "wiki2",
|
||||
"eval_stride": 512,
|
||||
"eval_len": 2048,
|
||||
|
||||
Reference in New Issue
Block a user