From c8fa853c894f6104b47aae7efb18f95496b4f3ac Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 9 Jan 2025 11:14:48 +0100 Subject: [PATCH] Test script: Allow --eval_rows in wiki2 ppl test --- test_inference.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test_inference.py b/test_inference.py index 2d170e0..9d81365 100644 --- a/test_inference.py +++ b/test_inference.py @@ -44,7 +44,7 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150) # (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!) parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model") parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)") -parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset") +parser.add_argument("-er", "--eval_rows", type = int, default = None, help = "Number of rows to apply from dataset (default 128)") parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample") parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache") parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache") @@ -267,13 +267,15 @@ if args.eval_dataset or args.standard_perplexity: seqs.append(eval_tokens[:, a:b]) eval_len.append(b if a == 0 else stride) a += stride + if args.eval_rows and len(seqs) >= args.eval_rows: + break eval_tokens = torch.cat(seqs, dim = 0) else: eval_dataset = args.eval_dataset - eval_rows = args.eval_rows + eval_rows = args.eval_rows or 128 eval_length = args.eval_length print(f" -- Dataset: {eval_dataset}")