OAI: Initial vision support in OAI chat completions

* Support image_url inputs containing URLs or base64 strings following OAI vision spec
* Use async lru cache for image embeddings
* Add generic wrapper class for multimodal embeddings
This commit is contained in:
DocShotgun
2024-11-17 21:23:09 -08:00
parent 5fa298e601
commit dd41eec8a4
7 changed files with 115 additions and 26 deletions

View File

@@ -6,6 +6,8 @@ import gc
import math import math
import pathlib import pathlib
import traceback import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch import torch
import uuid import uuid
from copy import deepcopy from copy import deepcopy
@@ -816,6 +818,9 @@ class ExllamaV2Container:
# Delete references held in the grammar module # Delete references held in the grammar module
clear_grammar_func_cache() clear_grammar_func_cache()
# Clear the image embedding cache
clear_image_embedding_cache()
# Unload LoRAs # Unload LoRAs
if self.generator and self.generator.generator.current_loras: if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras: for lora in self.generator.generator.current_loras:
@@ -908,12 +913,17 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values)) return dict(zip_longest(top_tokens, cleaned_values))
async def generate( async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
): ):
"""Generate a response to a prompt.""" """Generate a response to a prompt."""
generations = [] generations = []
async for generation in self.generate_gen( async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs prompt, embeddings, request_id, abort_event, **kwargs
): ):
generations.append(generation) generations.append(generation)
@@ -979,6 +989,7 @@ class ExllamaV2Container:
async def generate_gen( async def generate_gen(
self, self,
prompt: str, prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str, request_id: str,
abort_event: Optional[asyncio.Event] = None, abort_event: Optional[asyncio.Event] = None,
**kwargs, **kwargs,
@@ -1246,7 +1257,10 @@ class ExllamaV2Container:
# Encode both positive and negative prompts # Encode both positive and negative prompts
input_ids = [ input_ids = [
self.tokenizer.encode( self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=embeddings.content,
) )
for prompt in prompts for prompt in prompts
] ]
@@ -1297,6 +1311,7 @@ class ExllamaV2Container:
banned_strings=banned_strings, banned_strings=banned_strings,
token_healing=token_healing, token_healing=token_healing,
identifier=job_id, identifier=job_id,
embeddings=embeddings.content,
) )
# Save generated tokens and full response # Save generated tokens and full response

View File

@@ -4,18 +4,14 @@ import io
import base64 import base64
import re import re
from PIL import Image from PIL import Image
from common import model
import aiohttp import aiohttp
from common.networking import ( from common.networking import (
handle_request_error, handle_request_error,
) )
from fastapi import HTTPException from fastapi import HTTPException
from exllamav2 import ( from exllamav2.generator import ExLlamaV2MMEmbedding
ExLlamaV2, from async_lru import alru_cache
ExLlamaV2Tokenizer,
ExLlamaV2VisionTower,
ExLlamaV2MMEmbedding,
)
from functools import lru_cache
async def get_image(url: str) -> Image: async def get_image(url: str) -> Image:
@@ -50,14 +46,16 @@ async def get_image(url: str) -> Image:
return Image.open(io.BytesIO(bytes_image)) return Image.open(io.BytesIO(bytes_image))
@lru_cache(20) @alru_cache(20)
async def get_image_embedding( async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
vision_model: ExLlamaV2VisionTower,
url: str,
) -> ExLlamaV2MMEmbedding:
image = await get_image(url) image = await get_image(url)
return vision_model.get_image_embeddings( return model.container.vision_model.get_image_embeddings(
model=model, tokenizer=tokenizer, image=image model=model.container.model,
tokenizer=model.container.tokenizer,
image=image,
text_alias=None,
) )
def clear_image_embedding_cache():
get_image_embedding.cache_clear()

36
common/multimodal.py Normal file
View File

@@ -0,0 +1,36 @@
from typing import List
from backends.exllamav2.vision import get_image_embedding
from common import model
from pydantic import BaseModel
from loguru import logger
from common.optional_dependencies import dependencies
if dependencies.exllamav2:
from exllamav2 import ExLlamaV2VisionTower
class MultimodalEmbeddingWrapper(BaseModel):
"""Common multimodal embedding wrapper"""
type: str = None
content: List = []
text_alias: List[str] = []
async def add_image_embedding(
embeddings: MultimodalEmbeddingWrapper, url: str
) -> MultimodalEmbeddingWrapper:
# Determine the type of vision embedding to use
if not embeddings.type:
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
embeddings.type = "ExLlamaV2MMEmbedding"
if embeddings.type == "ExLlamaV2MMEmbedding":
embedding = await get_image_embedding(url)
embeddings.content.append(embedding)
embeddings.text_alias.append(embedding.text_alias)
else:
logger.error("No valid vision model to create embedding")
return embeddings

