Merge remote-tracking branch 'upstream/main' into HEAD

This commit is contained in:
TerminalMan
2024-09-11 15:57:18 +01:00
28 changed files with 386 additions and 171 deletions

View File

@@ -22,6 +22,7 @@ from endpoints.OAI.utils.chat_completion import (
)
from endpoints.OAI.utils.completion import (
generate_completion,
load_inline_model,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
@@ -42,7 +43,7 @@ def setup():
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key)],
)
async def completion_request(
request: Request, data: CompletionRequest
@@ -53,6 +54,18 @@ async def completion_request(
If stream = true, this returns an SSE stream.
"""
if data.model:
inline_load_task = asyncio.create_task(load_inline_model(data.model, request))
await run_with_request_disconnect(
request,
inline_load_task,
disconnect_message=f"Model switch for generation {request.state.id} "
+ "cancelled by user.",
)
else:
await check_model_container()
model_path = model.container.model_dir
if isinstance(data.prompt, list):
@@ -85,7 +98,7 @@ async def completion_request(
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
@@ -96,6 +109,11 @@ async def chat_completion_request(
If stream = true, this returns an SSE stream.
"""
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",