Merge pull request #130 from bartowski1182/main

WIP: Add 'model' argument to /v1/chat/completions to load a new model on the fly
This commit is contained in:
Brian Dashore
2024-09-04 21:46:41 -04:00
committed by GitHub
4 changed files with 66 additions and 4 deletions

View File

@@ -83,6 +83,9 @@ model:
# Enable this if the program is looking for a specific OAI model
#use_dummy_models: False
# Allow direct loading of models from a completion or chat completion request
inline_model_loading: False
# An initial model to load. Make sure the model is located in the model directory!
# A model can be loaded later via the API.
# REQUIRED: This must be filled out to load a model on startup!

View File

@@ -21,6 +21,7 @@ from endpoints.OAI.utils.chat_completion import (
)
from endpoints.OAI.utils.completion import (
generate_completion,
load_inline_model,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
@@ -41,7 +42,7 @@ def setup():
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key)],
)
async def completion_request(
request: Request, data: CompletionRequest
@@ -52,6 +53,11 @@ async def completion_request(
If stream = true, this returns an SSE stream.
"""
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
model_path = model.container.model_dir
if isinstance(data.prompt, list):
@@ -86,7 +92,7 @@ async def completion_request(
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
@@ -97,6 +103,11 @@ async def chat_completion_request(
If stream = true, this returns an SSE stream.
"""
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",

View File

@@ -56,6 +56,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
response_prefix: Optional[str] = None
model: Optional[str] = None
# tools is follows the format OAI schema, functions is more flexible
# both are available in the chat template.

View File

@@ -1,4 +1,8 @@
"""Completion utilities for OAI server."""
"""
Completion utilities for OAI server.
Also serves as a common module for completions and chat completions.
"""
import asyncio
import pathlib
@@ -9,7 +13,8 @@ from typing import List, Union
from loguru import logger
from common import model
from common import config, model
from common.auth import get_key_permission
from common.networking import (
get_generator_error,
handle_request_disconnect,
@@ -103,6 +108,48 @@ async def _stream_collector(
await gen_queue.put(e)
async def load_inline_model(model_name: str, request: Request):
"""Load a model from the data.model parameter"""
# Return if the model container already exists
if model.container and model.container.model_dir.name == model_name:
return
model_config = config.model_config()
# Inline model loading isn't enabled or the user isn't an admin
if not get_key_permission(request) == "admin":
error_message = handle_request_error(
f"Unable to switch model to {model_name} because "
+ "an admin key isn't provided",
exc_info=False,
).error.message
raise HTTPException(401, error_message)
if not unwrap(model_config.get("inline_model_loading"), False):
logger.warning(
f"Unable to switch model to {model_name} because "
'"inline_model_load" is not True in config.yml.'
)
return
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_name
# Model path doesn't exist
if not model_path.exists():
logger.warning(
f"Could not find model path {str(model_path)}. Skipping inline model load."
)
return
# Load the model
await model.load_model(model_path)
async def stream_generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):