compare_q.py: Fix some logic for KLD test

This commit is contained in:
turboderp
2025-05-18 21:55:26 +02:00
parent d860f8e1e1
commit c0a2028fb5
2 changed files with 7 additions and 4 deletions

View File

@@ -130,6 +130,7 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
print(f"Testing: {model_dir} ({spec['label']})")
collect_logits = False
if logits_file:
if "out_logits" in spec:
collect_logits = True
@@ -180,15 +181,17 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
torch.cuda.empty_cache()
pb.update(rows)
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
if logits_file:
kl_div = kl_div_sum_ab / kl_div_count
print(f"KL div: {kl_div:.6f}")
if collect_logits:
save_tensor(ref_logits, logits_file)
print(f"Perplexity: {perplexity:.6f}")
print(f"KL div: {kl_div:.6f}")
del model_instance
del eval_ids

View File

@@ -1,6 +1,6 @@
{
"tokenize_fn": "transformers",
"tokenizer_dir": "/mnt/str/eval_models/llama3.2-1b/hf/",
"tokenizer_dir": "/mnt/str/models/llama3.2-1b-instruct/hf/",
"dataset": "wiki2",
"eval_stride": 512,
"eval_len": 2048,