Fix OoM when testing PPL with large vocab

This commit is contained in:
turboderp
2023-12-12 13:02:42 +01:00
parent c1dbe4221f
commit 9f01116fb4

View File

@@ -170,14 +170,20 @@ if args.eval_dataset:
if cache is not None: cache.current_seq_len = 0
logits = model.forward(input_ids, cache)
logits = logits[:, :-1, :]
logits = logits.float() + 1e-10
target_ids = input_ids[:, 1:].to(logits.device)
chunksize = logits.shape[1] * 16000 // logits.shape[2]
b = 0
while b < logits.shape[1]:
a = b
b = min(b + chunksize, logits.shape[1])
log_probs = F.log_softmax(logits, dim = -1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum += token_log_probs.sum().item()
logprob_count += target_ids.numel()
logits_f = logits[:, a:b, :].float() + 1e-10
target_ids = input_ids[:, a+1:b+1].to(logits.device)
log_probs = F.log_softmax(logits_f, dim = -1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum += token_log_probs.sum().item()
logprob_count += target_ids.numel()
print()