mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-30 03:01:44 +00:00
add logprobs support for exl3
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import gc
|
import gc
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
|
from itertools import zip_longest
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@@ -608,6 +609,22 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
"unk_token": self.tokenizer.unk_token,
|
"unk_token": self.tokenizer.unk_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
|
||||||
|
top_tokens = [
|
||||||
|
self.tokenizer.get_id_to_piece_list(True)[index]
|
||||||
|
for index in token_ids.flatten().tolist()
|
||||||
|
]
|
||||||
|
|
||||||
|
top_values = torch.log(token_probs).flatten().tolist()
|
||||||
|
|
||||||
|
# Cannot return -inf in JSON
|
||||||
|
cleaned_values = [
|
||||||
|
-1000 if value == float("-inf") else value for value in top_values
|
||||||
|
]
|
||||||
|
|
||||||
|
return dict(zip_longest(top_tokens, cleaned_values))
|
||||||
|
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@@ -730,6 +747,26 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
# Clean up and remove the job from active IDs
|
# Clean up and remove the job from active IDs
|
||||||
del self.active_job_ids[request_id]
|
del self.active_job_ids[request_id]
|
||||||
|
|
||||||
|
def handle_logprobs(self, result: dict, generation: dict):
|
||||||
|
top_tokens = unwrap(
|
||||||
|
result.get("top_k_tokens"),
|
||||||
|
torch.empty((1, 0, 1), dtype=torch.long),
|
||||||
|
)
|
||||||
|
|
||||||
|
top_probs = unwrap(
|
||||||
|
result.get("top_k_probs"),
|
||||||
|
torch.empty((1, 0, 1), dtype=torch.float),
|
||||||
|
)
|
||||||
|
|
||||||
|
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
||||||
|
logprobs = self.get_logprobs(top_tokens, top_probs)
|
||||||
|
generation["logprobs"] = logprobs
|
||||||
|
|
||||||
|
# The first logprob is the selected token prob
|
||||||
|
generation["token_probs"] = {
|
||||||
|
token: logprobs[token] for token in list(logprobs.keys())[:1]
|
||||||
|
}
|
||||||
|
|
||||||
def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
|
def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
|
||||||
eos_reason = result.get("eos_reason")
|
eos_reason = result.get("eos_reason")
|
||||||
|
|
||||||
@@ -915,6 +952,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
banned_strings=params.banned_strings,
|
banned_strings=params.banned_strings,
|
||||||
embeddings=mm_embeddings_content,
|
embeddings=mm_embeddings_content,
|
||||||
|
return_top_tokens=params.logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_tokens = 0
|
generated_tokens = 0
|
||||||
@@ -948,6 +986,10 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
"generated_tokens": generated_tokens,
|
"generated_tokens": generated_tokens,
|
||||||
"offset": len(full_response),
|
"offset": len(full_response),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if params.logprobs > 0:
|
||||||
|
self.handle_logprobs(result, generation)
|
||||||
|
|
||||||
yield generation
|
yield generation
|
||||||
|
|
||||||
if result.get("eos"):
|
if result.get("eos"):
|
||||||
|
|||||||
Reference in New Issue
Block a user