From bd9e78e19e19550d91b73881bf85b6e29249f7bb Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 17 Nov 2024 21:12:38 -0500 Subject: [PATCH] API: Add inline exception for dummy models If an API key sends a dummy model, it shouldn't error as the server is catering to clients that expect specific OAI model names. This is a problem with inline model loading since these names would error by default. Therefore, add an exception if the provided name is in the dummy model names (which also doubles as inline strict exceptions). However, the dummy model names weren't configurable, so add a new option to specify exception names, otherwise the default is gpt-3.5-turbo. Signed-off-by: kingbri --- common/config_models.py | 17 ++++++++++++++--- config_sample.yml | 8 +++++++- endpoints/OAI/utils/completion.py | 31 +++++++++++++++++++++++-------- endpoints/core/router.py | 3 ++- endpoints/core/utils/model.py | 7 +++++++ 5 files changed, 53 insertions(+), 13 deletions(-) diff --git a/common/config_models.py b/common/config_models.py index 40b4109..b113194 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -141,14 +141,25 @@ class ModelConfig(BaseConfigModel): False, description=( "Allow direct loading of models " - "from a completion or chat completion request (default: False)." + "from a completion or chat completion request (default: False).\n" + "This method of loading is strict by default.\n" + "Enable dummy models to add exceptions for invalid model names." ), ) use_dummy_models: Optional[bool] = Field( False, description=( - "Sends dummy model names when the models endpoint is queried.\n" - "Enable this if the client is looking for specific OAI models." + "Sends dummy model names when the models endpoint is queried. " + "(default: False)\n" + "Enable this if the client is looking for specific OAI models.\n" + ), + ) + dummy_model_names: List[str] = Field( + default=["gpt-3.5-turbo"], + description=( + "A list of fake model names that are sent via the /v1/models endpoint. " + '(default: ["gpt-3.5-turbo"])\n' + "Also used as bypasses for strict mode if inline_model_loading is true." ), ) model_name: Optional[str] = Field( diff --git a/config_sample.yml b/config_sample.yml index 83f2fc7..39593db 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -49,12 +49,18 @@ model: model_dir: models # Allow direct loading of models from a completion or chat completion request (default: False). + # This method of loading is strict by default. + # Enable dummy models to add exceptions for invalid model names. inline_model_loading: false - # Sends dummy model names when the models endpoint is queried. + # Sends dummy model names when the models endpoint is queried. (default: False) # Enable this if the client is looking for specific OAI models. use_dummy_models: false + # A list of fake model names that are sent via the /v1/models endpoint. (default: ["gpt-3.5-turbo"]) + # Also used as bypasses for strict mode if inline_model_loading is true. + dummy_model_names: ["gpt-3.5-turbo"] + # An initial model to load. # Make sure the model is located in the model directory! # REQUIRED: This must be filled out to load a model on startup. diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index cfcdeba..9fd8b90 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -130,18 +130,33 @@ async def load_inline_model(model_name: str, request: Request): return - # Error if an invalid key is passed - if get_key_permission(request) != "admin": - error_message = handle_request_error( - f"Unable to switch model to {model_name} because " - + "an admin key isn't provided", - exc_info=False, - ).error.message + is_dummy_model = ( + config.model.use_dummy_models and model_name in config.model.dummy_model_names + ) - raise HTTPException(401, error_message) + # Error if an invalid key is passed + # If a dummy model is provided, don't error + if get_key_permission(request) != "admin": + if not is_dummy_model: + error_message = handle_request_error( + f"Unable to switch model to {model_name} because " + + "an admin key isn't provided", + exc_info=False, + ).error.message + + raise HTTPException(401, error_message) + else: + return # Start inline loading # Past here, user is assumed to be admin + + # Skip if the model is a dummy + if is_dummy_model: + logger.warning(f"Dummy model {model_name} provided. Skipping inline load.") + + return + model_path = pathlib.Path(config.model.model_dir) model_path = model_path / model_name diff --git a/endpoints/core/router.py b/endpoints/core/router.py index f2b4247..597930b 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -39,6 +39,7 @@ from endpoints.core.utils.lora import get_active_loras, get_lora_list from endpoints.core.utils.model import ( get_current_model, get_current_model_list, + get_dummy_models, get_model_list, stream_model_load, ) @@ -82,7 +83,7 @@ async def list_models(request: Request) -> ModelList: models = await get_current_model_list() if config.model.use_dummy_models: - models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) + models.data[:0] = get_dummy_models() return models diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 973337d..c2c209b 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -92,6 +92,13 @@ def get_current_model(): return model_card +def get_dummy_models(): + if config.model.dummy_model_names: + return [ModelCard(id=dummy_id) for dummy_id in config.model.dummy_model_names] + else: + return [ModelCard(id="gpt-3.5-turbo")] + + async def stream_model_load( data: ModelLoadRequest, model_path: pathlib.Path,