Attempt to add standard ppl test (experimental)

This commit is contained in:
turboderp
2023-12-30 01:39:03 +01:00
parent e4d4713757
commit a52d410d4a

View File

@@ -45,6 +45,7 @@ parser.add_argument("-s", "--speed", action = "store_true", help = "Test raw gen
parser.add_argument("-mix", "--mix_layers", type = str, help = "Load replacement layers from secondary model. Example: --mix_layers 1,6-7:/mnt/models/other_model")
parser.add_argument("-nwu", "--no_warmup", action = "store_true", help = "Skip warmup before testing model")
parser.add_argument("-sl", "--stream_layers", action = "store_true", help = "Load model layer by layer (perplexity evaluation only)")
parser.add_argument("-sp", "--standard_perplexity", choices = ["wiki2"], help = "Run standard (HF) perplexity test, stride 512")
# Initialize model and tokenizer
@@ -169,41 +170,89 @@ if args.prompt:
# Test perplexity
if args.eval_dataset:
if args.eval_dataset or args.standard_perplexity:
with torch.inference_mode():
eval_dataset = args.eval_dataset
eval_rows = args.eval_rows
eval_length = args.eval_length
print(f" -- Running perplexity test")
print(f" -- Dataset: {eval_dataset}")
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
if args.standard_perplexity:
if args.eval_dataset:
print(f" !! Note, overriding specified --eval_dataset with {args.standard_perplexity}")
from datasets import load_dataset
if args.standard_perplexity == "wiki2":
ds = "wikitext"
part = "wikitext-2-raw-v1"
split = "test"
# if args.standard_perplexity == "c4":
# ds = "allenai/c4"
# part = "allenai--c4"
# split = "train"
print(f" -- Loading dataset {ds}, {part}, {split}...")
test = load_dataset(ds, part, split = split)
print(f" -- Tokenizing samples...")
text = "\n\n".join(test["text"])
eval_tokens = tokenizer.encode(text)
stride = 512
seqs = []
eval_len = []
a = 0
while True:
b = a + model.config.max_seq_len
if b > eval_tokens.shape[-1]: break
seqs.append(eval_tokens[:, a:b])
eval_len.append(b if a == 0 else stride)
a += stride
eval_tokens = torch.cat(seqs, dim = 0)
else:
eval_dataset = args.eval_dataset
eval_rows = args.eval_rows
eval_length = args.eval_length
print(f" -- Dataset: {eval_dataset}")
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
logprob_sum = 0.0
logprob_count = 0
def ppl(input_ids_, logits_):
def ppl(input_ids__, logits__, lengths__):
logprob_sum_ = 0.0
logprob_count_ = 0
chunksize = logits_.shape[1] * 16000 // logits_.shape[2]
b_ = 0
while b_ < logits_.shape[1]:
a_ = b_
b_ = min(b_ + chunksize, logits_.shape[1])
assert logits__.shape[0] == input_ids__.shape[0]
ll = logits__.shape[1]
logits_f = logits_[:, a_:b_, :].float() + 1e-10
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
for bi in range(logits__.shape[0]):
cl = ll - lengths__[bi]
logits_ = logits__[bi:bi+1, cl:, :]
input_ids_ = input_ids__[bi:bi+1, cl:]
log_probs = F.log_softmax(logits_f, dim=-1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum_ += token_log_probs.sum().item()
logprob_count_ += target_ids.numel()
chunksize = logits_.shape[1] * 16000 // logits_.shape[2] + 1
b_ = 0
while b_ < logits_.shape[1]:
a_ = b_
b_ = min(b_ + chunksize, logits_.shape[1])
logits_f = logits_[:, a_:b_, :].float() + 1e-10
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
log_probs = F.log_softmax(logits_f, dim=-1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum_ += token_log_probs.sum().item()
logprob_count_ += target_ids.numel()
return logprob_sum_, logprob_count_
@@ -243,7 +292,7 @@ if args.eval_dataset:
input_ids = eval_tokens[a:b, :]
logits = x[:, :-1, :]
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[a:b])
logprob_sum += logprob_sum__
logprob_count += logprob_count__
@@ -271,7 +320,7 @@ if args.eval_dataset:
logits = model.forward(input_ids, cache)
logits = logits[:, :-1, :]
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1])
logprob_sum += logprob_sum__
logprob_count += logprob_count__
@@ -318,16 +367,22 @@ if args.eval_dataset:
print(f" -- Evaluation perplexity: {perplexity:.4f}")
if args.eval_token:
print(f" -- Inference (token)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
test_ppl_token()
if args.standard_perplexity:
print(f" !! Note, can't evalutate token perplexity on standard test")
else:
print(f" -- Inference (token)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
test_ppl_token()
if args.eval_token_8bit:
print(f" -- Inference (token, 8-bit cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
test_ppl_token()
if args.standard_perplexity:
print(f" !! Note, can't evalutate token perplexity on standard test")
else:
print(f" -- Inference (token, 8-bit cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
test_ppl_token()
# Test prompt speed