Add token ppl and 8-bit cache test to test_inference script

This commit is contained in:
turboderp
2023-12-03 22:04:01 +01:00
parent 38d393718d
commit 0aeca11fa6

View File

@@ -3,6 +3,7 @@ from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
model_init,
)
@@ -32,6 +33,8 @@ parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 mode
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit cache")
parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
parser.add_argument("-ps", "--prompt_speed", action = "store_true", help = "Test prompt processing (batch) speed over context length")
@@ -146,9 +149,51 @@ if args.eval_dataset:
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
print(f" -- Evaluation perplexity: {perplexity:.4f}")
def test_ppl_token():
global logprob_sum, logprob_count, i, input_ids
global logits, target_ids, log_probs, token_log_probs
global mean_log_prob, perplexity
logprob_sum = 0
logprob_count = 0
for i in range(eval_tokens.shape[0]):
cache.current_seq_len = 0
for j in range(eval_tokens.shape[1] - 1):
if j % 256 == 0: print(".", end="")
sys.stdout.flush()
input_ids = eval_tokens[i:i + 1, j:j + 1]
logits = model.forward(input_ids, cache)
logits = logits.float() + 1e-10
log_probs = F.log_softmax(logits, dim = -1)
logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
logprob_count += 1
print()
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
print(f" -- Evaluation perplexity: {perplexity:.4f}")
if args.eval_token:
print(f" -- Inference (token)", end="")
sys.stdout.flush()
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
test_ppl_token()
if args.eval_token_8bit:
print(f" -- Inference (token, 8-bit cache)", end="")
sys.stdout.flush()
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
test_ppl_token()
# Test prompt speed