OAI: Initial vision support in OAI chat completions

* Support image_url inputs containing URLs or base64 strings following OAI vision spec
* Use async lru cache for image embeddings
* Add generic wrapper class for multimodal embeddings
This commit is contained in:
DocShotgun
2024-11-17 21:23:09 -08:00
parent 5fa298e601
commit dd41eec8a4
7 changed files with 115 additions and 26 deletions

View File

@@ -6,6 +6,8 @@ import gc
import math
import pathlib
import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch
import uuid
from copy import deepcopy
@@ -816,6 +818,9 @@ class ExllamaV2Container:
# Delete references held in the grammar module
clear_grammar_func_cache()
# Clear the image embedding cache
clear_image_embedding_cache()
# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
@@ -908,12 +913,17 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values))
async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
):
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
prompt, embeddings, request_id, abort_event, **kwargs
):
generations.append(generation)
@@ -979,6 +989,7 @@ class ExllamaV2Container:
async def generate_gen(
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
@@ -1246,7 +1257,10 @@ class ExllamaV2Container:
# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=embeddings.content,
)
for prompt in prompts
]
@@ -1297,6 +1311,7 @@ class ExllamaV2Container:
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
embeddings=embeddings.content,
)
# Save generated tokens and full response

View File

@@ -4,18 +4,14 @@ import io
import base64
import re
from PIL import Image
from common import model
import aiohttp
from common.networking import (
handle_request_error,
)
from fastapi import HTTPException
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Tokenizer,
ExLlamaV2VisionTower,
ExLlamaV2MMEmbedding,
)
from functools import lru_cache
from exllamav2.generator import ExLlamaV2MMEmbedding
from async_lru import alru_cache
async def get_image(url: str) -> Image:
@@ -50,14 +46,16 @@ async def get_image(url: str) -> Image:
return Image.open(io.BytesIO(bytes_image))
@lru_cache(20)
async def get_image_embedding(
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
vision_model: ExLlamaV2VisionTower,
url: str,
) -> ExLlamaV2MMEmbedding:
@alru_cache(20)
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
image = await get_image(url)
return vision_model.get_image_embeddings(
model=model, tokenizer=tokenizer, image=image
return model.container.vision_model.get_image_embeddings(
model=model.container.model,
tokenizer=model.container.tokenizer,
image=image,
text_alias=None,
)
def clear_image_embedding_cache():
get_image_embedding.cache_clear()