mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
220 lines
6.2 KiB
Python
220 lines
6.2 KiB
Python
|
|
from exllamav2 import(
|
|
ExLlamaV2,
|
|
ExLlamaV2Config,
|
|
ExLlamaV2Cache,
|
|
ExLlamaV2Tokenizer,
|
|
model_init,
|
|
)
|
|
|
|
import argparse, os, math, time
|
|
import pandas, fastparquet
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from conversion.tokenize import get_tokens
|
|
from conversion.quantize import list_live_tensors
|
|
|
|
import sys
|
|
import json
|
|
|
|
torch.cuda._lazy_init()
|
|
torch.set_printoptions(precision = 10)
|
|
# torch.backends.cuda.matmul.allow_tf32 = True
|
|
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
|
# torch.set_float32_matmul_precision("medium")
|
|
|
|
parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
|
|
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
|
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
|
|
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
|
parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt")
|
|
parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
|
|
parser.add_argument("-ps", "--prompt_speed", action = "store_true", help = "Test prompt processing (batch) speed over context length")
|
|
parser.add_argument("-s", "--speed", action = "store_true", help = "Test raw generation speed over context length")
|
|
|
|
# Initialize model and tokenizer
|
|
|
|
model_init.add_args(parser)
|
|
args = parser.parse_args()
|
|
model_init.check_args(args)
|
|
model_init.print_options(args)
|
|
model, tokenizer = model_init.init(args)
|
|
|
|
# Test generation
|
|
|
|
if args.prompt:
|
|
|
|
with torch.inference_mode():
|
|
|
|
cache = ExLlamaV2Cache(model)
|
|
|
|
ids = tokenizer.encode(args.prompt)
|
|
tokens_prompt = ids.shape[-1]
|
|
|
|
print(f" -- Warmup...")
|
|
|
|
model.forward(ids[:, -1:])
|
|
|
|
print(f" -- Generating (greedy sampling)...")
|
|
print()
|
|
print(args.prompt, end = "")
|
|
sys.stdout.flush()
|
|
|
|
time_begin = time.time()
|
|
|
|
if ids.shape[-1] > 1: model.forward(ids[:, :-1], cache, preprocess_only = True)
|
|
|
|
torch.cuda.synchronize()
|
|
time_prompt = time.time()
|
|
|
|
for i in range(args.tokens):
|
|
|
|
text1 = tokenizer.decode(ids[:, -2:])[0]
|
|
|
|
logits = model.forward(ids[:, -1:], cache)
|
|
sample = torch.argmax(logits[0, -1]).cpu().unsqueeze(0).unsqueeze(0)
|
|
ids = torch.cat((ids, sample), dim = -1)
|
|
|
|
text2 = tokenizer.decode(ids[:, -3:])[0]
|
|
text2 = text2[len(text1):]
|
|
|
|
print (text2, end = "")
|
|
# sys.stdout.flush()
|
|
|
|
time_end = time.time()
|
|
|
|
print()
|
|
print()
|
|
|
|
total_prompt = time_prompt - time_begin
|
|
total_gen = time_end - time_prompt
|
|
print(f"Prompt processed in {total_prompt:.2f} seconds, {tokens_prompt} tokens, {tokens_prompt / total_prompt:.2f} tokens/second")
|
|
print(f"Response generated in {total_gen:.2f} seconds, {args.tokens} tokens, {args.tokens / total_gen:.2f} tokens/second")
|
|
|
|
cache = None
|
|
|
|
|
|
# Test perplexity
|
|
|
|
if args.eval_dataset:
|
|
|
|
with torch.inference_mode():
|
|
|
|
eval_dataset = args.eval_dataset
|
|
eval_rows = args.eval_rows
|
|
eval_length = args.eval_length
|
|
|
|
print(f" -- Running perplexity test")
|
|
print(f" -- Dataset: {eval_dataset}")
|
|
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
|
|
|
|
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
|
|
|
|
print(f" -- Inference", end = "")
|
|
sys.stdout.flush()
|
|
|
|
logprob_sum = 0.0
|
|
logprob_count = 0
|
|
|
|
for i in range(eval_tokens.shape[0]):
|
|
#for i in range(126, 127):
|
|
|
|
if i % 10 == 0: print(".", end = "")
|
|
sys.stdout.flush()
|
|
|
|
input_ids = eval_tokens[i:i+1, :]
|
|
|
|
input_ids = input_ids[:, :-1]
|
|
logits = model.forward(input_ids)
|
|
|
|
# print (tokenizer.decode(input_ids))
|
|
|
|
target_ids = input_ids[:, 1:].to(logits.device)
|
|
|
|
log_probs = F.log_softmax(logits, dim=-1)
|
|
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
|
logprob_sum += token_log_probs.sum().item()
|
|
logprob_count += target_ids.numel()
|
|
|
|
print()
|
|
|
|
mean_log_prob = logprob_sum / logprob_count
|
|
perplexity = math.exp(-mean_log_prob)
|
|
|
|
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
|
|
|
xx = 0
|
|
|
|
|
|
# Test prompt speed
|
|
|
|
if args.prompt_speed:
|
|
|
|
with torch.inference_mode():
|
|
|
|
cache = ExLlamaV2Cache(model)
|
|
|
|
ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
|
|
|
|
print(f" -- Warmup...")
|
|
|
|
model.forward(ids[:, -1:])
|
|
|
|
print(f" -- Measuring prompt speed...")
|
|
|
|
current_len = 128
|
|
while True:
|
|
|
|
time_begin = time.time()
|
|
|
|
cache.current_seq_len = 0
|
|
model.forward(ids[:, :current_len], cache, preprocess_only = True)
|
|
torch.cuda.synchronize()
|
|
|
|
time_end = time.time()
|
|
tps = current_len / (time_end - time_begin)
|
|
|
|
print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")
|
|
|
|
current_len_ = current_len
|
|
current_len = min(current_len + 128, model.config.max_seq_len)
|
|
if current_len == current_len_: break
|
|
|
|
cache = None
|
|
|
|
|
|
# Test token speed
|
|
|
|
if args.speed:
|
|
|
|
with torch.inference_mode():
|
|
|
|
cache = ExLlamaV2Cache(model)
|
|
|
|
print(f" -- Measuring token speed...")
|
|
ids = tokenizer.encode("X")
|
|
model.forward(ids[:, :])
|
|
|
|
current_idx = ids.shape[-1]
|
|
next_stop = 128
|
|
|
|
while True:
|
|
|
|
time_begin = time.time()
|
|
|
|
tokens = next_stop - current_idx
|
|
for i in range(tokens):
|
|
|
|
logits = model.forward(ids[:, -1:], cache)
|
|
sample = torch.argmax(logits[0, -1]).cpu().unsqueeze(0).unsqueeze(0)
|
|
ids = torch.cat((ids, sample), dim=-1)
|
|
|
|
time_end = time.time()
|
|
tps = tokens / (time_end - time_begin)
|
|
|
|
print(f" ** Position {current_idx:>5} + {tokens:>3} tokens: {tps:>9.4f} t/s")
|
|
|
|
current_idx = next_stop
|
|
next_stop = min(next_stop + 128, model.config.max_seq_len)
|
|
if next_stop == current_idx: break
|