Reduce VRAM overhead in ppl test

This commit is contained in:
turboderp
2025-04-17 22:22:47 +02:00
parent f8b9ab72ae
commit 9849c92b36
2 changed files with 5 additions and 2 deletions

View File

@@ -107,7 +107,8 @@ def test_ppl(data_spec: dict, spec: dict):
pb.update(row)
input_ids = eval_ids[row:row + 1, :]
logits = fwd_fn(model_instance, input_ids)
logits = logits[:, :-1, :].float() + 1e-10
logits = logits[:, :-1, :].float()
logits += 1e-10
log_probs = F.log_softmax(logits, dim = -1)
del logits
target_ids = input_ids[:, 1:].to(log_probs.device)

View File

@@ -63,7 +63,8 @@ def main(args):
pb.update(row)
input_ids = eval_ids[row:row + 1, :]
logits = model.forward(input_ids, {"attn_mode": "flash_attn_nc"})
logits = logits[:, :-1, :vocab_size].float() + 1e-10
logits = logits[:, :-1, :vocab_size].float()
logits += 1e-10
log_probs = F.log_softmax(logits, dim = -1)
del logits
target_ids = input_ids[:, 1:].to(log_probs.device)
@@ -71,6 +72,7 @@ def main(args):
target_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum += target_log_probs.sum().item()
logprob_count += target_ids.numel()
del log_probs
del target_log_probs
del target_ids
torch.cuda.empty_cache()