Fix ppl test for long seq lengths

This commit is contained in:
turboderp
2024-07-10 08:05:57 +02:00
parent 0122b1192f
commit 1179b8a5e5
2 changed files with 13 additions and 3 deletions

View File

@@ -292,6 +292,8 @@ if args.eval_dataset or args.standard_perplexity:
def ppl(input_ids__, logits__, lengths__, bins = False):
logits_device = model.modules[-1].device()
if bins:
num_bins = (max(lengths__) + 255) // 256
logprob_sum_ = [0.0] * num_bins
@@ -317,8 +319,8 @@ if args.eval_dataset or args.standard_perplexity:
a_ = b_
b_ = min(b_ + chunksize, logits_.shape[1])
logits_f = logits_[:, a_:b_, :].float() + 1e-10
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
logits_f = logits_[:, a_:b_, :].to(logits_device).float() + 1e-10
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_f.device)
log_probs = F.log_softmax(logits_f, dim=-1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
@@ -398,7 +400,7 @@ if args.eval_dataset or args.standard_perplexity:
input_ids = input_ids[:, :]
if cache is not None: cache.current_seq_len = 0
logits = model.forward(input_ids, cache)
logits = model.forward(input_ids, cache, cpu_logits = input_ids.numel() > 2048)
logits = logits[:, :-1, :]
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens)