compare_q.py: Option to capture logits in streaming mode (for large unquantized models)

This commit is contained in:
turboderp
2025-05-31 01:11:56 +02:00
parent 7aa775b6b3
commit 8ff65b8742
2 changed files with 102 additions and 10 deletions

View File

@@ -138,6 +138,8 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
else:
collect_logits = False
ref_logits = load_tensor(logits_file)
if not isinstance(ref_logits, list):
ref_logits = ref_logits.split(1, 0)
with ProgressBar("Evaluating", rows) as pb:
for row in range(rows):
@@ -322,16 +324,19 @@ def main(args):
m = json.load(f)
models_spec += m
logits_file = None
for idx, spec in enumerate(models_spec):
if "out_logits" in spec:
logits_dir = spec["out_logits"]
if not os.path.exists(logits_dir):
os.makedirs(logits_dir)
logits_file = os.path.join(logits_dir, dict_hash(test_data_spec) + ".safetensors")
logits_idx = idx
if logits_file is not None:
models_spec = [models_spec[logits_idx]] + models_spec[:logits_idx] + models_spec[logits_idx + 1:]
if args.logits_file:
logits_file = args.logits_file
else:
logits_file = None
for idx, spec in enumerate(models_spec):
if "out_logits" in spec:
logits_dir = spec["out_logits"]
if not os.path.exists(logits_dir):
os.makedirs(logits_dir)
logits_file = os.path.join(logits_dir, dict_hash(test_data_spec) + ".safetensors")
logits_idx = idx
if logits_file is not None:
models_spec = [models_spec[logits_idx]] + models_spec[:logits_idx] + models_spec[logits_idx + 1:]
if args.mask:
masks = args.mask.split(";")
@@ -370,6 +375,7 @@ if __name__ == "__main__":
parser.add_argument("-t", "--title", type = str, default = "Very plot", help = "Plot title")
parser.add_argument("-kld", "--kld", action = "store_true", help = "Test KL divergence")
parser.add_argument("-mask", "--mask", type = str, help = "Semicolon-separated list of strings to match against model labels for inclusion")
parser.add_argument("-lf", "--logits_file", type = str, help = "Reference logits file for KLD", required = False)
_args = parser.parse_args()
main(_args)