mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
OAI: Initial vision support in OAI chat completions
* Support image_url inputs containing URLs or base64 strings following OAI vision spec * Use async lru cache for image embeddings * Add generic wrapper class for multimodal embeddings
This commit is contained in:
@@ -6,6 +6,8 @@ import gc
|
||||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
import torch
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
@@ -816,6 +818,9 @@ class ExllamaV2Container:
|
||||
# Delete references held in the grammar module
|
||||
clear_grammar_func_cache()
|
||||
|
||||
# Clear the image embedding cache
|
||||
clear_image_embedding_cache()
|
||||
|
||||
# Unload LoRAs
|
||||
if self.generator and self.generator.generator.current_loras:
|
||||
for lora in self.generator.generator.current_loras:
|
||||
@@ -908,12 +913,17 @@ class ExllamaV2Container:
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate a response to a prompt."""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
prompt, embeddings, request_id, abort_event, **kwargs
|
||||
):
|
||||
generations.append(generation)
|
||||
|
||||
@@ -979,6 +989,7 @@ class ExllamaV2Container:
|
||||
async def generate_gen(
|
||||
self,
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
request_id: str,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
**kwargs,
|
||||
@@ -1246,7 +1257,10 @@ class ExllamaV2Container:
|
||||
# Encode both positive and negative prompts
|
||||
input_ids = [
|
||||
self.tokenizer.encode(
|
||||
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
||||
prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
embeddings=embeddings.content,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
@@ -1297,6 +1311,7 @@ class ExllamaV2Container:
|
||||
banned_strings=banned_strings,
|
||||
token_healing=token_healing,
|
||||
identifier=job_id,
|
||||
embeddings=embeddings.content,
|
||||
)
|
||||
|
||||
# Save generated tokens and full response
|
||||
|
||||
@@ -4,18 +4,14 @@ import io
|
||||
import base64
|
||||
import re
|
||||
from PIL import Image
|
||||
from common import model
|
||||
import aiohttp
|
||||
from common.networking import (
|
||||
handle_request_error,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2VisionTower,
|
||||
ExLlamaV2MMEmbedding,
|
||||
)
|
||||
from functools import lru_cache
|
||||
from exllamav2.generator import ExLlamaV2MMEmbedding
|
||||
from async_lru import alru_cache
|
||||
|
||||
|
||||
async def get_image(url: str) -> Image:
|
||||
@@ -50,14 +46,16 @@ async def get_image(url: str) -> Image:
|
||||
return Image.open(io.BytesIO(bytes_image))
|
||||
|
||||
|
||||
@lru_cache(20)
|
||||
async def get_image_embedding(
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
vision_model: ExLlamaV2VisionTower,
|
||||
url: str,
|
||||
) -> ExLlamaV2MMEmbedding:
|
||||
@alru_cache(20)
|
||||
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
|
||||
image = await get_image(url)
|
||||
return vision_model.get_image_embeddings(
|
||||
model=model, tokenizer=tokenizer, image=image
|
||||
return model.container.vision_model.get_image_embeddings(
|
||||
model=model.container.model,
|
||||
tokenizer=model.container.tokenizer,
|
||||
image=image,
|
||||
text_alias=None,
|
||||
)
|
||||
|
||||
|
||||
def clear_image_embedding_cache():
|
||||
get_image_embedding.cache_clear()
|
||||
|
||||
Reference in New Issue
Block a user