API: Add logprobs for chat completions

Adds chat completion logprob support using OAI's spec. Tokens are
not converted to tiktoken here since that will add an extra dependency
for no real reason.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-02-08 01:16:38 -05:00
committed by Brian Dashore
parent c02fe4d1db
commit c7428f0bcd
3 changed files with 58 additions and 8 deletions

View File

@@ -6,14 +6,14 @@ from uuid import uuid4
from OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionLogprobs(BaseModel):
class ChatCompletionLogprob(BaseModel):
token: str
logprob: float
top_logprobs: List["ChatCompletionLogprobs"]
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
class WrappedChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprobs]
class ChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprob] = Field(default_factory=list)
class ChatCompletionMessage(BaseModel):
@@ -26,7 +26,7 @@ class ChatCompletionRespChoice(BaseModel):
index: int = 0
finish_reason: str
message: ChatCompletionMessage
logprobs: Optional[WrappedChatCompletionLogprobs] = None
logprobs: Optional[ChatCompletionLogprobs] = None
class ChatCompletionStreamChoice(BaseModel):
@@ -34,7 +34,7 @@ class ChatCompletionStreamChoice(BaseModel):
index: int = 0
finish_reason: Optional[str]
delta: Union[ChatCompletionMessage, dict] = {}
logprobs: Optional[WrappedChatCompletionLogprobs] = None
logprobs: Optional[ChatCompletionLogprobs] = None
# Inherited from common request

View File

@@ -3,6 +3,8 @@ from typing import Optional
from common.utils import unwrap
from OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,
ChatCompletionMessage,
ChatCompletionRespChoice,
ChatCompletionStreamChunk,
@@ -63,7 +65,32 @@ def create_chat_completion_response(generation: dict, model_name: Optional[str])
role="assistant", content=unwrap(generation.get("text"), "")
)
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
collected_token_probs = []
for index, token in enumerate(token_probs.keys()):
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs[index].items()
]
collected_token_probs.append(
ChatCompletionLogprob(
token=token,
logprob=token_probs[token],
top_logprobs=top_logprobs,
)
)
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
choice = ChatCompletionRespChoice(
finish_reason="Generated", message=message, logprobs=logprob_response
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
@@ -89,6 +116,8 @@ def create_chat_completion_stream_chunk(
):
"""Create a chat completion stream chunk from the provided text."""
logprob_response = None
if finish_reason:
message = {}
else:
@@ -96,8 +125,27 @@ def create_chat_completion_stream_chunk(
role="assistant", content=unwrap(generation.get("text"), "")
)
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs.items()
]
generated_token = next(iter(token_probs))
token_prob_response = ChatCompletionLogprob(
token=generated_token,
logprob=token_probs[generated_token],
top_logprobs=top_logprobs,
)
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
# The finish reason can be None
choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message)
choice = ChatCompletionStreamChoice(
finish_reason=finish_reason, delta=message, logprobs=logprob_response
)
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")

View File

@@ -520,6 +520,8 @@ class ExllamaV2Container:
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)