mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Fix ppl test for long seq lengths
This commit is contained in:
@@ -292,6 +292,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
|
||||
def ppl(input_ids__, logits__, lengths__, bins = False):
|
||||
|
||||
logits_device = model.modules[-1].device()
|
||||
|
||||
if bins:
|
||||
num_bins = (max(lengths__) + 255) // 256
|
||||
logprob_sum_ = [0.0] * num_bins
|
||||
@@ -317,8 +319,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
a_ = b_
|
||||
b_ = min(b_ + chunksize, logits_.shape[1])
|
||||
|
||||
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
||||
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
||||
logits_f = logits_[:, a_:b_, :].to(logits_device).float() + 1e-10
|
||||
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_f.device)
|
||||
|
||||
log_probs = F.log_softmax(logits_f, dim=-1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
@@ -398,7 +400,7 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
|
||||
input_ids = input_ids[:, :]
|
||||
if cache is not None: cache.current_seq_len = 0
|
||||
logits = model.forward(input_ids, cache)
|
||||
logits = model.forward(input_ids, cache, cpu_logits = input_ids.numel() > 2048)
|
||||
logits = logits[:, :-1, :]
|
||||
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens)
|
||||
|
||||
Reference in New Issue
Block a user