Change test_inference.py ppl calculation to exactly match logic in convert.py

This commit is contained in:
turboderp
2023-09-20 09:56:25 +02:00
parent 2a3ff14af2
commit 73a133405f

View File

@@ -125,14 +125,15 @@ if args.eval_dataset:
input_ids = eval_tokens[i:i+1, :]
input_ids = input_ids[:, :-1]
input_ids = input_ids[:, :]
if cache is not None: cache.current_seq_len = 0
logits = model.forward(input_ids, cache)
logits = logits[:, :-1, :]
logits = logits.float() + 1e-10
target_ids = input_ids[:, 1:].to(logits.device)
log_probs = F.log_softmax(logits, dim=-1)
log_probs = F.log_softmax(logits, dim = -1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum += token_log_probs.sum().item()
logprob_count += target_ids.numel()