diff --git a/common/multimodal.py b/common/multimodal.py index 74d4964..5b93f23 100644 --- a/common/multimodal.py +++ b/common/multimodal.py @@ -1,7 +1,6 @@ from typing import List from backends.exllamav2.vision import get_image_embedding from common import model -from pydantic import BaseModel from loguru import logger from common.optional_dependencies import dependencies @@ -10,27 +9,22 @@ if dependencies.exllamav2: from exllamav2 import ExLlamaV2VisionTower -class MultimodalEmbeddingWrapper(BaseModel): +class MultimodalEmbeddingWrapper: """Common multimodal embedding wrapper""" type: str = None content: List = [] text_alias: List[str] = [] + async def add(self, url: str): + # Determine the type of vision embedding to use + if not self.type: + if isinstance(model.container.vision_model, ExLlamaV2VisionTower): + self.type = "ExLlamaV2MMEmbedding" -async def add_image_embedding( - embeddings: MultimodalEmbeddingWrapper, url: str -) -> MultimodalEmbeddingWrapper: - # Determine the type of vision embedding to use - if not embeddings.type: - if isinstance(model.container.vision_model, ExLlamaV2VisionTower): - embeddings.type = "ExLlamaV2MMEmbedding" - - if embeddings.type == "ExLlamaV2MMEmbedding": - embedding = await get_image_embedding(url) - embeddings.content.append(embedding) - embeddings.text_alias.append(embedding.text_alias) - else: - logger.error("No valid vision model to create embedding") - - return embeddings + if self.type == "ExLlamaV2MMEmbedding": + embedding = await get_image_embedding(url) + self.content.append(embedding) + self.text_alias.append(embedding.text_alias) + else: + logger.error("No valid vision model to create embedding") diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index c14a8dc..7a31f39 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -10,7 +10,7 @@ from jinja2 import TemplateError from loguru import logger from common import model -from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding +from common.multimodal import MultimodalEmbeddingWrapper from common.networking import ( get_generator_error, handle_request_disconnect, @@ -483,9 +483,7 @@ async def preprocess_vision_request(messages: List[ChatCompletionMessage]): if content.type == "text": concatenated_content += content.text elif content.type == "image_url": - embeddings = await add_image_embedding( - embeddings, content.image_url.url - ) + await embeddings.add(content.image_url.url) concatenated_content += embeddings.text_alias[-1] message.content = concatenated_content