diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index ba0a968..07418fc 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -6,6 +6,16 @@ from uuid import uuid4 from OAI.types.common import UsageStats, CommonCompletionRequest +class ChatCompletionLogprobs(BaseModel): + token: str + logprob: float + top_logprobs: List["ChatCompletionLogprobs"] + + +class WrappedChatCompletionLogprobs(BaseModel): + content: List[ChatCompletionLogprobs] + + class ChatCompletionMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None @@ -16,6 +26,7 @@ class ChatCompletionRespChoice(BaseModel): index: int = 0 finish_reason: str message: ChatCompletionMessage + logprobs: Optional[WrappedChatCompletionLogprobs] = None class ChatCompletionStreamChoice(BaseModel): @@ -23,6 +34,7 @@ class ChatCompletionStreamChoice(BaseModel): index: int = 0 finish_reason: Optional[str] delta: Union[ChatCompletionMessage, dict] = {} + logprobs: Optional[WrappedChatCompletionLogprobs] = None # Inherited from common request diff --git a/OAI/utils/completion.py b/OAI/utils/completion.py index fddf666..46b7e70 100644 --- a/OAI/utils/completion.py +++ b/OAI/utils/completion.py @@ -17,32 +17,35 @@ from OAI.types.completion import ( from OAI.types.common import UsageStats -def create_completion_response(**kwargs): +def create_completion_response(generation: dict, model_name: Optional[str]): """Create a completion response from the provided text.""" - token_probs = unwrap(kwargs.get("token_probs"), {}) - logprobs = unwrap(kwargs.get("logprobs"), []) - offset = unwrap(kwargs.get("offset"), []) + logprob_response = None - logprob_response = CompletionLogProbs( - text_offset=offset if isinstance(offset, list) else [offset], - token_logprobs=token_probs.values(), - tokens=token_probs.keys(), - top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], - ) + token_probs = unwrap(generation.get("token_probs"), {}) + if token_probs: + logprobs = unwrap(generation.get("logprobs"), []) + offset = unwrap(generation.get("offset"), []) + + logprob_response = CompletionLogProbs( + text_offset=offset if isinstance(offset, list) else [offset], + token_logprobs=token_probs.values(), + tokens=token_probs.keys(), + top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], + ) choice = CompletionRespChoice( finish_reason="Generated", - text=unwrap(kwargs.get("text"), ""), + text=unwrap(generation.get("text"), ""), logprobs=logprob_response, ) - prompt_tokens = unwrap(kwargs.get("prompt_tokens"), 0) - completion_tokens = unwrap(kwargs.get("completion_tokens"), 0) + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("completion_tokens"), 0) response = CompletionResponse( choices=[choice], - model=unwrap(kwargs.get("model_name"), ""), + model=unwrap(model_name, ""), usage=UsageStats( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, @@ -53,17 +56,18 @@ def create_completion_response(**kwargs): return response -def create_chat_completion_response( - text: str, - prompt_tokens: Optional[int], - completion_tokens: Optional[int], - model_name: Optional[str], -): +def create_chat_completion_response(generation: dict, model_name: Optional[str]): """Create a chat completion response from the provided text.""" - message = ChatCompletionMessage(role="assistant", content=unwrap(text, "")) + + message = ChatCompletionMessage( + role="assistant", content=unwrap(generation.get("text"), "") + ) choice = ChatCompletionRespChoice(finish_reason="Generated", message=message) + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("completion_tokens"), 0) + response = ChatCompletionResponse( choices=[choice], model=unwrap(model_name, ""), @@ -79,15 +83,18 @@ def create_chat_completion_response( def create_chat_completion_stream_chunk( const_id: str, - text: Optional[str] = None, + generation: Optional[dict] = None, model_name: Optional[str] = None, finish_reason: Optional[str] = None, ): """Create a chat completion stream chunk from the provided text.""" + if finish_reason: message = {} else: - message = ChatCompletionMessage(role="assistant", content=text) + message = ChatCompletionMessage( + role="assistant", content=unwrap(generation.get("text"), "") + ) # The finish reason can be None choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index e86f4cb..0c37533 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -505,7 +505,7 @@ class ExllamaV2Container: generations = list(self.generate_gen(prompt, **kwargs)) joined_generation = { - "chunk": "", + "text": "", "prompt_tokens": 0, "generation_tokens": 0, "offset": [], @@ -515,7 +515,7 @@ class ExllamaV2Container: if generations: for generation in generations: - joined_generation["chunk"] += unwrap(generation.get("chunk"), "") + joined_generation["text"] += unwrap(generation.get("text"), "") joined_generation["offset"].append(unwrap(generation.get("offset"), [])) joined_generation["token_probs"].update( unwrap(generation.get("token_probs"), {}) @@ -835,7 +835,7 @@ class ExllamaV2Container: elapsed > stream_interval or eos or generated_tokens == max_tokens ): generation = { - "chunk": chunk_buffer, + "text": chunk_buffer, "prompt_tokens": prompt_tokens, "generated_tokens": generated_tokens, "offset": len(full_response), diff --git a/main.py b/main.py index e875a06..edb1747 100644 --- a/main.py +++ b/main.py @@ -462,10 +462,7 @@ async def generate_completion(request: Request, data: CompletionRequest): if await request.is_disconnected(): break - response = create_completion_response( - **generation, - model_name=model_path.name, - ) + response = create_completion_response(generation, model_path.name) yield get_sse_packet(response.model_dump_json()) @@ -483,7 +480,7 @@ async def generate_completion(request: Request, data: CompletionRequest): generation = await call_with_semaphore( partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params()) ) - response = create_completion_response(**generation) + response = create_completion_response(generation, model_path.name) return response @@ -548,7 +545,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest break response = create_chat_completion_stream_chunk( - const_id, generation.get("chunk"), model_path.name + const_id, generation, model_path.name ) yield get_sse_packet(response.model_dump_json()) @@ -568,13 +565,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest generate_with_semaphore(generator), media_type="text/event-stream" ) - response_text, prompt_tokens, completion_tokens = await call_with_semaphore( + generation = await call_with_semaphore( partial(MODEL_CONTAINER.generate, prompt, **data.to_gen_params()) ) - - response = create_chat_completion_response( - response_text, prompt_tokens, completion_tokens, model_path.name - ) + response = create_chat_completion_response(generation, model_path.name) return response