OAI: Allow /v1/encode endpoint to handle vision requests

* More robust checks for OAI chat completion message lists on /v1/encode endpoint
* Added TODO to support other aspects of chat completions
* Fix oversight where embeddings was not defined in advance on /v1/chat/completions endpoint
This commit is contained in:
DocShotgun
2024-11-19 11:14:37 -08:00
parent c42655336b
commit 5611365c07
4 changed files with 36 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
import asyncio
import pathlib
from sys import maxsize
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette import EventSourceResponse
@@ -357,10 +358,27 @@ async def unload_embedding_model():
)
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
embeddings = MultimodalEmbeddingWrapper()
if isinstance(data.text, str):
text = data.text
else:
elif isinstance(data.text, list) and "oai" in config.network.api_servers:
# TODO: Support additional chat completion args for encode
# i.e. add_generation_prompt, template selection, tool args, template kwargs
if model.container.prompt_template is None:
error_message = handle_request_error(
"Tokenization of chat completion requests is disabled "
"because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
from endpoints.OAI.utils.chat_completion import preprocess_vision_request
if model.container.use_vision:
data.text, embeddings = await preprocess_vision_request(data.text)
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
@@ -371,9 +389,16 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
text = await model.container.prompt_template.render(template_vars)
else:
error_message = handle_request_error(
"OAI API server must be enabled to handle chat completion message inputs.",
exc_info=False,
).error.message
raw_tokens = model.container.encode_tokens(text, **data.get_params())
raise HTTPException(422, error_message)
raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))

View File

@@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: Union[str, List[Dict[str, str]]]
text: Union[str, List[Dict]]
class TokenEncodeResponse(BaseModel):