mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
OAI: Allow /v1/encode endpoint to handle vision requests
* More robust checks for OAI chat completion message lists on /v1/encode endpoint * Added TODO to support other aspects of chat completions * Fix oversight where embeddings was not defined in advance on /v1/chat/completions endpoint
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
@@ -124,6 +125,8 @@ async def chat_completion_request(
|
||||
|
||||
model_path = model.container.model_dir
|
||||
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from sys import maxsize
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
@@ -357,10 +358,27 @@ async def unload_embedding_model():
|
||||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
else:
|
||||
elif isinstance(data.text, list) and "oai" in config.network.api_servers:
|
||||
# TODO: Support additional chat completion args for encode
|
||||
# i.e. add_generation_prompt, template selection, tool args, template kwargs
|
||||
if model.container.prompt_template is None:
|
||||
error_message = handle_request_error(
|
||||
"Tokenization of chat completion requests is disabled "
|
||||
"because a prompt template is not set.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
from endpoints.OAI.utils.chat_completion import preprocess_vision_request
|
||||
|
||||
if model.container.use_vision:
|
||||
data.text, embeddings = await preprocess_vision_request(data.text)
|
||||
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
@@ -371,9 +389,16 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
text = await model.container.prompt_template.render(template_vars)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"OAI API server must be enabled to handle chat completion message inputs.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel):
|
||||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
"""Represents a tokenization request."""
|
||||
|
||||
text: Union[str, List[Dict[str, str]]]
|
||||
text: Union[str, List[Dict]]
|
||||
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user