mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Model: Fix logprobs unwrapping
Take a log of the token probs since they're already normalized which reflects the proper value. Also, don't error out if a token prob doesn't exist in the dict and return None instead from zip. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
import gc
|
||||
from itertools import zip_longest
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
@@ -486,9 +487,11 @@ class ExllamaV2Container:
|
||||
)
|
||||
top_values = top_values[0].tolist()
|
||||
|
||||
return dict(zip(top_tokens, top_values, strict=True))
|
||||
return dict(zip_longest(top_tokens, top_values))
|
||||
|
||||
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
|
||||
normalized_probs = torch.log(token_probs)
|
||||
|
||||
tokens = list(
|
||||
map(
|
||||
lambda index: self.tokenizer.extended_id_to_piece.get(
|
||||
@@ -498,7 +501,7 @@ class ExllamaV2Container:
|
||||
)
|
||||
)
|
||||
|
||||
return dict(zip(tokens, token_probs[0].tolist(), strict=True))
|
||||
return dict(zip_longest(tokens, normalized_probs[0].tolist()))
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
|
||||
Reference in New Issue
Block a user