mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
API: Fix types for chat completions
Messages were mistakenly being sent as Pydantic objects, but templates expect dictionaries. Properly convert these before render. In addition, initialize all Optional lists as an empty list since this will cause the least problems when interacting with other parts of API code, such as templates. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
|
||||
class ChatCompletionLogprob(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
|
||||
top_logprobs: Optional[List["ChatCompletionLogprob"]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogprobs(BaseModel):
|
||||
@@ -30,8 +30,10 @@ class ChatCompletionMessagePart(BaseModel):
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: str = "user"
|
||||
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
tool_calls_json: SkipJsonSchema[Optional[str]] = None
|
||||
|
||||
|
||||
@@ -76,13 +78,15 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
||||
# tools is follows the format OAI schema, functions is more flexible
|
||||
# both are available in the chat template.
|
||||
|
||||
tools: Optional[List[ToolSpec]] = None
|
||||
functions: Optional[List[Dict]] = None
|
||||
tools: Optional[List[ToolSpec]] = Field(default_factory=list)
|
||||
functions: Optional[List[Dict]] = Field(default_factory=list)
|
||||
|
||||
# Typically collected from Chat Template.
|
||||
# Don't include this in the OpenAPI docs
|
||||
# TODO: Use these custom parameters
|
||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None
|
||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
||||
|
||||
|
||||
@@ -210,14 +210,14 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
|
||||
async def format_messages_with_template(
|
||||
messages: List[ChatCompletionMessage],
|
||||
existing_template_vars: Optional[dict] = None,
|
||||
add_bos_token: bool = True,
|
||||
ban_eos_token: bool = False,
|
||||
):
|
||||
"""Barebones function to format chat completion messages into a prompt."""
|
||||
|
||||
template_vars = unwrap(existing_template_vars, {})
|
||||
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
|
||||
|
||||
# Convert all messages to a dictionary representation
|
||||
message_dicts: List[dict] = []
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
concatenated_content = ""
|
||||
@@ -238,9 +238,12 @@ async def format_messages_with_template(
|
||||
# store the list of dicts rather than the ToolCallProcessor object.
|
||||
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
|
||||
|
||||
message_dicts.append(message.model_dump())
|
||||
|
||||
# Get all special tokens
|
||||
special_tokens_dict = model.container.get_special_tokens()
|
||||
|
||||
template_vars.update({"messages": messages, **special_tokens_dict})
|
||||
template_vars.update({"messages": message_dicts, **special_tokens_dict})
|
||||
|
||||
prompt = await model.container.prompt_template.render(template_vars)
|
||||
return prompt, mm_embeddings, template_vars
|
||||
@@ -270,7 +273,7 @@ async def apply_chat_template(
|
||||
)
|
||||
|
||||
prompt, mm_embeddings, template_vars = await format_messages_with_template(
|
||||
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
|
||||
data.messages, data.template_vars
|
||||
)
|
||||
|
||||
# Append response prefix if present
|
||||
|
||||
Reference in New Issue
Block a user