diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index 62353d9..91558c1 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -5,8 +5,8 @@ from typing import Union, List, Optional from OAI.types.common import UsageStats, CommonCompletionRequest class ChatCompletionMessage(BaseModel): - role: str - content: str + role: Optional[str] = None + content: Optional[str] = None class ChatCompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices @@ -17,8 +17,8 @@ class ChatCompletionRespChoice(BaseModel): class ChatCompletionStreamChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 - finish_reason: str - delta: ChatCompletionMessage + finish_reason: Optional[str] + delta: Union[ChatCompletionRespChoice, dict] = {} # Inherited from common request class ChatCompletionRequest(CommonCompletionRequest): diff --git a/OAI/utils.py b/OAI/utils.py index 3cfdfef..7fc3126 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -57,16 +57,21 @@ def create_chat_completion_response(text: str, prompt_tokens: int, completion_to return response -def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Optional[str]): - # TODO: Add method to get token amounts in model for UsageStats - - message = ChatCompletionMessage( - role = "assistant", - content = text - ) +def create_chat_completion_stream_chunk(const_id: str, + text: Optional[str] = None, + model_name: Optional[str] = None, + finish_reason: Optional[str] = None): + if finish_reason: + message = {} + else: + message = ChatCompletionMessage( + role = "assistant", + content = text + ) + # The finish reason can be None choice = ChatCompletionStreamChoice( - finish_reason = "Generated", + finish_reason = finish_reason, delta = message ) @@ -95,8 +100,8 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]): return model_card_list def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]): - # Check if fastchat is available + # Check if fastchat is available if not _fastchat_available: raise ModuleNotFoundError( "Fastchat must be installed to parse these chat completion messages.\n" @@ -114,7 +119,7 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes for message in messages: msg_role = message.role if msg_role == "system": - conv.system_message = message.content + conv.set_system_message(message.content) elif msg_role == "user": conv.append_message(conv.roles[0], message.content) elif msg_role == "assistant": diff --git a/main.py b/main.py index 0d059a3..8c9b776 100644 --- a/main.py +++ b/main.py @@ -225,8 +225,8 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest const_id = f"chatcmpl-{uuid4().hex}" async def generator(): try: - new_generation, prompt_tokens, completion_tokens = model_container.generate_gen(prompt, **data.to_gen_params()) - for part in new_generation: + new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) + for (part, _, _) in new_generation: if await request.is_disconnected(): break @@ -239,6 +239,15 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest yield response.json(ensure_ascii=False) except Exception as e: yield get_generator_error(e) + finally: + + # Always finish no matter what + finish_response = create_chat_completion_stream_chunk( + const_id, + finish_reason = "stop" + ) + + yield finish_response.json(ensure_ascii=False) return EventSourceResponse(generator()) else: