mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 18:51:53 +00:00
Model: Fix logprobs unwrapping
Take a log of the token probs since they're already normalized which reflects the proper value. Also, don't error out if a token prob doesn't exist in the dict and return None instead from zip. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""The model container class for ExLlamaV2 models."""
|
"""The model container class for ExLlamaV2 models."""
|
||||||
import gc
|
import gc
|
||||||
|
from itertools import zip_longest
|
||||||
import pathlib
|
import pathlib
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -486,9 +487,11 @@ class ExllamaV2Container:
|
|||||||
)
|
)
|
||||||
top_values = top_values[0].tolist()
|
top_values = top_values[0].tolist()
|
||||||
|
|
||||||
return dict(zip(top_tokens, top_values, strict=True))
|
return dict(zip_longest(top_tokens, top_values))
|
||||||
|
|
||||||
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
|
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
|
||||||
|
normalized_probs = torch.log(token_probs)
|
||||||
|
|
||||||
tokens = list(
|
tokens = list(
|
||||||
map(
|
map(
|
||||||
lambda index: self.tokenizer.extended_id_to_piece.get(
|
lambda index: self.tokenizer.extended_id_to_piece.get(
|
||||||
@@ -498,7 +501,7 @@ class ExllamaV2Container:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(zip(tokens, token_probs[0].tolist(), strict=True))
|
return dict(zip_longest(tokens, normalized_probs[0].tolist()))
|
||||||
|
|
||||||
def generate(self, prompt: str, **kwargs):
|
def generate(self, prompt: str, **kwargs):
|
||||||
"""Generate a response to a prompt"""
|
"""Generate a response to a prompt"""
|
||||||
|
|||||||
Reference in New Issue
Block a user