Fix regular ppl test

This commit is contained in:
turboderp
2023-12-30 22:14:43 +01:00
parent 4d5ef3b53d
commit 5ddf57f945

View File

@@ -236,7 +236,7 @@ if args.eval_dataset or args.standard_perplexity:
ll = logits__.shape[1]
for bi in range(logits__.shape[0]):
cl = ll - lengths__[bi]
cl = max(ll - lengths__[bi], 0)
logits_ = logits__[bi:bi+1, cl:, :]
input_ids_ = input_ids__[bi:bi+1, cl:]