Add cosine_error and SQNR measures

This commit is contained in:
turboderp
2025-05-30 19:43:20 +02:00
parent 34d2f1f5fa
commit 2cc8f718da
3 changed files with 44 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ import argparse
from exllamav3.util.file import disk_lru_cache
from exllamav3.util.progress import ProgressBar
from exllamav3.util.memory import free_mem
from exllamav3.util.measures import cosine_error, sqnr
from exllamav3 import Config, Model, Tokenizer
from datasets import load_dataset
import torch
@@ -83,10 +84,14 @@ def main(args):
max_diff = 0
rfn_error_sum = 0
cos_error_sum = 0
sqnr_sum = 0
rows = state_a.shape[0]
for i in range(rows):
sa = state_a[i].to(float, copy = True)
sb = state_b[i].to(float)
cos_error_sum += cosine_error(sa, sb)
sqnr_sum += sqnr(sa, sb)
sa -= sb
rfn_error_sum += (torch.linalg.norm(sa, 'fro') / torch.linalg.norm(sb, 'fro').mean()).item()
sa.abs_()
@@ -95,7 +100,15 @@ def main(args):
del sa, sb
rfn_error = rfn_error_sum / rows
print(f" -- {module_a.key:40} error: {rfn_error:.6f} max_diff/norm: {max_diff:.6f}")
cos_error = cos_error_sum / rows
sqnr_ = sqnr_sum / rows
print(
f" -- {module_a.key:40}"
f" rfn_err: {rfn_error:.6f}"
f" max_diff/norm: {max_diff:.6f}"
f" sqnr: {sqnr_:9.6f}"
f" cos_err: {cos_error:.6f}"
)
# Compare logits
topk_max = args.topk_max

View File

@@ -9,6 +9,7 @@ from ..util.progress import ProgressBar
from ..util.memory import free_mem
from ..util import Timer, human_time
from ..util.tensor import save_tensor_image
from ..util.measures import cosine_error, sqnr
from .calibration_data import get_default_calibration
from .compile import compile_model, dsize
from safetensors.torch import save_file
@@ -227,7 +228,9 @@ def get_state_error(x, ref):
x = x.view(-1, x.shape[-1]).float()
ref = ref.view(-1, ref.shape[-1]).float()
err = torch.linalg.norm(x - ref, 'fro') / torch.linalg.norm(ref, 'fro')
return err.item()
sq = sqnr(x, ref)
cos = cosine_error(x, ref)
return err.item(), cos, sq
@torch.inference_mode()
@@ -430,6 +433,8 @@ def main(args, job_state):
# Advance state
error = 0
cos_error = 0
sqnr_ = 0
with ProgressBar(f" -- Forward pass: {module.key}", len(state)) as progress:
for i in range(len(state)):
progress.update(i)
@@ -441,16 +446,23 @@ def main(args, job_state):
state[i] = module.forward(state[i], params).cpu()
if i < num_ref_states and len(linears):
ref_states[i] = ref_states[i].to(state[i].device)
error += get_state_error(state[i], ref_states[i])
rfn, cos, sq = get_state_error(state[i], ref_states[i])
error += rfn
cos_error += cos
sqnr_ += sq
ref_states[i] = None
error /= num_ref_states
cos_error /= num_ref_states
sqnr_ /= num_ref_states
# Feedback after module
module_time = time.time() - start_module_time
print(
f" -- Quantized: {module.key:{config.stc.max_key_len() + 8}}" +
(f" bpw: {final_bpw:5.2f}" if final_bpw else f" no_weights") +
(f" rfn: {error:.6f}" if module.num_slices == 1 else " rfn: N/A ") +
(f" rfn: {error:.6f}" if module.num_slices == 1 else " rfn: N/A ") +
f" cos: {cos_error:.6f}"
f" sqnr: {sqnr_:.6f}"
f" [{module_time:.2f} s]"
)
sys.stdout.flush()

View File

@@ -0,0 +1,15 @@
import torch
import torch.nn.functional as F
def sqnr(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
a_flat = a.view(a.shape[0], -1)
b_flat = b.view(b.shape[0], -1)
signal_power = torch.sum(b_flat ** 2, dim = 1)
noise_power = torch.sum((a_flat - b_flat) ** 2, dim = 1) + eps
return 10.0 * torch.log10(signal_power / noise_power).mean().item() # dB
def cosine_error(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
a_flat = a.view(a.shape[0], -1)
b_flat = b.view(b.shape[0], -1)
cos_sim = F.cosine_similarity(a_flat, b_flat, dim = 1, eps = eps)
return 1.0 - cos_sim.mean().item()