Add Q4 cache mode

This commit is contained in:
turboderp
2024-03-03 23:34:11 +01:00
parent b4e6c5e9c9
commit bafe539728
9 changed files with 366 additions and 28 deletions

View File

@@ -4,6 +4,7 @@ from exllamav2 import(
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Tokenizer,
model_init,
)
@@ -40,7 +41,8 @@ parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity eval
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit cache")
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache")
parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)")
parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt")
@@ -61,7 +63,7 @@ args = parser.parse_args()
# Check conflicting settings
if args.stream_layers:
if args.eval_token or args.eval_token_8bit:
if args.eval_token or args.eval_token_8bit or args.eval_token_q4:
print(" ## Can't test token ppl while streaming layers")
sys.exit()
if args.prompt:
@@ -423,6 +425,15 @@ if args.eval_dataset or args.standard_perplexity:
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
test_ppl_token()
if args.eval_token_q4:
if args.standard_perplexity:
print(f" !! Note, can't evalutate token perplexity on standard test")
else:
print(f" -- Inference (token, Q4 cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
test_ppl_token()
# Test prompt speed