diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 5fdf81f..cc752c5 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -108,6 +108,48 @@ async def _stream_collector( await gen_queue.put(e) +async def load_inline_model(model_name: str, request: Request): + """Load a model from the data.model parameter""" + + # Return if the model container already exists + if model.container and model.container.model_dir.name == model_name: + return + + model_config = config.model_config() + + # Inline model loading isn't enabled or the user isn't an admin + if not get_key_permission(request) == "admin": + error_message = handle_request_error( + f"Unable to switch model to {model_name} because " + + "an admin key isn't provided", + exc_info=False, + ).error.message + + raise HTTPException(401, error_message) + + if not unwrap(model_config.get("inline_model_loading"), False): + logger.warning( + f"Unable to switch model to {model_name} because " + '"inline_model_load" is not True in config.yml.' + ) + + return + + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = model_path / model_name + + # Model path doesn't exist + if not model_path.exists(): + logger.warning( + f"Could not find model path {str(model_path)}. Skipping inline model load." + ) + + return + + # Load the model + await model.load_model(model_path) + + async def stream_generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): @@ -178,47 +220,6 @@ async def stream_generate_completion( ) -async def load_inline_model(model_name: str, request: Request): - """Load a model from the data.model parameter""" - - # Return if the model container already exists - if model.container and model.container.model_dir.name == model_name: - return - - model_config = config.model_config() - - # Inline model loading isn't enabled or the user isn't an admin - if not get_key_permission(request) == "admin": - logger.warning( - f"Unable to switch model to {model_name} " - "because an admin key isn't provided." - ) - - return - - if not unwrap(model_config.get("inline_model_loading"), False): - logger.warning( - f"Unable to switch model to {model_name} because " - '"inline_model_load" is not True in config.yml.' - ) - - return - - model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) - model_path = model_path / model_name - - # Model path doesn't exist - if not model_path.exists(): - logger.warning( - f"Could not find model path {str(model_path)}. Skipping inline model load." - ) - - return - - # Load the model - await model.load_model(model_path) - - async def generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ):