OAI: Initial vision support in OAI chat completions

* Support image_url inputs containing URLs or base64 strings following OAI vision spec
* Use async lru cache for image embeddings
* Add generic wrapper class for multimodal embeddings
This commit is contained in:
DocShotgun
2024-11-17 21:23:09 -08:00
parent 5fa298e601
commit dd41eec8a4
7 changed files with 115 additions and 26 deletions

View File

@@ -3,9 +3,10 @@
import asyncio
import pathlib
from asyncio import CancelledError
from typing import List, Optional
from typing import Dict, List, Optional
import json
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
@@ -279,7 +280,11 @@ async def format_prompt_with_template(
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
"""Generator for the generation process."""
abort_event = asyncio.Event()
@@ -298,6 +303,7 @@ async def stream_generate_chat_completion(
n,
gen_queue,
prompt,
embeddings,
request.state.id,
abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
@@ -372,7 +378,11 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
gen_tasks: List[asyncio.Task] = []
@@ -381,7 +391,10 @@ async def generate_chat_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
prompt, request.state.id, **data.model_dump(exclude={"prompt"})
prompt,
embeddings,
request.state.id,
**data.model_dump(exclude={"prompt"}),
)
)
)
@@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]
async def preprocess_vision_request(messages: List[Dict]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message["content"], list):
concatenated_content = ""
for content in message["content"]:
if content["type"] == "text":
concatenated_content += content["text"]
elif content["type"] == "image_url":
embeddings = await add_image_embedding(
embeddings, content["image_url"]["url"]
)
concatenated_content += embeddings.text_alias[-1]
message["content"] = concatenated_content
return messages, embeddings