View File

@@ -17,6 +17,7 @@ from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import ( from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template, format_prompt_with_template,
generate_chat_completion, generate_chat_completion,
preprocess_vision_request,
stream_generate_chat_completion, stream_generate_chat_completion,
) )
from endpoints.OAI.utils.completion import ( from endpoints.OAI.utils.completion import (
@@ -126,6 +127,8 @@ async def chat_completion_request(
if isinstance(data.messages, str): if isinstance(data.messages, str):
prompt = data.messages prompt = data.messages
else: else:
if model.container.use_vision:
data.messages, embeddings = await preprocess_vision_request(data.messages)
prompt = await format_prompt_with_template(data) prompt = await format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response # Set an empty JSON schema if the request wants a JSON response
@@ -136,12 +139,14 @@ async def chat_completion_request(
if data.stream and not disable_request_streaming: if data.stream and not disable_request_streaming:
return EventSourceResponse( return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path), stream_generate_chat_completion(
prompt, embeddings, data, request, model_path
),
ping=maxsize, ping=maxsize,
) )
else: else:
generate_task = asyncio.create_task( generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path) generate_chat_completion(prompt, embeddings, data, request, model_path)
) )
response = await run_with_request_disconnect( response = await run_with_request_disconnect(

View File

@@ -3,9 +3,10 @@
import asyncio import asyncio
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from typing import List, Optional from typing import Dict, List, Optional
import json import json
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from jinja2 import TemplateError from jinja2 import TemplateError
from loguru import logger from loguru import logger
@@ -279,7 +280,11 @@ async def format_prompt_with_template(
async def stream_generate_chat_completion( async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
): ):
"""Generator for the generation process.""" """Generator for the generation process."""
abort_event = asyncio.Event() abort_event = asyncio.Event()
@@ -298,6 +303,7 @@ async def stream_generate_chat_completion(
n, n,
gen_queue, gen_queue,
prompt, prompt,
embeddings,
request.state.id, request.state.id,
abort_event, abort_event,
**task_gen_params.model_dump(exclude={"prompt"}), **task_gen_params.model_dump(exclude={"prompt"}),
@@ -372,7 +378,11 @@ async def stream_generate_chat_completion(
async def generate_chat_completion( async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
): ):
gen_tasks: List[asyncio.Task] = [] gen_tasks: List[asyncio.Task] = []
@@ -381,7 +391,10 @@ async def generate_chat_completion(
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
prompt, request.state.id, **data.model_dump(exclude={"prompt"}) prompt,
embeddings,
request.state.id,
**data.model_dump(exclude={"prompt"}),
) )
) )
) )
@@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_call["function"]["arguments"] tool_call["function"]["arguments"]
) )
return [ToolCall(**tool_call) for tool_call in tool_calls] return [ToolCall(**tool_call) for tool_call in tool_calls]
async def preprocess_vision_request(messages: List[Dict]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message["content"], list):
concatenated_content = ""
for content in message["content"]:
if content["type"] == "text":
concatenated_content += content["text"]
elif content["type"] == "image_url":
embeddings = await add_image_embedding(
embeddings, content["image_url"]["url"]
)
concatenated_content += embeddings.text_alias[-1]
message["content"] = concatenated_content
return messages, embeddings

View File

@@ -7,6 +7,7 @@ Also serves as a common module for completions and chat completions.
import asyncio import asyncio
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from typing import List, Union from typing import List, Union
@@ -87,6 +88,7 @@ async def _stream_collector(
task_idx: int, task_idx: int,
gen_queue: asyncio.Queue, gen_queue: asyncio.Queue,
prompt: str, prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str, request_id: str,
abort_event: asyncio.Event, abort_event: asyncio.Event,
**kwargs, **kwargs,
@@ -95,7 +97,7 @@ async def _stream_collector(
try: try:
new_generation = model.container.generate_gen( new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs prompt, embeddings, request_id, abort_event, **kwargs
) )
async for generation in new_generation: async for generation in new_generation:
generation["index"] = task_idx generation["index"] = task_idx

View File

@@ -29,6 +29,7 @@ dependencies = [
"lm-format-enforcer >= 0.9.6", "lm-format-enforcer >= 0.9.6",
"aiofiles", "aiofiles",
"aiohttp", "aiohttp",
"async_lru",
"huggingface_hub", "huggingface_hub",
"psutil", "psutil",
"httptools>=0.5.0", "httptools>=0.5.0",