diff --git a/common/config_models.py b/common/config_models.py index 40b4109..b113194 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -141,14 +141,25 @@ class ModelConfig(BaseConfigModel): False, description=( "Allow direct loading of models " - "from a completion or chat completion request (default: False)." + "from a completion or chat completion request (default: False).\n" + "This method of loading is strict by default.\n" + "Enable dummy models to add exceptions for invalid model names." ), ) use_dummy_models: Optional[bool] = Field( False, description=( - "Sends dummy model names when the models endpoint is queried.\n" - "Enable this if the client is looking for specific OAI models." + "Sends dummy model names when the models endpoint is queried. " + "(default: False)\n" + "Enable this if the client is looking for specific OAI models.\n" + ), + ) + dummy_model_names: List[str] = Field( + default=["gpt-3.5-turbo"], + description=( + "A list of fake model names that are sent via the /v1/models endpoint. " + '(default: ["gpt-3.5-turbo"])\n' + "Also used as bypasses for strict mode if inline_model_loading is true." ), ) model_name: Optional[str] = Field( diff --git a/config_sample.yml b/config_sample.yml index 83f2fc7..39593db 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -49,12 +49,18 @@ model: model_dir: models # Allow direct loading of models from a completion or chat completion request (default: False). + # This method of loading is strict by default. + # Enable dummy models to add exceptions for invalid model names. inline_model_loading: false - # Sends dummy model names when the models endpoint is queried. + # Sends dummy model names when the models endpoint is queried. (default: False) # Enable this if the client is looking for specific OAI models. use_dummy_models: false + # A list of fake model names that are sent via the /v1/models endpoint. (default: ["gpt-3.5-turbo"]) + # Also used as bypasses for strict mode if inline_model_loading is true. + dummy_model_names: ["gpt-3.5-turbo"] + # An initial model to load. # Make sure the model is located in the model directory! # REQUIRED: This must be filled out to load a model on startup. diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index cfcdeba..9fd8b90 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -130,18 +130,33 @@ async def load_inline_model(model_name: str, request: Request): return - # Error if an invalid key is passed - if get_key_permission(request) != "admin": - error_message = handle_request_error( - f"Unable to switch model to {model_name} because " - + "an admin key isn't provided", - exc_info=False, - ).error.message + is_dummy_model = ( + config.model.use_dummy_models and model_name in config.model.dummy_model_names + ) - raise HTTPException(401, error_message) + # Error if an invalid key is passed + # If a dummy model is provided, don't error + if get_key_permission(request) != "admin": + if not is_dummy_model: + error_message = handle_request_error( + f"Unable to switch model to {model_name} because " + + "an admin key isn't provided", + exc_info=False, + ).error.message + + raise HTTPException(401, error_message) + else: + return # Start inline loading # Past here, user is assumed to be admin + + # Skip if the model is a dummy + if is_dummy_model: + logger.warning(f"Dummy model {model_name} provided. Skipping inline load.") + + return + model_path = pathlib.Path(config.model.model_dir) model_path = model_path / model_name diff --git a/endpoints/core/router.py b/endpoints/core/router.py index f2b4247..597930b 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -39,6 +39,7 @@ from endpoints.core.utils.lora import get_active_loras, get_lora_list from endpoints.core.utils.model import ( get_current_model, get_current_model_list, + get_dummy_models, get_model_list, stream_model_load, ) @@ -82,7 +83,7 @@ async def list_models(request: Request) -> ModelList: models = await get_current_model_list() if config.model.use_dummy_models: - models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) + models.data[:0] = get_dummy_models() return models diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 973337d..c2c209b 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -92,6 +92,13 @@ def get_current_model(): return model_card +def get_dummy_models(): + if config.model.dummy_model_names: + return [ModelCard(id=dummy_id) for dummy_id in config.model.dummy_model_names] + else: + return [ModelCard(id="gpt-3.5-turbo")] + + async def stream_model_load( data: ModelLoadRequest, model_path: pathlib.Path,