mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
Add cosine_error and SQNR measures
This commit is contained in:
@@ -5,6 +5,7 @@ import argparse
|
||||
from exllamav3.util.file import disk_lru_cache
|
||||
from exllamav3.util.progress import ProgressBar
|
||||
from exllamav3.util.memory import free_mem
|
||||
from exllamav3.util.measures import cosine_error, sqnr
|
||||
from exllamav3 import Config, Model, Tokenizer
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
@@ -83,10 +84,14 @@ def main(args):
|
||||
|
||||
max_diff = 0
|
||||
rfn_error_sum = 0
|
||||
cos_error_sum = 0
|
||||
sqnr_sum = 0
|
||||
rows = state_a.shape[0]
|
||||
for i in range(rows):
|
||||
sa = state_a[i].to(float, copy = True)
|
||||
sb = state_b[i].to(float)
|
||||
cos_error_sum += cosine_error(sa, sb)
|
||||
sqnr_sum += sqnr(sa, sb)
|
||||
sa -= sb
|
||||
rfn_error_sum += (torch.linalg.norm(sa, 'fro') / torch.linalg.norm(sb, 'fro').mean()).item()
|
||||
sa.abs_()
|
||||
@@ -95,7 +100,15 @@ def main(args):
|
||||
|
||||
del sa, sb
|
||||
rfn_error = rfn_error_sum / rows
|
||||
print(f" -- {module_a.key:40} error: {rfn_error:.6f} max_diff/norm: {max_diff:.6f}")
|
||||
cos_error = cos_error_sum / rows
|
||||
sqnr_ = sqnr_sum / rows
|
||||
print(
|
||||
f" -- {module_a.key:40}"
|
||||
f" rfn_err: {rfn_error:.6f}"
|
||||
f" max_diff/norm: {max_diff:.6f}"
|
||||
f" sqnr: {sqnr_:9.6f}"
|
||||
f" cos_err: {cos_error:.6f}"
|
||||
)
|
||||
|
||||
# Compare logits
|
||||
topk_max = args.topk_max
|
||||
|
||||
@@ -9,6 +9,7 @@ from ..util.progress import ProgressBar
|
||||
from ..util.memory import free_mem
|
||||
from ..util import Timer, human_time
|
||||
from ..util.tensor import save_tensor_image
|
||||
from ..util.measures import cosine_error, sqnr
|
||||
from .calibration_data import get_default_calibration
|
||||
from .compile import compile_model, dsize
|
||||
from safetensors.torch import save_file
|
||||
@@ -227,7 +228,9 @@ def get_state_error(x, ref):
|
||||
x = x.view(-1, x.shape[-1]).float()
|
||||
ref = ref.view(-1, ref.shape[-1]).float()
|
||||
err = torch.linalg.norm(x - ref, 'fro') / torch.linalg.norm(ref, 'fro')
|
||||
return err.item()
|
||||
sq = sqnr(x, ref)
|
||||
cos = cosine_error(x, ref)
|
||||
return err.item(), cos, sq
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -430,6 +433,8 @@ def main(args, job_state):
|
||||
|
||||
# Advance state
|
||||
error = 0
|
||||
cos_error = 0
|
||||
sqnr_ = 0
|
||||
with ProgressBar(f" -- Forward pass: {module.key}", len(state)) as progress:
|
||||
for i in range(len(state)):
|
||||
progress.update(i)
|
||||
@@ -441,16 +446,23 @@ def main(args, job_state):
|
||||
state[i] = module.forward(state[i], params).cpu()
|
||||
if i < num_ref_states and len(linears):
|
||||
ref_states[i] = ref_states[i].to(state[i].device)
|
||||
error += get_state_error(state[i], ref_states[i])
|
||||
rfn, cos, sq = get_state_error(state[i], ref_states[i])
|
||||
error += rfn
|
||||
cos_error += cos
|
||||
sqnr_ += sq
|
||||
ref_states[i] = None
|
||||
error /= num_ref_states
|
||||
cos_error /= num_ref_states
|
||||
sqnr_ /= num_ref_states
|
||||
|
||||
# Feedback after module
|
||||
module_time = time.time() - start_module_time
|
||||
print(
|
||||
f" -- Quantized: {module.key:{config.stc.max_key_len() + 8}}" +
|
||||
(f" bpw: {final_bpw:5.2f}" if final_bpw else f" no_weights") +
|
||||
(f" rfn: {error:.6f}" if module.num_slices == 1 else " rfn: N/A ") +
|
||||
(f" rfn: {error:.6f}" if module.num_slices == 1 else " rfn: N/A ") +
|
||||
f" cos: {cos_error:.6f}"
|
||||
f" sqnr: {sqnr_:.6f}"
|
||||
f" [{module_time:.2f} s]"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
15
exllamav3/util/measures.py
Normal file
15
exllamav3/util/measures.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def sqnr(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
|
||||
a_flat = a.view(a.shape[0], -1)
|
||||
b_flat = b.view(b.shape[0], -1)
|
||||
signal_power = torch.sum(b_flat ** 2, dim = 1)
|
||||
noise_power = torch.sum((a_flat - b_flat) ** 2, dim = 1) + eps
|
||||
return 10.0 * torch.log10(signal_power / noise_power).mean().item() # dB
|
||||
|
||||
def cosine_error(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
|
||||
a_flat = a.view(a.shape[0], -1)
|
||||
b_flat = b.view(b.shape[0], -1)
|
||||
cos_sim = F.cosine_similarity(a_flat, b_flat, dim = 1, eps = eps)
|
||||
return 1.0 - cos_sim.mean().item()
|
||||
Reference in New Issue
Block a user