Add cosine_error and SQNR measures

This commit is contained in:
turboderp
2025-05-30 19:43:20 +02:00
parent 34d2f1f5fa
commit 2cc8f718da
3 changed files with 44 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ import argparse
from exllamav3.util.file import disk_lru_cache
from exllamav3.util.progress import ProgressBar
from exllamav3.util.memory import free_mem
from exllamav3.util.measures import cosine_error, sqnr
from exllamav3 import Config, Model, Tokenizer
from datasets import load_dataset
import torch
@@ -83,10 +84,14 @@ def main(args):
max_diff = 0
rfn_error_sum = 0
cos_error_sum = 0
sqnr_sum = 0
rows = state_a.shape[0]
for i in range(rows):
sa = state_a[i].to(float, copy = True)
sb = state_b[i].to(float)
cos_error_sum += cosine_error(sa, sb)
sqnr_sum += sqnr(sa, sb)
sa -= sb
rfn_error_sum += (torch.linalg.norm(sa, 'fro') / torch.linalg.norm(sb, 'fro').mean()).item()
sa.abs_()
@@ -95,7 +100,15 @@ def main(args):
del sa, sb
rfn_error = rfn_error_sum / rows
print(f" -- {module_a.key:40} error: {rfn_error:.6f} max_diff/norm: {max_diff:.6f}")
cos_error = cos_error_sum / rows
sqnr_ = sqnr_sum / rows
print(
f" -- {module_a.key:40}"
f" rfn_err: {rfn_error:.6f}"
f" max_diff/norm: {max_diff:.6f}"
f" sqnr: {sqnr_:9.6f}"
f" cos_err: {cos_error:.6f}"
)
# Compare logits
topk_max = args.topk_max