mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-30 03:01:23 +00:00
Test script: Allow --eval_rows in wiki2 ppl test
This commit is contained in:
@@ -44,7 +44,7 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
|
|||||||
# (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
|
# (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
|
||||||
parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
|
parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
|
||||||
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
||||||
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
|
parser.add_argument("-er", "--eval_rows", type = int, default = None, help = "Number of rows to apply from dataset (default 128)")
|
||||||
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
||||||
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
|
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
|
||||||
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
|
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
|
||||||
@@ -267,13 +267,15 @@ if args.eval_dataset or args.standard_perplexity:
|
|||||||
seqs.append(eval_tokens[:, a:b])
|
seqs.append(eval_tokens[:, a:b])
|
||||||
eval_len.append(b if a == 0 else stride)
|
eval_len.append(b if a == 0 else stride)
|
||||||
a += stride
|
a += stride
|
||||||
|
if args.eval_rows and len(seqs) >= args.eval_rows:
|
||||||
|
break
|
||||||
|
|
||||||
eval_tokens = torch.cat(seqs, dim = 0)
|
eval_tokens = torch.cat(seqs, dim = 0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
eval_dataset = args.eval_dataset
|
eval_dataset = args.eval_dataset
|
||||||
eval_rows = args.eval_rows
|
eval_rows = args.eval_rows or 128
|
||||||
eval_length = args.eval_length
|
eval_length = args.eval_length
|
||||||
|
|
||||||
print(f" -- Dataset: {eval_dataset}")
|
print(f" -- Dataset: {eval_dataset}")
|
||||||
|
|||||||
Reference in New Issue
Block a user