mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Fix ppl test for long seq lengths
This commit is contained in:
@@ -681,6 +681,7 @@ class ExLlamaV2:
|
||||
return_last_state: bool = False,
|
||||
position_offsets: torch.Tensor | None = None,
|
||||
abort_event: threading.Event | None = None,
|
||||
cpu_logits: bool = False,
|
||||
**kwargs) \
|
||||
-> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None:
|
||||
"""
|
||||
@@ -717,6 +718,11 @@ class ExLlamaV2:
|
||||
:param abort_event:
|
||||
Optional event that, if set, will abort the forward pass. Function will return None if aborted.
|
||||
|
||||
:param cpu_logits:
|
||||
If True, logits are collected and returned in system RAM. This is somewhat slower but can prevent
|
||||
out-of-memory errors when computing logits for all positions in a long sequence, such as during a
|
||||
perplexity test.
|
||||
|
||||
:return:
|
||||
FP16 logits tensor, shape (batch_size, q_len, vocab_size)
|
||||
(optional) state tensor, shape (batch_size, q_len, hidden_size)
|
||||
@@ -819,6 +825,8 @@ class ExLlamaV2:
|
||||
if abort_event and abort_event.is_set(): return
|
||||
|
||||
if not _preprocess_only:
|
||||
if cpu_logits:
|
||||
r["logits"] = r["logits"].cpu()
|
||||
result = r["logits"] if result is None else torch.cat((result, r["logits"]), dim = 1)
|
||||
|
||||
chunk_begin = chunk_end
|
||||
|
||||
@@ -292,6 +292,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
|
||||
def ppl(input_ids__, logits__, lengths__, bins = False):
|
||||
|
||||
logits_device = model.modules[-1].device()
|
||||
|
||||
if bins:
|
||||
num_bins = (max(lengths__) + 255) // 256
|
||||
logprob_sum_ = [0.0] * num_bins
|
||||
@@ -317,8 +319,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
a_ = b_
|
||||
b_ = min(b_ + chunksize, logits_.shape[1])
|
||||
|
||||
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
||||
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
||||
logits_f = logits_[:, a_:b_, :].to(logits_device).float() + 1e-10
|
||||
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_f.device)
|
||||
|
||||
log_probs = F.log_softmax(logits_f, dim=-1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
@@ -398,7 +400,7 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
|
||||
input_ids = input_ids[:, :]
|
||||
if cache is not None: cache.current_seq_len = 0
|
||||
logits = model.forward(input_ids, cache)
|
||||
logits = model.forward(input_ids, cache, cpu_logits = input_ids.numel() > 2048)
|
||||
logits = logits[:, :-1, :]
|
||||
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens)
|
||||
|
||||
Reference in New Issue
Block a user