mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
API: Add logprobs for chat completions
Adds chat completion logprob support using OAI's spec. Tokens are not converted to tiktoken here since that will add an extra dependency for no real reason. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -6,14 +6,14 @@ from uuid import uuid4
|
||||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
||||
|
||||
class ChatCompletionLogprobs(BaseModel):
|
||||
class ChatCompletionLogprob(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: List["ChatCompletionLogprobs"]
|
||||
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
|
||||
|
||||
|
||||
class WrappedChatCompletionLogprobs(BaseModel):
|
||||
content: List[ChatCompletionLogprobs]
|
||||
class ChatCompletionLogprobs(BaseModel):
|
||||
content: List[ChatCompletionLogprob] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
@@ -26,7 +26,7 @@ class ChatCompletionRespChoice(BaseModel):
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
message: ChatCompletionMessage
|
||||
logprobs: Optional[WrappedChatCompletionLogprobs] = None
|
||||
logprobs: Optional[ChatCompletionLogprobs] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamChoice(BaseModel):
|
||||
@@ -34,7 +34,7 @@ class ChatCompletionStreamChoice(BaseModel):
|
||||
index: int = 0
|
||||
finish_reason: Optional[str]
|
||||
delta: Union[ChatCompletionMessage, dict] = {}
|
||||
logprobs: Optional[WrappedChatCompletionLogprobs] = None
|
||||
logprobs: Optional[ChatCompletionLogprobs] = None
|
||||
|
||||
|
||||
# Inherited from common request
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Optional
|
||||
|
||||
from common.utils import unwrap
|
||||
from OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
ChatCompletionLogprob,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRespChoice,
|
||||
ChatCompletionStreamChunk,
|
||||
@@ -63,7 +65,32 @@ def create_chat_completion_response(generation: dict, model_name: Optional[str])
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason="Generated", message=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
|
||||
@@ -89,6 +116,8 @@ def create_chat_completion_stream_chunk(
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
|
||||
logprob_response = None
|
||||
|
||||
if finish_reason:
|
||||
message = {}
|
||||
else:
|
||||
@@ -96,8 +125,27 @@ def create_chat_completion_stream_chunk(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), {})
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs.items()
|
||||
]
|
||||
|
||||
generated_token = next(iter(token_probs))
|
||||
token_prob_response = ChatCompletionLogprob(
|
||||
token=generated_token,
|
||||
logprob=token_probs[generated_token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
||||
|
||||
# The finish reason can be None
|
||||
choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message)
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason=finish_reason, delta=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||
|
||||
@@ -520,6 +520,8 @@ class ExllamaV2Container:
|
||||
joined_generation["token_probs"].update(
|
||||
unwrap(generation.get("token_probs"), {})
|
||||
)
|
||||
|
||||
# Include empty logprob dicts for index preservation
|
||||
joined_generation["logprobs"].append(
|
||||
unwrap(generation.get("logprobs"), {})
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user