diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index 07418fc..bafbdd2 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -6,14 +6,14 @@ from uuid import uuid4 from OAI.types.common import UsageStats, CommonCompletionRequest -class ChatCompletionLogprobs(BaseModel): +class ChatCompletionLogprob(BaseModel): token: str logprob: float - top_logprobs: List["ChatCompletionLogprobs"] + top_logprobs: Optional[List["ChatCompletionLogprob"]] = None -class WrappedChatCompletionLogprobs(BaseModel): - content: List[ChatCompletionLogprobs] +class ChatCompletionLogprobs(BaseModel): + content: List[ChatCompletionLogprob] = Field(default_factory=list) class ChatCompletionMessage(BaseModel): @@ -26,7 +26,7 @@ class ChatCompletionRespChoice(BaseModel): index: int = 0 finish_reason: str message: ChatCompletionMessage - logprobs: Optional[WrappedChatCompletionLogprobs] = None + logprobs: Optional[ChatCompletionLogprobs] = None class ChatCompletionStreamChoice(BaseModel): @@ -34,7 +34,7 @@ class ChatCompletionStreamChoice(BaseModel): index: int = 0 finish_reason: Optional[str] delta: Union[ChatCompletionMessage, dict] = {} - logprobs: Optional[WrappedChatCompletionLogprobs] = None + logprobs: Optional[ChatCompletionLogprobs] = None # Inherited from common request diff --git a/OAI/utils/completion.py b/OAI/utils/completion.py index 46b7e70..81effe2 100644 --- a/OAI/utils/completion.py +++ b/OAI/utils/completion.py @@ -3,6 +3,8 @@ from typing import Optional from common.utils import unwrap from OAI.types.chat_completion import ( + ChatCompletionLogprobs, + ChatCompletionLogprob, ChatCompletionMessage, ChatCompletionRespChoice, ChatCompletionStreamChunk, @@ -63,7 +65,32 @@ def create_chat_completion_response(generation: dict, model_name: Optional[str]) role="assistant", content=unwrap(generation.get("text"), "") ) - choice = ChatCompletionRespChoice(finish_reason="Generated", message=message) + logprob_response = None + + token_probs = unwrap(generation.get("token_probs"), {}) + if token_probs: + logprobs = unwrap(generation.get("logprobs"), []) + + collected_token_probs = [] + for index, token in enumerate(token_probs.keys()): + top_logprobs = [ + ChatCompletionLogprob(token=token, logprob=logprob) + for token, logprob in logprobs[index].items() + ] + + collected_token_probs.append( + ChatCompletionLogprob( + token=token, + logprob=token_probs[token], + top_logprobs=top_logprobs, + ) + ) + + logprob_response = ChatCompletionLogprobs(content=collected_token_probs) + + choice = ChatCompletionRespChoice( + finish_reason="Generated", message=message, logprobs=logprob_response + ) prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) completion_tokens = unwrap(generation.get("completion_tokens"), 0) @@ -89,6 +116,8 @@ def create_chat_completion_stream_chunk( ): """Create a chat completion stream chunk from the provided text.""" + logprob_response = None + if finish_reason: message = {} else: @@ -96,8 +125,27 @@ def create_chat_completion_stream_chunk( role="assistant", content=unwrap(generation.get("text"), "") ) + token_probs = unwrap(generation.get("token_probs"), {}) + if token_probs: + logprobs = unwrap(generation.get("logprobs"), {}) + top_logprobs = [ + ChatCompletionLogprob(token=token, logprob=logprob) + for token, logprob in logprobs.items() + ] + + generated_token = next(iter(token_probs)) + token_prob_response = ChatCompletionLogprob( + token=generated_token, + logprob=token_probs[generated_token], + top_logprobs=top_logprobs, + ) + + logprob_response = ChatCompletionLogprobs(content=[token_prob_response]) + # The finish reason can be None - choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message) + choice = ChatCompletionStreamChoice( + finish_reason=finish_reason, delta=message, logprobs=logprob_response + ) chunk = ChatCompletionStreamChunk( id=const_id, choices=[choice], model=unwrap(model_name, "") diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 0c37533..45a79df 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -520,6 +520,8 @@ class ExllamaV2Container: joined_generation["token_probs"].update( unwrap(generation.get("token_probs"), {}) ) + + # Include empty logprob dicts for index preservation joined_generation["logprobs"].append( unwrap(generation.get("logprobs"), {}) )