Merge branch 'main' into breaking

This commit is contained in:
kingbri
2025-06-17 22:24:32 -04:00
7 changed files with 149 additions and 77 deletions

View File

@@ -1003,7 +1003,6 @@ class ExllamaV2Container(BaseModelContainer):
params: BaseSamplerRequest,
gen_settings: ExLlamaV2Sampler.Settings,
grammar_handler: ExLlamaV2Grammar,
banned_strings: List[str],
):
# Apply settings
gen_settings.temperature = params.temperature
@@ -1111,16 +1110,6 @@ class ExllamaV2Container(BaseModelContainer):
params.grammar_string, self.model, self.tokenizer
)
# Set banned strings
banned_strings = params.banned_strings
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
"Disabling banned_strings because "
"they cannot be used with grammar filters."
)
banned_strings = []
# Speculative Ngram
self.generator.speculative_ngram = params.speculative_ngram
@@ -1226,15 +1215,23 @@ class ExllamaV2Container(BaseModelContainer):
prompts = [prompt]
gen_settings = ExLlamaV2Sampler.Settings()
grammar_handler = ExLlamaV2Grammar()
banned_strings = []
self.assign_gen_params(
params,
gen_settings,
grammar_handler,
banned_strings,
)
# Set banned strings
banned_strings = params.banned_strings
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
"Disabling banned_strings because "
"they cannot be used with grammar filters."
)
banned_strings = []
# Set CFG scale and negative prompt
cfg_scale = params.cfg_scale
negative_prompt = None

View File

@@ -1,19 +1,10 @@
"""Vision utilities for ExLlamaV2."""
import aiohttp
import base64
import io
import re
from async_lru import alru_cache
from fastapi import HTTPException
from PIL import Image
from common import model
from common.networking import (
handle_request_error,
)
from common.optional_dependencies import dependencies
from common.tabby_config import config
from common.image_util import get_image
# Since this is used outside the Exl2 backend, the dependency
# may be optional
@@ -21,49 +12,9 @@ if dependencies.exllamav2:
from exllamav2.generator import ExLlamaV2MMEmbedding
async def get_image(url: str) -> Image:
if url.startswith("data:image"):
# Handle base64 image
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
if match:
base64_image = match.group(1)
bytes_image = base64.b64decode(base64_image)
else:
error_message = handle_request_error(
"Failed to read base64 image input.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
else:
# Handle image URL
if config.network.disable_fetch_requests:
error_message = handle_request_error(
f"Failed to fetch image from {url} as fetch requests are disabled.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
bytes_image = await response.read()
else:
error_message = handle_request_error(
f"Failed to fetch image from {url}.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return Image.open(io.BytesIO(bytes_image))
# Fetch the return type on runtime
@alru_cache(20)
async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding":
async def get_image_embedding_exl2(url: str) -> "ExLlamaV2MMEmbedding":
image = await get_image(url)
return model.container.vision_model.get_image_embeddings(
model=model.container.model,
@@ -74,4 +25,4 @@ async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding":
def clear_image_embedding_cache():
get_image_embedding.cache_clear()
get_image_embedding_exl2.cache_clear()