mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Reduce VRAM overhead in ppl test
This commit is contained in:
@@ -107,7 +107,8 @@ def test_ppl(data_spec: dict, spec: dict):
|
||||
pb.update(row)
|
||||
input_ids = eval_ids[row:row + 1, :]
|
||||
logits = fwd_fn(model_instance, input_ids)
|
||||
logits = logits[:, :-1, :].float() + 1e-10
|
||||
logits = logits[:, :-1, :].float()
|
||||
logits += 1e-10
|
||||
log_probs = F.log_softmax(logits, dim = -1)
|
||||
del logits
|
||||
target_ids = input_ids[:, 1:].to(log_probs.device)
|
||||
|
||||
@@ -63,7 +63,8 @@ def main(args):
|
||||
pb.update(row)
|
||||
input_ids = eval_ids[row:row + 1, :]
|
||||
logits = model.forward(input_ids, {"attn_mode": "flash_attn_nc"})
|
||||
logits = logits[:, :-1, :vocab_size].float() + 1e-10
|
||||
logits = logits[:, :-1, :vocab_size].float()
|
||||
logits += 1e-10
|
||||
log_probs = F.log_softmax(logits, dim = -1)
|
||||
del logits
|
||||
target_ids = input_ids[:, 1:].to(log_probs.device)
|
||||
@@ -71,6 +72,7 @@ def main(args):
|
||||
target_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum += target_log_probs.sum().item()
|
||||
logprob_count += target_ids.numel()
|
||||
del log_probs
|
||||
del target_log_probs
|
||||
del target_ids
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user