mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 17:51:36 +00:00
OAI: Initial vision support in OAI chat completions
* Support image_url inputs containing URLs or base64 strings following OAI vision spec * Use async lru cache for image embeddings * Add generic wrapper class for multimodal embeddings
This commit is contained in:
@@ -6,6 +6,8 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import pathlib
|
import pathlib
|
||||||
import traceback
|
import traceback
|
||||||
|
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||||
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
import torch
|
import torch
|
||||||
import uuid
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -816,6 +818,9 @@ class ExllamaV2Container:
|
|||||||
# Delete references held in the grammar module
|
# Delete references held in the grammar module
|
||||||
clear_grammar_func_cache()
|
clear_grammar_func_cache()
|
||||||
|
|
||||||
|
# Clear the image embedding cache
|
||||||
|
clear_image_embedding_cache()
|
||||||
|
|
||||||
# Unload LoRAs
|
# Unload LoRAs
|
||||||
if self.generator and self.generator.generator.current_loras:
|
if self.generator and self.generator.generator.current_loras:
|
||||||
for lora in self.generator.generator.current_loras:
|
for lora in self.generator.generator.current_loras:
|
||||||
@@ -908,12 +913,17 @@ class ExllamaV2Container:
|
|||||||
return dict(zip_longest(top_tokens, cleaned_values))
|
return dict(zip_longest(top_tokens, cleaned_values))
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
embeddings: MultimodalEmbeddingWrapper,
|
||||||
|
request_id: str,
|
||||||
|
abort_event: asyncio.Event = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Generate a response to a prompt."""
|
"""Generate a response to a prompt."""
|
||||||
generations = []
|
generations = []
|
||||||
async for generation in self.generate_gen(
|
async for generation in self.generate_gen(
|
||||||
prompt, request_id, abort_event, **kwargs
|
prompt, embeddings, request_id, abort_event, **kwargs
|
||||||
):
|
):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
@@ -979,6 +989,7 @@ class ExllamaV2Container:
|
|||||||
async def generate_gen(
|
async def generate_gen(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
embeddings: MultimodalEmbeddingWrapper,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
abort_event: Optional[asyncio.Event] = None,
|
abort_event: Optional[asyncio.Event] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -1246,7 +1257,10 @@ class ExllamaV2Container:
|
|||||||
# Encode both positive and negative prompts
|
# Encode both positive and negative prompts
|
||||||
input_ids = [
|
input_ids = [
|
||||||
self.tokenizer.encode(
|
self.tokenizer.encode(
|
||||||
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
prompt,
|
||||||
|
add_bos=add_bos_token,
|
||||||
|
encode_special_tokens=True,
|
||||||
|
embeddings=embeddings.content,
|
||||||
)
|
)
|
||||||
for prompt in prompts
|
for prompt in prompts
|
||||||
]
|
]
|
||||||
@@ -1297,6 +1311,7 @@ class ExllamaV2Container:
|
|||||||
banned_strings=banned_strings,
|
banned_strings=banned_strings,
|
||||||
token_healing=token_healing,
|
token_healing=token_healing,
|
||||||
identifier=job_id,
|
identifier=job_id,
|
||||||
|
embeddings=embeddings.content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save generated tokens and full response
|
# Save generated tokens and full response
|
||||||
|
|||||||
@@ -4,18 +4,14 @@ import io
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from common import model
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from common.networking import (
|
from common.networking import (
|
||||||
handle_request_error,
|
handle_request_error,
|
||||||
)
|
)
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from exllamav2 import (
|
from exllamav2.generator import ExLlamaV2MMEmbedding
|
||||||
ExLlamaV2,
|
from async_lru import alru_cache
|
||||||
ExLlamaV2Tokenizer,
|
|
||||||
ExLlamaV2VisionTower,
|
|
||||||
ExLlamaV2MMEmbedding,
|
|
||||||
)
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
|
|
||||||
async def get_image(url: str) -> Image:
|
async def get_image(url: str) -> Image:
|
||||||
@@ -50,14 +46,16 @@ async def get_image(url: str) -> Image:
|
|||||||
return Image.open(io.BytesIO(bytes_image))
|
return Image.open(io.BytesIO(bytes_image))
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(20)
|
@alru_cache(20)
|
||||||
async def get_image_embedding(
|
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
|
||||||
model: ExLlamaV2,
|
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
|
||||||
vision_model: ExLlamaV2VisionTower,
|
|
||||||
url: str,
|
|
||||||
) -> ExLlamaV2MMEmbedding:
|
|
||||||
image = await get_image(url)
|
image = await get_image(url)
|
||||||
return vision_model.get_image_embeddings(
|
return model.container.vision_model.get_image_embeddings(
|
||||||
model=model, tokenizer=tokenizer, image=image
|
model=model.container.model,
|
||||||
|
tokenizer=model.container.tokenizer,
|
||||||
|
image=image,
|
||||||
|
text_alias=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_image_embedding_cache():
|
||||||
|
get_image_embedding.cache_clear()
|
||||||
|
|||||||
36
common/multimodal.py
Normal file
36
common/multimodal.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import List
|
||||||
|
from backends.exllamav2.vision import get_image_embedding
|
||||||
|
from common import model
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from common.optional_dependencies import dependencies
|
||||||
|
|
||||||
|
if dependencies.exllamav2:
|
||||||
|
from exllamav2 import ExLlamaV2VisionTower
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalEmbeddingWrapper(BaseModel):
|
||||||
|
"""Common multimodal embedding wrapper"""
|
||||||
|
|
||||||
|
type: str = None
|
||||||
|
content: List = []
|
||||||
|
text_alias: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
async def add_image_embedding(
|
||||||
|
embeddings: MultimodalEmbeddingWrapper, url: str
|
||||||
|
) -> MultimodalEmbeddingWrapper:
|
||||||
|
# Determine the type of vision embedding to use
|
||||||
|
if not embeddings.type:
|
||||||
|
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
|
||||||
|
embeddings.type = "ExLlamaV2MMEmbedding"
|
||||||
|
|
||||||
|
if embeddings.type == "ExLlamaV2MMEmbedding":
|
||||||
|
embedding = await get_image_embedding(url)
|
||||||
|
embeddings.content.append(embedding)
|
||||||
|
embeddings.text_alias.append(embedding.text_alias)
|
||||||
|
else:
|
||||||
|
logger.error("No valid vision model to create embedding")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
@@ -17,6 +17,7 @@ from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
|||||||
from endpoints.OAI.utils.chat_completion import (
|
from endpoints.OAI.utils.chat_completion import (
|
||||||
format_prompt_with_template,
|
format_prompt_with_template,
|
||||||
generate_chat_completion,
|
generate_chat_completion,
|
||||||
|
preprocess_vision_request,
|
||||||
stream_generate_chat_completion,
|
stream_generate_chat_completion,
|
||||||
)
|
)
|
||||||
from endpoints.OAI.utils.completion import (
|
from endpoints.OAI.utils.completion import (
|
||||||
@@ -126,6 +127,8 @@ async def chat_completion_request(
|
|||||||
if isinstance(data.messages, str):
|
if isinstance(data.messages, str):
|
||||||
prompt = data.messages
|
prompt = data.messages
|
||||||
else:
|
else:
|
||||||
|
if model.container.use_vision:
|
||||||
|
data.messages, embeddings = await preprocess_vision_request(data.messages)
|
||||||
prompt = await format_prompt_with_template(data)
|
prompt = await format_prompt_with_template(data)
|
||||||
|
|
||||||
# Set an empty JSON schema if the request wants a JSON response
|
# Set an empty JSON schema if the request wants a JSON response
|
||||||
@@ -136,12 +139,14 @@ async def chat_completion_request(
|
|||||||
|
|
||||||
if data.stream and not disable_request_streaming:
|
if data.stream and not disable_request_streaming:
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
stream_generate_chat_completion(prompt, data, request, model_path),
|
stream_generate_chat_completion(
|
||||||
|
prompt, embeddings, data, request, model_path
|
||||||
|
),
|
||||||
ping=maxsize,
|
ping=maxsize,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generate_task = asyncio.create_task(
|
generate_task = asyncio.create_task(
|
||||||
generate_chat_completion(prompt, data, request, model_path)
|
generate_chat_completion(prompt, embeddings, data, request, model_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await run_with_request_disconnect(
|
response = await run_with_request_disconnect(
|
||||||
|
|||||||
@@ -3,9 +3,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from jinja2 import TemplateError
|
from jinja2 import TemplateError
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -279,7 +280,11 @@ async def format_prompt_with_template(
|
|||||||
|
|
||||||
|
|
||||||
async def stream_generate_chat_completion(
|
async def stream_generate_chat_completion(
|
||||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
prompt: str,
|
||||||
|
embeddings: MultimodalEmbeddingWrapper,
|
||||||
|
data: ChatCompletionRequest,
|
||||||
|
request: Request,
|
||||||
|
model_path: pathlib.Path,
|
||||||
):
|
):
|
||||||
"""Generator for the generation process."""
|
"""Generator for the generation process."""
|
||||||
abort_event = asyncio.Event()
|
abort_event = asyncio.Event()
|
||||||
@@ -298,6 +303,7 @@ async def stream_generate_chat_completion(
|
|||||||
n,
|
n,
|
||||||
gen_queue,
|
gen_queue,
|
||||||
prompt,
|
prompt,
|
||||||
|
embeddings,
|
||||||
request.state.id,
|
request.state.id,
|
||||||
abort_event,
|
abort_event,
|
||||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||||
@@ -372,7 +378,11 @@ async def stream_generate_chat_completion(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_chat_completion(
|
async def generate_chat_completion(
|
||||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
prompt: str,
|
||||||
|
embeddings: MultimodalEmbeddingWrapper,
|
||||||
|
data: ChatCompletionRequest,
|
||||||
|
request: Request,
|
||||||
|
model_path: pathlib.Path,
|
||||||
):
|
):
|
||||||
gen_tasks: List[asyncio.Task] = []
|
gen_tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
@@ -381,7 +391,10 @@ async def generate_chat_completion(
|
|||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
prompt, request.state.id, **data.model_dump(exclude={"prompt"})
|
prompt,
|
||||||
|
embeddings,
|
||||||
|
request.state.id,
|
||||||
|
**data.model_dump(exclude={"prompt"}),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
|||||||
tool_call["function"]["arguments"]
|
tool_call["function"]["arguments"]
|
||||||
)
|
)
|
||||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||||
|
|
||||||
|
|
||||||
|
async def preprocess_vision_request(messages: List[Dict]):
|
||||||
|
embeddings = MultimodalEmbeddingWrapper()
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
concatenated_content = ""
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
concatenated_content += content["text"]
|
||||||
|
elif content["type"] == "image_url":
|
||||||
|
embeddings = await add_image_embedding(
|
||||||
|
embeddings, content["image_url"]["url"]
|
||||||
|
)
|
||||||
|
concatenated_content += embeddings.text_alias[-1]
|
||||||
|
|
||||||
|
message["content"] = concatenated_content
|
||||||
|
|
||||||
|
return messages, embeddings
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Also serves as a common module for completions and chat completions.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
@@ -87,6 +88,7 @@ async def _stream_collector(
|
|||||||
task_idx: int,
|
task_idx: int,
|
||||||
gen_queue: asyncio.Queue,
|
gen_queue: asyncio.Queue,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
embeddings: MultimodalEmbeddingWrapper,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
abort_event: asyncio.Event,
|
abort_event: asyncio.Event,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -95,7 +97,7 @@ async def _stream_collector(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
new_generation = model.container.generate_gen(
|
new_generation = model.container.generate_gen(
|
||||||
prompt, request_id, abort_event, **kwargs
|
prompt, embeddings, request_id, abort_event, **kwargs
|
||||||
)
|
)
|
||||||
async for generation in new_generation:
|
async for generation in new_generation:
|
||||||
generation["index"] = task_idx
|
generation["index"] = task_idx
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ dependencies = [
|
|||||||
"lm-format-enforcer >= 0.9.6",
|
"lm-format-enforcer >= 0.9.6",
|
||||||
"aiofiles",
|
"aiofiles",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
"async_lru",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"psutil",
|
"psutil",
|
||||||
"httptools>=0.5.0",
|
"httptools>=0.5.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user