mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Add token ppl and 8-bit cache test to test_inference script
This commit is contained in:
@@ -3,6 +3,7 @@ from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
model_init,
|
||||
)
|
||||
@@ -32,6 +33,8 @@ parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 mode
|
||||
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
||||
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
|
||||
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
||||
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
|
||||
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit cache")
|
||||
parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
|
||||
parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
|
||||
parser.add_argument("-ps", "--prompt_speed", action = "store_true", help = "Test prompt processing (batch) speed over context length")
|
||||
@@ -146,9 +149,51 @@ if args.eval_dataset:
|
||||
|
||||
mean_log_prob = logprob_sum / logprob_count
|
||||
perplexity = math.exp(-mean_log_prob)
|
||||
|
||||
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
||||
|
||||
def test_ppl_token():
|
||||
global logprob_sum, logprob_count, i, input_ids
|
||||
global logits, target_ids, log_probs, token_log_probs
|
||||
global mean_log_prob, perplexity
|
||||
|
||||
logprob_sum = 0
|
||||
logprob_count = 0
|
||||
|
||||
for i in range(eval_tokens.shape[0]):
|
||||
|
||||
cache.current_seq_len = 0
|
||||
|
||||
for j in range(eval_tokens.shape[1] - 1):
|
||||
if j % 256 == 0: print(".", end="")
|
||||
sys.stdout.flush()
|
||||
|
||||
input_ids = eval_tokens[i:i + 1, j:j + 1]
|
||||
logits = model.forward(input_ids, cache)
|
||||
logits = logits.float() + 1e-10
|
||||
|
||||
log_probs = F.log_softmax(logits, dim = -1)
|
||||
logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
|
||||
logprob_count += 1
|
||||
|
||||
print()
|
||||
|
||||
mean_log_prob = logprob_sum / logprob_count
|
||||
perplexity = math.exp(-mean_log_prob)
|
||||
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
||||
|
||||
|
||||
if args.eval_token:
|
||||
print(f" -- Inference (token)", end="")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
|
||||
if args.eval_token_8bit:
|
||||
print(f" -- Inference (token, 8-bit cache)", end="")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
|
||||
|
||||
# Test prompt speed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user