Files
tabbyAPI/endpoints/OAI/router.py

170 lines
5.4 KiB
Python

import asyncio
from asyncio import CancelledError, InvalidStateError
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from common import model
from common.auth import check_api_key
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, DisconnectHandler, run_with_request_disconnect
from common.tabby_config import config
from common.logger import xlogger
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.common_ import load_inline_model
from endpoints.OAI.utils.chat_completion import (
apply_chat_template,
generate_chat_completion,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
api_name = "OAI"
router = APIRouter()
urls = {
"Completions": "http://{host}:{port}/v1/completions",
"Chat completions": "http://{host}:{port}/v1/chat/completions",
}
# Block when model is still loading while second inline load request comes in
load_lock: asyncio.Lock = asyncio.Lock()
def setup():
return router
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key)],
)
async def completion_request(request: Request, data: CompletionRequest) -> CompletionResponse:
"""
Generates a completion from a prompt.
If stream = true, this returns an SSE stream.
"""
raw_json = await request.json()
xlogger.debug("[ENDPOINT] /v1/completions", {"raw": raw_json})
await load_lock.acquire()
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
model_path = model.container.model_dir
load_lock.release()
# Prepare raw prompt (will be str or list[str])
prompt = data.prompt
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
try:
disconnect_handler = DisconnectHandler(request, "/v1/completions")
await disconnect_handler.poll()
if data.stream and not config.developer.disable_request_streaming:
return EventSourceResponse(
stream_generate_completion(prompt, data, request, model_path, disconnect_handler),
ping=maxsize,
)
else:
response = await generate_completion(
prompt, data, request, model_path, disconnect_handler
)
return response
except (CancelledError, InvalidStateError) as ex:
raise HTTPException(422, "/v1/completions request cancelled by user.") from ex
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
raw_json = await request.json()
xlogger.debug("[ENDPOINT] /v1/chat/completions", {"raw": raw_json})
await load_lock.acquire()
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
model_path = model.container.model_dir
load_lock.release()
# Prepare raw prompt
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
prompt, mm_embeddings = await apply_chat_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
try:
disconnect_handler = DisconnectHandler(request, "/v1/chat/completions")
await disconnect_handler.poll()
if data.stream and not config.developer.disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(
prompt, mm_embeddings, data, request, model_path, disconnect_handler
),
ping=maxsize,
)
else:
response = await generate_chat_completion(
prompt, mm_embeddings, data, request, model_path, disconnect_handler
)
return response
except (CancelledError, InvalidStateError) as ex:
raise HTTPException(422, "/v1/chat/completions request cancelled by user.") from ex
# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled",
)
return response