Model: Fix logprobs unwrapping

Take a log of the token probs since they're already normalized which
reflects the proper value. Also, don't error out if a token prob
doesn't exist in the dict and return None instead from zip.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-02-08 19:28:19 -05:00
committed by Brian Dashore
parent c7428f0bcd
commit 43bba526bf

View File

@@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import gc
from itertools import zip_longest
import pathlib
import time
@@ -486,9 +487,11 @@ class ExllamaV2Container:
)
top_values = top_values[0].tolist()
return dict(zip(top_tokens, top_values, strict=True))
return dict(zip_longest(top_tokens, top_values))
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
normalized_probs = torch.log(token_probs)
tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
@@ -498,7 +501,7 @@ class ExllamaV2Container:
)
)
return dict(zip(tokens, token_probs[0].tolist(), strict=True))
return dict(zip_longest(tokens, normalized_probs[0].tolist()))
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""