Test script: Allow --eval_rows in wiki2 ppl test

This commit is contained in:
turboderp
2025-01-09 11:14:48 +01:00
parent 318435db81
commit c8fa853c89

View File

@@ -44,7 +44,7 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
# (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
parser.add_argument("-er", "--eval_rows", type = int, default = None, help = "Number of rows to apply from dataset (default 128)")
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
@@ -267,13 +267,15 @@ if args.eval_dataset or args.standard_perplexity:
seqs.append(eval_tokens[:, a:b])
eval_len.append(b if a == 0 else stride)
a += stride
if args.eval_rows and len(seqs) >= args.eval_rows:
break
eval_tokens = torch.cat(seqs, dim = 0)
else:
eval_dataset = args.eval_dataset
eval_rows = args.eval_rows
eval_rows = args.eval_rows or 128
eval_length = args.eval_length
print(f" -- Dataset: {eval_dataset}")