mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-29 18:51:34 +00:00
compare_q.py: Option to capture logits in streaming mode (for large unquantized models)
This commit is contained in:
@@ -138,6 +138,8 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
|
||||
else:
|
||||
collect_logits = False
|
||||
ref_logits = load_tensor(logits_file)
|
||||
if not isinstance(ref_logits, list):
|
||||
ref_logits = ref_logits.split(1, 0)
|
||||
|
||||
with ProgressBar("Evaluating", rows) as pb:
|
||||
for row in range(rows):
|
||||
@@ -322,16 +324,19 @@ def main(args):
|
||||
m = json.load(f)
|
||||
models_spec += m
|
||||
|
||||
logits_file = None
|
||||
for idx, spec in enumerate(models_spec):
|
||||
if "out_logits" in spec:
|
||||
logits_dir = spec["out_logits"]
|
||||
if not os.path.exists(logits_dir):
|
||||
os.makedirs(logits_dir)
|
||||
logits_file = os.path.join(logits_dir, dict_hash(test_data_spec) + ".safetensors")
|
||||
logits_idx = idx
|
||||
if logits_file is not None:
|
||||
models_spec = [models_spec[logits_idx]] + models_spec[:logits_idx] + models_spec[logits_idx + 1:]
|
||||
if args.logits_file:
|
||||
logits_file = args.logits_file
|
||||
else:
|
||||
logits_file = None
|
||||
for idx, spec in enumerate(models_spec):
|
||||
if "out_logits" in spec:
|
||||
logits_dir = spec["out_logits"]
|
||||
if not os.path.exists(logits_dir):
|
||||
os.makedirs(logits_dir)
|
||||
logits_file = os.path.join(logits_dir, dict_hash(test_data_spec) + ".safetensors")
|
||||
logits_idx = idx
|
||||
if logits_file is not None:
|
||||
models_spec = [models_spec[logits_idx]] + models_spec[:logits_idx] + models_spec[logits_idx + 1:]
|
||||
|
||||
if args.mask:
|
||||
masks = args.mask.split(";")
|
||||
@@ -370,6 +375,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-t", "--title", type = str, default = "Very plot", help = "Plot title")
|
||||
parser.add_argument("-kld", "--kld", action = "store_true", help = "Test KL divergence")
|
||||
parser.add_argument("-mask", "--mask", type = str, help = "Semicolon-separated list of strings to match against model labels for inclusion")
|
||||
parser.add_argument("-lf", "--logits_file", type = str, help = "Reference logits file for KLD", required = False)
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user