mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
OAI: Strictly type chat completions
Previously, the messages were a list of dicts. These are untyped and don't provide strict hinting. Add types for chat completion messages and reformat existing code. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,17 +1,16 @@
|
||||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Dict, List, Optional
|
||||
import json
|
||||
|
||||
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
@@ -214,21 +213,21 @@ async def format_prompt_with_template(
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
# Deal with list in messages.content
|
||||
# Just replace the content list with the very first text message
|
||||
# Convert list to text-based content
|
||||
# Use the first instance of text inside the part list
|
||||
for message in data.messages:
|
||||
if isinstance(message["content"], list):
|
||||
message["content"] = next(
|
||||
if isinstance(message.content, list):
|
||||
message.content = next(
|
||||
(
|
||||
content["text"]
|
||||
for content in message["content"]
|
||||
if content["type"] == "text"
|
||||
content.text
|
||||
for content in message.content
|
||||
if content.type == "text"
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
if "tool_calls" in message:
|
||||
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
|
||||
if message.tool_calls:
|
||||
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
|
||||
|
||||
# Overwrite any protected vars with their values
|
||||
data.template_vars.update(
|
||||
@@ -474,20 +473,21 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
async def preprocess_vision_request(messages: List[Dict]):
|
||||
# TODO: Combine this with the existing preprocessor in format_prompt_with_template
|
||||
async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
for message in messages:
|
||||
if isinstance(message["content"], list):
|
||||
if isinstance(message.content, list):
|
||||
concatenated_content = ""
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
concatenated_content += content["text"]
|
||||
elif content["type"] == "image_url":
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
concatenated_content += content.text
|
||||
elif content.type == "image_url":
|
||||
embeddings = await add_image_embedding(
|
||||
embeddings, content["image_url"]["url"]
|
||||
embeddings, content.image_url.url
|
||||
)
|
||||
concatenated_content += embeddings.text_alias[-1]
|
||||
|
||||
message["content"] = concatenated_content
|
||||
message.content = concatenated_content
|
||||
|
||||
return messages, embeddings
|
||||
|
||||
Reference in New Issue
Block a user