From 6cc3bd975253d9061f9bbb6d1899fbcf8a840c00 Mon Sep 17 00:00:00 2001 From: Orion Date: Tue, 4 Jun 2024 01:57:15 +0800 Subject: [PATCH] feat: list support in message.content (#122) --- endpoints/OAI/types/chat_completion.py | 3 ++- endpoints/OAI/utils/chat_completion.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 92265a7..be5cfea 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -41,7 +41,8 @@ class ChatCompletionStreamChoice(BaseModel): class ChatCompletionRequest(CommonCompletionRequest): # Messages # Take in a string as well even though it's not part of the OAI spec - messages: Union[str, List[Dict[str, str]]] + # support messages.content as a list of dict + messages: Union[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]] prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 19c50c0..9e82b1b 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -151,6 +151,19 @@ def format_prompt_with_template(data: ChatCompletionRequest): unwrap(data.ban_eos_token, False), ) + # Deal with list in messages.content + # Just replace the content list with the very first text message + for message in data.messages: + if message["role"] == "user" and isinstance(message["content"], list): + message["content"] = next( + ( + content["text"] + for content in message["content"] + if content["type"] == "text" + ), + "", + ) + # Overwrite any protected vars with their values data.template_vars.update( {