mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Add cosine_error and SQNR measures
This commit is contained in:
@@ -5,6 +5,7 @@ import argparse
|
||||
from exllamav3.util.file import disk_lru_cache
|
||||
from exllamav3.util.progress import ProgressBar
|
||||
from exllamav3.util.memory import free_mem
|
||||
from exllamav3.util.measures import cosine_error, sqnr
|
||||
from exllamav3 import Config, Model, Tokenizer
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
@@ -83,10 +84,14 @@ def main(args):
|
||||
|
||||
max_diff = 0
|
||||
rfn_error_sum = 0
|
||||
cos_error_sum = 0
|
||||
sqnr_sum = 0
|
||||
rows = state_a.shape[0]
|
||||
for i in range(rows):
|
||||
sa = state_a[i].to(float, copy = True)
|
||||
sb = state_b[i].to(float)
|
||||
cos_error_sum += cosine_error(sa, sb)
|
||||
sqnr_sum += sqnr(sa, sb)
|
||||
sa -= sb
|
||||
rfn_error_sum += (torch.linalg.norm(sa, 'fro') / torch.linalg.norm(sb, 'fro').mean()).item()
|
||||
sa.abs_()
|
||||
@@ -95,7 +100,15 @@ def main(args):
|
||||
|
||||
del sa, sb
|
||||
rfn_error = rfn_error_sum / rows
|
||||
print(f" -- {module_a.key:40} error: {rfn_error:.6f} max_diff/norm: {max_diff:.6f}")
|
||||
cos_error = cos_error_sum / rows
|
||||
sqnr_ = sqnr_sum / rows
|
||||
print(
|
||||
f" -- {module_a.key:40}"
|
||||
f" rfn_err: {rfn_error:.6f}"
|
||||
f" max_diff/norm: {max_diff:.6f}"
|
||||
f" sqnr: {sqnr_:9.6f}"
|
||||
f" cos_err: {cos_error:.6f}"
|
||||
)
|
||||
|
||||
# Compare logits
|
||||
topk_max = args.topk_max
|
||||
|
||||
Reference in New Issue
Block a user