mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Allow reduced max_input_len when measuring ppl
This commit is contained in:
@@ -116,6 +116,8 @@ if args.eval_dataset:
|
||||
logprob_sum = 0.0
|
||||
logprob_count = 0
|
||||
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
|
||||
|
||||
for i in range(eval_tokens.shape[0]):
|
||||
|
||||
if i % 10 == 0: print(".", end = "")
|
||||
@@ -124,7 +126,8 @@ if args.eval_dataset:
|
||||
input_ids = eval_tokens[i:i+1, :]
|
||||
|
||||
input_ids = input_ids[:, :-1]
|
||||
logits = model.forward(input_ids)
|
||||
if cache is not None: cache.current_seq_len = 0
|
||||
logits = model.forward(input_ids, cache)
|
||||
logits = logits.float() + 1e-10
|
||||
|
||||
target_ids = input_ids[:, 1:].to(logits.device)
|
||||
|
||||
Reference in New Issue
Block a user