mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-23 15:59:10 +00:00
Attempt to add standard ppl test (experimental)
This commit is contained in:
@@ -45,6 +45,7 @@ parser.add_argument("-s", "--speed", action = "store_true", help = "Test raw gen
|
||||
parser.add_argument("-mix", "--mix_layers", type = str, help = "Load replacement layers from secondary model. Example: --mix_layers 1,6-7:/mnt/models/other_model")
|
||||
parser.add_argument("-nwu", "--no_warmup", action = "store_true", help = "Skip warmup before testing model")
|
||||
parser.add_argument("-sl", "--stream_layers", action = "store_true", help = "Load model layer by layer (perplexity evaluation only)")
|
||||
parser.add_argument("-sp", "--standard_perplexity", choices = ["wiki2"], help = "Run standard (HF) perplexity test, stride 512")
|
||||
|
||||
# Initialize model and tokenizer
|
||||
|
||||
@@ -169,41 +170,89 @@ if args.prompt:
|
||||
|
||||
# Test perplexity
|
||||
|
||||
if args.eval_dataset:
|
||||
if args.eval_dataset or args.standard_perplexity:
|
||||
|
||||
with torch.inference_mode():
|
||||
|
||||
eval_dataset = args.eval_dataset
|
||||
eval_rows = args.eval_rows
|
||||
eval_length = args.eval_length
|
||||
|
||||
print(f" -- Running perplexity test")
|
||||
print(f" -- Dataset: {eval_dataset}")
|
||||
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
|
||||
|
||||
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
|
||||
if args.standard_perplexity:
|
||||
|
||||
if args.eval_dataset:
|
||||
print(f" !! Note, overriding specified --eval_dataset with {args.standard_perplexity}")
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
if args.standard_perplexity == "wiki2":
|
||||
ds = "wikitext"
|
||||
part = "wikitext-2-raw-v1"
|
||||
split = "test"
|
||||
# if args.standard_perplexity == "c4":
|
||||
# ds = "allenai/c4"
|
||||
# part = "allenai--c4"
|
||||
# split = "train"
|
||||
|
||||
print(f" -- Loading dataset {ds}, {part}, {split}...")
|
||||
test = load_dataset(ds, part, split = split)
|
||||
|
||||
print(f" -- Tokenizing samples...")
|
||||
text = "\n\n".join(test["text"])
|
||||
eval_tokens = tokenizer.encode(text)
|
||||
|
||||
stride = 512
|
||||
seqs = []
|
||||
eval_len = []
|
||||
a = 0
|
||||
while True:
|
||||
b = a + model.config.max_seq_len
|
||||
if b > eval_tokens.shape[-1]: break
|
||||
seqs.append(eval_tokens[:, a:b])
|
||||
eval_len.append(b if a == 0 else stride)
|
||||
a += stride
|
||||
|
||||
eval_tokens = torch.cat(seqs, dim = 0)
|
||||
|
||||
else:
|
||||
|
||||
eval_dataset = args.eval_dataset
|
||||
eval_rows = args.eval_rows
|
||||
eval_length = args.eval_length
|
||||
|
||||
print(f" -- Dataset: {eval_dataset}")
|
||||
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
|
||||
|
||||
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
|
||||
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
|
||||
|
||||
logprob_sum = 0.0
|
||||
logprob_count = 0
|
||||
|
||||
def ppl(input_ids_, logits_):
|
||||
def ppl(input_ids__, logits__, lengths__):
|
||||
|
||||
logprob_sum_ = 0.0
|
||||
logprob_count_ = 0
|
||||
|
||||
chunksize = logits_.shape[1] * 16000 // logits_.shape[2]
|
||||
b_ = 0
|
||||
while b_ < logits_.shape[1]:
|
||||
a_ = b_
|
||||
b_ = min(b_ + chunksize, logits_.shape[1])
|
||||
assert logits__.shape[0] == input_ids__.shape[0]
|
||||
ll = logits__.shape[1]
|
||||
|
||||
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
||||
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
||||
for bi in range(logits__.shape[0]):
|
||||
cl = ll - lengths__[bi]
|
||||
logits_ = logits__[bi:bi+1, cl:, :]
|
||||
input_ids_ = input_ids__[bi:bi+1, cl:]
|
||||
|
||||
log_probs = F.log_softmax(logits_f, dim=-1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum_ += token_log_probs.sum().item()
|
||||
logprob_count_ += target_ids.numel()
|
||||
chunksize = logits_.shape[1] * 16000 // logits_.shape[2] + 1
|
||||
b_ = 0
|
||||
while b_ < logits_.shape[1]:
|
||||
a_ = b_
|
||||
b_ = min(b_ + chunksize, logits_.shape[1])
|
||||
|
||||
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
||||
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
||||
|
||||
log_probs = F.log_softmax(logits_f, dim=-1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum_ += token_log_probs.sum().item()
|
||||
logprob_count_ += target_ids.numel()
|
||||
|
||||
return logprob_sum_, logprob_count_
|
||||
|
||||
@@ -243,7 +292,7 @@ if args.eval_dataset:
|
||||
input_ids = eval_tokens[a:b, :]
|
||||
logits = x[:, :-1, :]
|
||||
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[a:b])
|
||||
logprob_sum += logprob_sum__
|
||||
logprob_count += logprob_count__
|
||||
|
||||
@@ -271,7 +320,7 @@ if args.eval_dataset:
|
||||
logits = model.forward(input_ids, cache)
|
||||
logits = logits[:, :-1, :]
|
||||
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1])
|
||||
logprob_sum += logprob_sum__
|
||||
logprob_count += logprob_count__
|
||||
|
||||
@@ -318,16 +367,22 @@ if args.eval_dataset:
|
||||
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
||||
|
||||
if args.eval_token:
|
||||
print(f" -- Inference (token)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
if args.standard_perplexity:
|
||||
print(f" !! Note, can't evalutate token perplexity on standard test")
|
||||
else:
|
||||
print(f" -- Inference (token)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
|
||||
if args.eval_token_8bit:
|
||||
print(f" -- Inference (token, 8-bit cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
if args.standard_perplexity:
|
||||
print(f" !! Note, can't evalutate token perplexity on standard test")
|
||||
else:
|
||||
print(f" -- Inference (token, 8-bit cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
|
||||
|
||||
# Test prompt speed
|
||||
|
||||
Reference in New Issue
Block a user