mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
@@ -6,6 +6,8 @@ import gc
|
||||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
import torch
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
@@ -20,6 +22,7 @@ from exllamav2 import (
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2Lora,
|
||||
ExLlamaV2VisionTower,
|
||||
)
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2Sampler,
|
||||
@@ -91,6 +94,10 @@ class ExllamaV2Container:
|
||||
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||
use_tp: bool = False
|
||||
|
||||
# Vision vars
|
||||
use_vision: bool = False
|
||||
vision_model: Optional[ExLlamaV2VisionTower] = None
|
||||
|
||||
# Load state
|
||||
model_is_loading: bool = False
|
||||
model_loaded: bool = False
|
||||
@@ -144,6 +151,15 @@ class ExllamaV2Container:
|
||||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = await self.set_model_overrides(**kwargs)
|
||||
|
||||
# Set vision state and error if vision isn't supported on the current model
|
||||
self.use_vision = unwrap(kwargs.get("vision"), False)
|
||||
if self.use_vision and not self.config.vision_model_type:
|
||||
raise ValueError(
|
||||
"The provided model does not have vision capabilities that are "
|
||||
"supported by ExllamaV2. "
|
||||
"Please reload with vision disabled."
|
||||
)
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
@@ -477,6 +493,7 @@ class ExllamaV2Container:
|
||||
"prompt_template": self.prompt_template.name
|
||||
if self.prompt_template
|
||||
else None,
|
||||
"use_vision": self.use_vision,
|
||||
}
|
||||
|
||||
if self.draft_config:
|
||||
@@ -620,6 +637,14 @@ class ExllamaV2Container:
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
# Load vision tower if it exists
|
||||
if self.use_vision:
|
||||
self.vision_model = ExLlamaV2VisionTower(self.config)
|
||||
|
||||
for value in self.vision_model.load_gen(callback_gen=progress_callback):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
self.model = ExLlamaV2(self.config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
@@ -811,6 +836,9 @@ class ExllamaV2Container:
|
||||
# Delete references held in the grammar module
|
||||
clear_grammar_func_cache()
|
||||
|
||||
# Clear the image embedding cache
|
||||
clear_image_embedding_cache()
|
||||
|
||||
# Unload LoRAs
|
||||
if self.generator and self.generator.generator.current_loras:
|
||||
for lora in self.generator.generator.current_loras:
|
||||
@@ -824,6 +852,16 @@ class ExllamaV2Container:
|
||||
self.model.unload()
|
||||
self.model = None
|
||||
|
||||
if self.vision_model:
|
||||
# TODO: Remove this with newer exl2 versions
|
||||
# Required otherwise unload function won't finish
|
||||
try:
|
||||
self.vision_model.unload()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.vision_model = None
|
||||
|
||||
if self.draft_model:
|
||||
self.draft_model.unload()
|
||||
self.draft_model = None
|
||||
@@ -855,11 +893,15 @@ class ExllamaV2Container:
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string."""
|
||||
|
||||
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
.flatten()
|
||||
.tolist()
|
||||
@@ -903,7 +945,11 @@ class ExllamaV2Container:
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate a response to a prompt."""
|
||||
generations = []
|
||||
@@ -1238,10 +1284,17 @@ class ExllamaV2Container:
|
||||
else:
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
# Get multimodal embeddings if present
|
||||
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
# Encode both positive and negative prompts
|
||||
input_ids = [
|
||||
self.tokenizer.encode(
|
||||
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
||||
prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
@@ -1292,6 +1345,7 @@ class ExllamaV2Container:
|
||||
banned_strings=banned_strings,
|
||||
token_healing=token_healing,
|
||||
identifier=job_id,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
|
||||
# Save generated tokens and full response
|
||||
|
||||
70
backends/exllamav2/vision.py
Normal file
70
backends/exllamav2/vision.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Vision utilities for ExLlamaV2."""
|
||||
|
||||
import io
|
||||
import base64
|
||||
import re
|
||||
from PIL import Image
|
||||
from common import model
|
||||
import aiohttp
|
||||
from common.networking import (
|
||||
handle_request_error,
|
||||
)
|
||||
from common.tabby_config import config
|
||||
from fastapi import HTTPException
|
||||
from exllamav2.generator import ExLlamaV2MMEmbedding
|
||||
from async_lru import alru_cache
|
||||
|
||||
|
||||
async def get_image(url: str) -> Image:
|
||||
if url.startswith("data:image"):
|
||||
# Handle base64 image
|
||||
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
|
||||
if match:
|
||||
base64_image = match.group(1)
|
||||
bytes_image = base64.b64decode(base64_image)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"Failed to read base64 image input.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
else:
|
||||
# Handle image URL
|
||||
if config.network.disable_fetch_requests:
|
||||
error_message = handle_request_error(
|
||||
f"Failed to fetch image from {url} as fetch requests are disabled.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
bytes_image = await response.read()
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
f"Failed to fetch image from {url}.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
return Image.open(io.BytesIO(bytes_image))
|
||||
|
||||
|
||||
@alru_cache(20)
|
||||
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
|
||||
image = await get_image(url)
|
||||
return model.container.vision_model.get_image_embeddings(
|
||||
model=model.container.model,
|
||||
tokenizer=model.container.tokenizer,
|
||||
image=image,
|
||||
text_alias=None,
|
||||
)
|
||||
|
||||
|
||||
def clear_image_embedding_cache():
|
||||
get_image_embedding.cache_clear()
|
||||
@@ -78,6 +78,13 @@ class NetworkConfig(BaseConfigModel):
|
||||
"Turn on this option if you are ONLY connecting from localhost."
|
||||
),
|
||||
)
|
||||
disable_fetch_requests: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Disable fetching external content in response to requests,"
|
||||
"such as images from URLs."
|
||||
),
|
||||
)
|
||||
send_tracebacks: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
@@ -281,6 +288,12 @@ class ModelConfig(BaseConfigModel):
|
||||
"NOTE: Only works with chat completion message lists!"
|
||||
),
|
||||
)
|
||||
vision: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Enables vision support if the model supports it. (default: False)"
|
||||
),
|
||||
)
|
||||
num_experts_per_token: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
|
||||
@@ -33,6 +33,7 @@ class ModelType(Enum):
|
||||
MODEL = "model"
|
||||
DRAFT = "draft"
|
||||
EMBEDDING = "embedding"
|
||||
VISION = "vision"
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
@@ -70,29 +71,39 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
# Create a new container
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
# Add possible types of models that can be loaded
|
||||
model_type = [ModelType.MODEL]
|
||||
|
||||
if container.use_vision:
|
||||
model_type.insert(0, ModelType.VISION)
|
||||
|
||||
if container.draft_config:
|
||||
model_type.insert(0, ModelType.DRAFT)
|
||||
|
||||
load_status = container.load_gen(load_progress, **kwargs)
|
||||
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
|
||||
try:
|
||||
index = 0
|
||||
async for module, modules in load_status:
|
||||
current_model_type = model_type[index].value
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
f"[cyan]Loading {current_model_type} modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task)
|
||||
|
||||
yield module, modules, model_type
|
||||
yield module, modules, current_model_type
|
||||
|
||||
if module == modules:
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_type == "draft":
|
||||
model_type = "model"
|
||||
else:
|
||||
if index == len(model_type):
|
||||
progress.stop()
|
||||
else:
|
||||
index += 1
|
||||
finally:
|
||||
progress.stop()
|
||||
|
||||
|
||||
30
common/multimodal.py
Normal file
30
common/multimodal.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
from backends.exllamav2.vision import get_image_embedding
|
||||
from common import model
|
||||
from loguru import logger
|
||||
|
||||
from common.optional_dependencies import dependencies
|
||||
|
||||
if dependencies.exllamav2:
|
||||
from exllamav2 import ExLlamaV2VisionTower
|
||||
|
||||
|
||||
class MultimodalEmbeddingWrapper:
|
||||
"""Common multimodal embedding wrapper"""
|
||||
|
||||
type: str = None
|
||||
content: List = []
|
||||
text_alias: List[str] = []
|
||||
|
||||
async def add(self, url: str):
|
||||
# Determine the type of vision embedding to use
|
||||
if not self.type:
|
||||
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
|
||||
self.type = "ExLlamaV2MMEmbedding"
|
||||
|
||||
if self.type == "ExLlamaV2MMEmbedding":
|
||||
embedding = await get_image_embedding(url)
|
||||
self.content.append(embedding)
|
||||
self.text_alias.append(embedding.text_alias)
|
||||
else:
|
||||
logger.error("No valid vision model to create embedding")
|
||||
@@ -20,6 +20,9 @@ network:
|
||||
# Turn on this option if you are ONLY connecting from localhost.
|
||||
disable_auth: false
|
||||
|
||||
# Disable fetching external content in response to requests, such as images from URLs.
|
||||
disable_fetch_requests: false
|
||||
|
||||
# Send tracebacks over the API (default: False).
|
||||
# NOTE: Only enable this for debug purposes.
|
||||
send_tracebacks: false
|
||||
@@ -130,6 +133,9 @@ model:
|
||||
# NOTE: Only works with chat completion message lists!
|
||||
prompt_template:
|
||||
|
||||
# Enables vision support if the model supports it. (default: False)
|
||||
vision: false
|
||||
|
||||
# Number of experts to use per token.
|
||||
# Fetched from the model's config.json if empty.
|
||||
# NOTE: For MoE models only.
|
||||
|
||||
@@ -15,7 +15,7 @@ from endpoints.OAI.types.chat_completion import (
|
||||
)
|
||||
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
||||
from endpoints.OAI.utils.chat_completion import (
|
||||
format_prompt_with_template,
|
||||
apply_chat_template,
|
||||
generate_chat_completion,
|
||||
stream_generate_chat_completion,
|
||||
)
|
||||
@@ -123,10 +123,7 @@ async def chat_completion_request(
|
||||
|
||||
model_path = model.container.model_dir
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = await format_prompt_with_template(data)
|
||||
prompt, embeddings = await apply_chat_template(data)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
if data.response_format.type == "json":
|
||||
@@ -136,12 +133,14 @@ async def chat_completion_request(
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_chat_completion(prompt, data, request, model_path),
|
||||
stream_generate_chat_completion(
|
||||
prompt, embeddings, data, request, model_path
|
||||
),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
generate_chat_completion(prompt, data, request, model_path)
|
||||
generate_chat_completion(prompt, embeddings, data, request, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from time import time
|
||||
from typing import Union, List, Optional, Dict
|
||||
from typing import Literal, Union, List, Optional, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
@@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel):
|
||||
content: List[ChatCompletionLogprob] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionImageUrl(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ChatCompletionMessagePart(BaseModel):
|
||||
type: Literal["text", "image_url"] = "text"
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[ChatCompletionImageUrl] = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
role: str = "user"
|
||||
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
tool_calls_json: SkipJsonSchema[Optional[str]] = None
|
||||
|
||||
|
||||
class ChatCompletionRespChoice(BaseModel):
|
||||
@@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
||||
|
||||
# WIP this can probably be tightened, or maybe match the OAI lib type
|
||||
# in openai\types\chat\chat_completion_message_param.py
|
||||
messages: Union[str, List[Dict]]
|
||||
messages: List[ChatCompletionMessage] = Field(default_factory=list)
|
||||
prompt_template: Optional[str] = None
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import List, Optional
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
@@ -177,11 +177,11 @@ def _create_stream_chunk(
|
||||
return chunk
|
||||
|
||||
|
||||
async def _append_template_metadata(data: ChatCompletionRequest):
|
||||
async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict):
|
||||
"""Adding metadata is a one-time process."""
|
||||
|
||||
template_metadata = await model.container.prompt_template.extract_metadata(
|
||||
data.template_vars
|
||||
template_vars
|
||||
)
|
||||
|
||||
# Stop strings
|
||||
@@ -199,7 +199,44 @@ async def _append_template_metadata(data: ChatCompletionRequest):
|
||||
data.stop.extend(template_metadata.tool_starts)
|
||||
|
||||
|
||||
async def format_prompt_with_template(
|
||||
async def format_messages_with_template(
|
||||
messages: List[ChatCompletionMessage],
|
||||
existing_template_vars: Optional[dict] = None,
|
||||
add_bos_token: bool = True,
|
||||
ban_eos_token: bool = False,
|
||||
):
|
||||
"""Barebones function to format chat completion messages into a prompt."""
|
||||
|
||||
template_vars = unwrap(existing_template_vars, {})
|
||||
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
concatenated_content = ""
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
concatenated_content += content.text
|
||||
elif content.type == "image_url" and mm_embeddings:
|
||||
await mm_embeddings.add(content.image_url.url)
|
||||
concatenated_content += mm_embeddings.text_alias[-1]
|
||||
|
||||
# Convert the message content into a concatenated string
|
||||
message.content = concatenated_content
|
||||
|
||||
if message.tool_calls:
|
||||
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
|
||||
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
add_bos_token, ban_eos_token
|
||||
)
|
||||
|
||||
template_vars.update({"messages": messages, **special_tokens_dict})
|
||||
|
||||
prompt = await model.container.prompt_template.render(template_vars)
|
||||
return prompt, mm_embeddings, template_vars
|
||||
|
||||
|
||||
async def apply_chat_template(
|
||||
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
@@ -208,40 +245,18 @@ async def format_prompt_with_template(
|
||||
"""
|
||||
|
||||
try:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True),
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
# Deal with list in messages.content
|
||||
# Just replace the content list with the very first text message
|
||||
for message in data.messages:
|
||||
if isinstance(message["content"], list):
|
||||
message["content"] = next(
|
||||
(
|
||||
content["text"]
|
||||
for content in message["content"]
|
||||
if content["type"] == "text"
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
if "tool_calls" in message:
|
||||
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
|
||||
|
||||
# Overwrite any protected vars with their values
|
||||
data.template_vars.update(
|
||||
{
|
||||
"messages": data.messages,
|
||||
"add_generation_prompt": data.add_generation_prompt,
|
||||
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
|
||||
"functions_json": json.dumps(data.functions, indent=2),
|
||||
"tool_precursor": tool_precursor,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
)
|
||||
|
||||
prompt = await model.container.prompt_template.render(data.template_vars)
|
||||
prompt, mm_embeddings, template_vars = await format_messages_with_template(
|
||||
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
|
||||
)
|
||||
|
||||
# Append response prefix if present
|
||||
if data.response_prefix:
|
||||
@@ -255,14 +270,14 @@ async def format_prompt_with_template(
|
||||
|
||||
# Removes the starting BOS token if present
|
||||
# This is to prevent add_bos_token from adding multiple bos tokens
|
||||
bos_token = special_tokens_dict.get("bos_token")
|
||||
bos_token = template_vars.get("bos_token")
|
||||
if bos_token and prompt.startswith(bos_token):
|
||||
prompt = prompt.removeprefix(bos_token)
|
||||
|
||||
# Add template metadata
|
||||
await _append_template_metadata(data)
|
||||
await _append_template_metadata(data, template_vars)
|
||||
|
||||
return prompt
|
||||
return prompt, mm_embeddings
|
||||
|
||||
except KeyError as exc:
|
||||
error_message = handle_request_error(
|
||||
@@ -279,7 +294,11 @@ async def format_prompt_with_template(
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
data: ChatCompletionRequest,
|
||||
request: Request,
|
||||
model_path: pathlib.Path,
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
abort_event = asyncio.Event()
|
||||
@@ -300,6 +319,7 @@ async def stream_generate_chat_completion(
|
||||
prompt,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
embeddings=embeddings,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
@@ -372,7 +392,11 @@ async def stream_generate_chat_completion(
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
data: ChatCompletionRequest,
|
||||
request: Request,
|
||||
model_path: pathlib.Path,
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
|
||||
@@ -381,7 +405,10 @@ async def generate_chat_completion(
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt, request.state.id, **data.model_dump(exclude={"prompt"})
|
||||
prompt,
|
||||
request.state.id,
|
||||
embeddings=embeddings,
|
||||
**data.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -427,13 +454,11 @@ async def generate_tool_calls(
|
||||
if gen["stop_str"] in tool_data.tool_call_start:
|
||||
if "text" in gen:
|
||||
# non streaming, all generations will have the text they generated
|
||||
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
|
||||
pre_tool_prompt = await apply_chat_template(data, gen["text"])
|
||||
elif current_generations is not None:
|
||||
# streaming, we wont have text in the generation,
|
||||
# we'll have to use the current_generations
|
||||
pre_tool_prompt = await format_prompt_with_template(
|
||||
data, current_generations
|
||||
)
|
||||
pre_tool_prompt = await apply_chat_template(data, current_generations)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from sys import maxsize
|
||||
from typing import Optional
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
@@ -13,6 +15,7 @@ from common.tabby_config import config
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
from common.health import HealthManager
|
||||
from endpoints.OAI.utils.chat_completion import format_messages_with_template
|
||||
from endpoints.core.types.auth import AuthPermissionResponse
|
||||
from endpoints.core.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
@@ -359,22 +362,47 @@ async def unload_embedding_model():
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
else:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
elif isinstance(data.text, list):
|
||||
if "oai" not in config.network.api_servers:
|
||||
error_message = handle_request_error(
|
||||
"Enable the OAI server to handle chat completion messages.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
if not model.container.prompt_template:
|
||||
error_message = handle_request_error(
|
||||
"Cannot tokenize chat completion message because "
|
||||
+ "a prompt template is not set.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
template_vars = {
|
||||
"messages": data.text,
|
||||
"add_generation_prompt": False,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
# Don't need template vars again
|
||||
text, mm_embeddings, _ = await format_messages_with_template(
|
||||
data.text, template_vars, data.add_bos_token
|
||||
)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"Unable to tokenize the provided text. Check your formatting?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(
|
||||
text, embeddings=mm_embeddings, **data.get_params()
|
||||
)
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ class ModelCardParameters(BaseModel):
|
||||
chunk_size: Optional[int] = 2048
|
||||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
use_vision: Optional[bool] = False
|
||||
|
||||
# Draft is another model, so include it in the card params
|
||||
draft: Optional["ModelCard"] = None
|
||||
@@ -107,6 +108,7 @@ class ModelLoadRequest(BaseModel):
|
||||
cache_mode: Optional[str] = None
|
||||
chunk_size: Optional[int] = None
|
||||
prompt_template: Optional[str] = None
|
||||
vision: Optional[bool] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
|
||||
# Non-config arguments
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Tokenization types"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Union
|
||||
from typing import List, Union
|
||||
|
||||
from endpoints.OAI.types.chat_completion import ChatCompletionMessage
|
||||
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
@@ -23,7 +25,7 @@ class CommonTokenRequest(BaseModel):
|
||||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
"""Represents a tokenization request."""
|
||||
|
||||
text: Union[str, List[Dict[str, str]]]
|
||||
text: Union[str, List[ChatCompletionMessage]]
|
||||
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
|
||||
@@ -29,6 +29,7 @@ dependencies = [
|
||||
"lm-format-enforcer >= 0.9.6",
|
||||
"aiofiles",
|
||||
"aiohttp",
|
||||
"async_lru",
|
||||
"huggingface_hub",
|
||||
"psutil",
|
||||
"httptools>=0.5.0",
|
||||
|
||||
Reference in New Issue
Block a user