Merge pull request #249 from theroyallab/vision

Vision
This commit is contained in:
Brian
2024-11-22 17:45:49 -05:00
committed by GitHub
13 changed files with 321 additions and 69 deletions

View File

@@ -6,6 +6,8 @@ import gc
import math
import pathlib
import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch
import uuid
from copy import deepcopy
@@ -20,6 +22,7 @@ from exllamav2 import (
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2Sampler,
@@ -91,6 +94,10 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False
# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None
# Load state
model_is_loading: bool = False
model_loaded: bool = False
@@ -144,6 +151,15 @@ class ExllamaV2Container:
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)
# Set vision state and error if vision isn't supported on the current model
self.use_vision = unwrap(kwargs.get("vision"), False)
if self.use_vision and not self.config.vision_model_type:
raise ValueError(
"The provided model does not have vision capabilities that are "
"supported by ExllamaV2. "
"Please reload with vision disabled."
)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
@@ -477,6 +493,7 @@ class ExllamaV2Container:
"prompt_template": self.prompt_template.name
if self.prompt_template
else None,
"use_vision": self.use_vision,
}
if self.draft_config:
@@ -620,6 +637,14 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)
for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value
self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)
@@ -811,6 +836,9 @@ class ExllamaV2Container:
# Delete references held in the grammar module
clear_grammar_func_cache()
# Clear the image embedding cache
clear_image_embedding_cache()
# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
@@ -824,6 +852,16 @@ class ExllamaV2Container:
self.model.unload()
self.model = None
if self.vision_model:
# TODO: Remove this with newer exl2 versions
# Required otherwise unload function won't finish
try:
self.vision_model.unload()
except AttributeError:
pass
self.vision_model = None
if self.draft_model:
self.draft_model.unload()
self.draft_model = None
@@ -855,11 +893,15 @@ class ExllamaV2Container:
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string."""
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=mm_embeddings_content,
)
.flatten()
.tolist()
@@ -903,7 +945,11 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values))
async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
self,
prompt: str,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
):
"""Generate a response to a prompt."""
generations = []
@@ -1238,10 +1284,17 @@ class ExllamaV2Container:
else:
stop_conditions += eos_tokens
# Get multimodal embeddings if present
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=mm_embeddings_content,
)
for prompt in prompts
]
@@ -1292,6 +1345,7 @@ class ExllamaV2Container:
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
embeddings=mm_embeddings_content,
)
# Save generated tokens and full response

View File

@@ -0,0 +1,70 @@
"""Vision utilities for ExLlamaV2."""
import io
import base64
import re
from PIL import Image
from common import model
import aiohttp
from common.networking import (
handle_request_error,
)
from common.tabby_config import config
from fastapi import HTTPException
from exllamav2.generator import ExLlamaV2MMEmbedding
from async_lru import alru_cache
async def get_image(url: str) -> Image:
if url.startswith("data:image"):
# Handle base64 image
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
if match:
base64_image = match.group(1)
bytes_image = base64.b64decode(base64_image)
else:
error_message = handle_request_error(
"Failed to read base64 image input.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
else:
# Handle image URL
if config.network.disable_fetch_requests:
error_message = handle_request_error(
f"Failed to fetch image from {url} as fetch requests are disabled.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
bytes_image = await response.read()
else:
error_message = handle_request_error(
f"Failed to fetch image from {url}.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return Image.open(io.BytesIO(bytes_image))
@alru_cache(20)
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
image = await get_image(url)
return model.container.vision_model.get_image_embeddings(
model=model.container.model,
tokenizer=model.container.tokenizer,
image=image,
text_alias=None,
)
def clear_image_embedding_cache():
get_image_embedding.cache_clear()

View File

@@ -78,6 +78,13 @@ class NetworkConfig(BaseConfigModel):
"Turn on this option if you are ONLY connecting from localhost."
),
)
disable_fetch_requests: Optional[bool] = Field(
False,
description=(
"Disable fetching external content in response to requests,"
"such as images from URLs."
),
)
send_tracebacks: Optional[bool] = Field(
False,
description=(
@@ -281,6 +288,12 @@ class ModelConfig(BaseConfigModel):
"NOTE: Only works with chat completion message lists!"
),
)
vision: Optional[bool] = Field(
False,
description=(
"Enables vision support if the model supports it. (default: False)"
),
)
num_experts_per_token: Optional[int] = Field(
None,
description=(

View File

@@ -33,6 +33,7 @@ class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
VISION = "vision"
def load_progress(module, modules):
@@ -70,29 +71,39 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
# Add possible types of models that can be loaded
model_type = [ModelType.MODEL]
if container.use_vision:
model_type.insert(0, ModelType.VISION)
if container.draft_config:
model_type.insert(0, ModelType.DRAFT)
load_status = container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar()
progress.start()
try:
index = 0
async for module, modules in load_status:
current_model_type = model_type[index].value
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
f"[cyan]Loading {current_model_type} modules", total=modules
)
else:
progress.advance(loading_task)
yield module, modules, model_type
yield module, modules, current_model_type
if module == modules:
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
if index == len(model_type):
progress.stop()
else:
index += 1
finally:
progress.stop()

30
common/multimodal.py Normal file
View File

@@ -0,0 +1,30 @@
from typing import List
from backends.exllamav2.vision import get_image_embedding
from common import model
from loguru import logger
from common.optional_dependencies import dependencies
if dependencies.exllamav2:
from exllamav2 import ExLlamaV2VisionTower
class MultimodalEmbeddingWrapper:
"""Common multimodal embedding wrapper"""
type: str = None
content: List = []
text_alias: List[str] = []
async def add(self, url: str):
# Determine the type of vision embedding to use
if not self.type:
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
self.type = "ExLlamaV2MMEmbedding"
if self.type == "ExLlamaV2MMEmbedding":
embedding = await get_image_embedding(url)
self.content.append(embedding)
self.text_alias.append(embedding.text_alias)
else:
logger.error("No valid vision model to create embedding")

View File

@@ -20,6 +20,9 @@ network:
# Turn on this option if you are ONLY connecting from localhost.
disable_auth: false
# Disable fetching external content in response to requests, such as images from URLs.
disable_fetch_requests: false
# Send tracebacks over the API (default: False).
# NOTE: Only enable this for debug purposes.
send_tracebacks: false
@@ -130,6 +133,9 @@ model:
# NOTE: Only works with chat completion message lists!
prompt_template:
# Enables vision support if the model supports it. (default: False)
vision: false
# Number of experts to use per token.
# Fetched from the model's config.json if empty.
# NOTE: For MoE models only.

View File

@@ -15,7 +15,7 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
apply_chat_template,
generate_chat_completion,
stream_generate_chat_completion,
)
@@ -123,10 +123,7 @@ async def chat_completion_request(
model_path = model.container.model_dir
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = await format_prompt_with_template(data)
prompt, embeddings = await apply_chat_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
@@ -136,12 +133,14 @@ async def chat_completion_request(
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
stream_generate_chat_completion(
prompt, embeddings, data, request, model_path
),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
generate_chat_completion(prompt, embeddings, data, request, model_path)
)
response = await run_with_request_disconnect(

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from time import time
from typing import Union, List, Optional, Dict
from typing import Literal, Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
@@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprob] = Field(default_factory=list)
class ChatCompletionImageUrl(BaseModel):
url: str
class ChatCompletionMessagePart(BaseModel):
type: Literal["text", "image_url"] = "text"
text: Optional[str] = None
image_url: Optional[ChatCompletionImageUrl] = None
class ChatCompletionMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
role: str = "user"
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
tool_calls: Optional[List[ToolCall]] = None
tool_calls_json: SkipJsonSchema[Optional[str]] = None
class ChatCompletionRespChoice(BaseModel):
@@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# WIP this can probably be tightened, or maybe match the OAI lib type
# in openai\types\chat\chat_completion_message_param.py
messages: Union[str, List[Dict]]
messages: List[ChatCompletionMessage] = Field(default_factory=list)
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}

View File

@@ -1,16 +1,16 @@
"""Chat completion utilities for OAI server."""
import asyncio
import json
import pathlib
from asyncio import CancelledError
from typing import List, Optional
import json
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
from common import model
from common.multimodal import MultimodalEmbeddingWrapper
from common.networking import (
get_generator_error,
handle_request_disconnect,
@@ -177,11 +177,11 @@ def _create_stream_chunk(
return chunk
async def _append_template_metadata(data: ChatCompletionRequest):
async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict):
"""Adding metadata is a one-time process."""
template_metadata = await model.container.prompt_template.extract_metadata(
data.template_vars
template_vars
)
# Stop strings
@@ -199,7 +199,44 @@ async def _append_template_metadata(data: ChatCompletionRequest):
data.stop.extend(template_metadata.tool_starts)
async def format_prompt_with_template(
async def format_messages_with_template(
messages: List[ChatCompletionMessage],
existing_template_vars: Optional[dict] = None,
add_bos_token: bool = True,
ban_eos_token: bool = False,
):
"""Barebones function to format chat completion messages into a prompt."""
template_vars = unwrap(existing_template_vars, {})
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
for message in messages:
if isinstance(message.content, list):
concatenated_content = ""
for content in message.content:
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url" and mm_embeddings:
await mm_embeddings.add(content.image_url.url)
concatenated_content += mm_embeddings.text_alias[-1]
# Convert the message content into a concatenated string
message.content = concatenated_content
if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
special_tokens_dict = model.container.get_special_tokens(
add_bos_token, ban_eos_token
)
template_vars.update({"messages": messages, **special_tokens_dict})
prompt = await model.container.prompt_template.render(template_vars)
return prompt, mm_embeddings, template_vars
async def apply_chat_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
"""
@@ -208,40 +245,18 @@ async def format_prompt_with_template(
"""
try:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True),
unwrap(data.ban_eos_token, False),
)
# Deal with list in messages.content
# Just replace the content list with the very first text message
for message in data.messages:
if isinstance(message["content"], list):
message["content"] = next(
(
content["text"]
for content in message["content"]
if content["type"] == "text"
),
"",
)
if "tool_calls" in message:
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
# Overwrite any protected vars with their values
data.template_vars.update(
{
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
**special_tokens_dict,
}
)
prompt = await model.container.prompt_template.render(data.template_vars)
prompt, mm_embeddings, template_vars = await format_messages_with_template(
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
)
# Append response prefix if present
if data.response_prefix:
@@ -255,14 +270,14 @@ async def format_prompt_with_template(
# Removes the starting BOS token if present
# This is to prevent add_bos_token from adding multiple bos tokens
bos_token = special_tokens_dict.get("bos_token")
bos_token = template_vars.get("bos_token")
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)
# Add template metadata
await _append_template_metadata(data)
await _append_template_metadata(data, template_vars)
return prompt
return prompt, mm_embeddings
except KeyError as exc:
error_message = handle_request_error(
@@ -279,7 +294,11 @@ async def format_prompt_with_template(
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
"""Generator for the generation process."""
abort_event = asyncio.Event()
@@ -300,6 +319,7 @@ async def stream_generate_chat_completion(
prompt,
request.state.id,
abort_event,
embeddings=embeddings,
**task_gen_params.model_dump(exclude={"prompt"}),
)
)
@@ -372,7 +392,11 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
gen_tasks: List[asyncio.Task] = []
@@ -381,7 +405,10 @@ async def generate_chat_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
prompt, request.state.id, **data.model_dump(exclude={"prompt"})
prompt,
request.state.id,
embeddings=embeddings,
**data.model_dump(exclude={"prompt"}),
)
)
)
@@ -427,13 +454,11 @@ async def generate_tool_calls(
if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
pre_tool_prompt = await apply_chat_template(data, gen["text"])
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt = await format_prompt_with_template(
data, current_generations
)
pre_tool_prompt = await apply_chat_template(data, current_generations)
gen_tasks.append(
asyncio.create_task(

View File

@@ -1,6 +1,8 @@
import asyncio
import pathlib
from sys import maxsize
from typing import Optional
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette import EventSourceResponse
@@ -13,6 +15,7 @@ from common.tabby_config import config
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from common.health import HealthManager
from endpoints.OAI.utils.chat_completion import format_messages_with_template
from endpoints.core.types.auth import AuthPermissionResponse
from endpoints.core.types.download import DownloadRequest, DownloadResponse
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
@@ -359,22 +362,47 @@ async def unload_embedding_model():
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None
if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
elif isinstance(data.text, list):
if "oai" not in config.network.api_servers:
error_message = handle_request_error(
"Enable the OAI server to handle chat completion messages.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
if not model.container.prompt_template:
error_message = handle_request_error(
"Cannot tokenize chat completion message because "
+ "a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
# Don't need template vars again
text, mm_embeddings, _ = await format_messages_with_template(
data.text, template_vars, data.add_bos_token
)
else:
error_message = handle_request_error(
"Unable to tokenize the provided text. Check your formatting?",
exc_info=False,
).error.message
raw_tokens = model.container.encode_tokens(text, **data.get_params())
raise HTTPException(422, error_message)
raw_tokens = model.container.encode_tokens(
text, embeddings=mm_embeddings, **data.get_params()
)
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))

View File

@@ -21,6 +21,7 @@ class ModelCardParameters(BaseModel):
chunk_size: Optional[int] = 2048
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_vision: Optional[bool] = False
# Draft is another model, so include it in the card params
draft: Optional["ModelCard"] = None
@@ -107,6 +108,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = None
chunk_size: Optional[int] = None
prompt_template: Optional[str] = None
vision: Optional[bool] = None
num_experts_per_token: Optional[int] = None
# Non-config arguments

View File

@@ -1,7 +1,9 @@
"""Tokenization types"""
from pydantic import BaseModel
from typing import Dict, List, Union
from typing import List, Union
from endpoints.OAI.types.chat_completion import ChatCompletionMessage
class CommonTokenRequest(BaseModel):
@@ -23,7 +25,7 @@ class CommonTokenRequest(BaseModel):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: Union[str, List[Dict[str, str]]]
text: Union[str, List[ChatCompletionMessage]]
class TokenEncodeResponse(BaseModel):

View File

@@ -29,6 +29,7 @@ dependencies = [
"lm-format-enforcer >= 0.9.6",
"aiofiles",
"aiohttp",
"async_lru",
"huggingface_hub",
"psutil",
"httptools>=0.5.0",