OAI: Strictly type chat completions

Previously, the messages were a list of dicts. These are untyped
and don't provide strict hinting. Add types for chat completion
messages and reformat existing code.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-11-19 23:15:47 -05:00
parent 0fadb1e5e8
commit 8ffc636dce
3 changed files with 37 additions and 25 deletions

View File

@@ -1,17 +1,16 @@
"""Chat completion utilities for OAI server."""
import asyncio
import json
import pathlib
from asyncio import CancelledError
from typing import Dict, List, Optional
import json
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from typing import List, Optional
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
from common import model
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from common.networking import (
get_generator_error,
handle_request_disconnect,
@@ -214,21 +213,21 @@ async def format_prompt_with_template(
unwrap(data.ban_eos_token, False),
)
# Deal with list in messages.content
# Just replace the content list with the very first text message
# Convert list to text-based content
# Use the first instance of text inside the part list
for message in data.messages:
if isinstance(message["content"], list):
message["content"] = next(
if isinstance(message.content, list):
message.content = next(
(
content["text"]
for content in message["content"]
if content["type"] == "text"
content.text
for content in message.content
if content.type == "text"
),
"",
)
if "tool_calls" in message:
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
# Overwrite any protected vars with their values
data.template_vars.update(
@@ -474,20 +473,21 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
return [ToolCall(**tool_call) for tool_call in tool_calls]
async def preprocess_vision_request(messages: List[Dict]):
# TODO: Combine this with the existing preprocessor in format_prompt_with_template
async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message["content"], list):
if isinstance(message.content, list):
concatenated_content = ""
for content in message["content"]:
if content["type"] == "text":
concatenated_content += content["text"]
elif content["type"] == "image_url":
for content in message.content:
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url":
embeddings = await add_image_embedding(
embeddings, content["image_url"]["url"]
embeddings, content.image_url.url
)
concatenated_content += embeddings.text_alias[-1]
message["content"] = concatenated_content
message.content = concatenated_content
return messages, embeddings