diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5f4e86b..04c6d08 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -405,6 +405,9 @@ class ExllamaV2Container: def get_model_path(self, is_draft: bool = False): """Get the path for this model.""" + if is_draft and not self.draft_config: + return None + model_path = pathlib.Path( self.draft_config.model_dir if is_draft else self.config.model_dir ) diff --git a/common/auth.py b/common/auth.py index 623de63..7c0d83b 100644 --- a/common/auth.py +++ b/common/auth.py @@ -106,8 +106,7 @@ def get_key_permission(request: Request): async def check_api_key( - x_api_key: str = Header(None), - authorization: str = Header(None) + x_api_key: str = Header(None), authorization: str = Header(None) ): """Check if the API key is valid.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 0cf921a..5f6d1d1 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from common import config, model, gen_logging, sampling +from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download from common.networking import handle_request_error, run_with_request_disconnect @@ -18,7 +18,6 @@ from endpoints.OAI.types.chat_completion import ( ) from endpoints.OAI.types.download import DownloadRequest, DownloadResponse from endpoints.OAI.types.lora import ( - LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse, @@ -27,7 +26,6 @@ from endpoints.OAI.types.model import ( ModelCard, ModelList, ModelLoadRequest, - ModelCardParameters, ModelLoadResponse, ) from endpoints.OAI.types.sampler_overrides import ( @@ -50,8 +48,13 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.model import get_model_list, stream_model_load -from endpoints.OAI.utils.lora import get_lora_list +from endpoints.OAI.utils.model import ( + get_current_model, + get_current_model_list, + get_model_list, + stream_model_load, +) +from endpoints.OAI.utils.lora import get_active_loras, get_lora_list router = APIRouter() @@ -172,7 +175,7 @@ async def chat_completion_request( # Model list endpoint @router.get("/v1/models", dependencies=[Depends(check_api_key)]) @router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) -async def list_models() -> ModelList: +async def list_models(request: Request) -> ModelList: """Lists all models in the model directory.""" model_config = config.model_config() model_dir = unwrap(model_config.get("model_dir"), "models") @@ -180,7 +183,11 @@ async def list_models() -> ModelList: draft_model_dir = config.draft_model_config().get("draft_model_dir") - models = get_model_list(model_path.resolve(), draft_model_dir) + if get_key_permission(request) == "admin": + models = get_model_list(model_path.resolve(), draft_model_dir) + else: + models = await get_current_model_list() + if unwrap(model_config.get("use_dummy_models"), False): models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) @@ -194,43 +201,23 @@ async def list_models() -> ModelList: ) async def current_model() -> ModelCard: """Returns the currently loaded model.""" - model_params = model.container.get_model_parameters() - draft_model_params = model_params.pop("draft", {}) - if draft_model_params: - model_params["draft"] = ModelCard( - id=unwrap(draft_model_params.get("name"), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - else: - draft_model_params = None - - model_card = ModelCard( - id=unwrap(model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(model_params), - logging=gen_logging.PREFERENCES, - ) - - if draft_model_params: - draft_card = ModelCard( - id=unwrap(draft_model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - - model_card.parameters.draft = draft_card - - return model_card + return get_current_model() @router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) -async def list_draft_models() -> ModelList: +async def list_draft_models(request: Request) -> ModelList: """Lists all draft models in the model directory.""" - draft_model_dir = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - draft_model_path = pathlib.Path(draft_model_dir) - models = get_model_list(draft_model_path.resolve()) + if get_key_permission(request) == "admin": + draft_model_dir = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + draft_model_path = pathlib.Path(draft_model_dir) + + models = get_model_list(draft_model_path.resolve()) + else: + models = await get_current_model_list(is_draft=True) return models @@ -313,10 +300,14 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes # Lora list endpoint @router.get("/v1/loras", dependencies=[Depends(check_api_key)]) @router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) -async def list_all_loras() -> LoraList: +async def list_all_loras(request: Request) -> LoraList: """Lists all LoRAs in the lora directory.""" - lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - loras = get_lora_list(lora_path.resolve()) + + if get_key_permission(request) == "admin": + lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + loras = get_lora_list(lora_path.resolve()) + else: + loras = get_active_loras() return loras @@ -328,17 +319,8 @@ async def list_all_loras() -> LoraList: ) async def active_loras() -> LoraList: """Returns the currently loaded loras.""" - loras = LoraList( - data=[ - LoraCard( - id=pathlib.Path(lora.lora_path).parent.name, - scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, - ) - for lora in model.container.get_loras() - ] - ) - return loras + return get_active_loras() # Load lora endpoint @@ -452,9 +434,17 @@ async def key_permission(request: Request) -> AuthPermissionResponse: @router.get("/v1/templates", dependencies=[Depends(check_api_key)]) @router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -async def list_templates() -> TemplateList: - templates = get_all_templates() - template_strings = [template.stem for template in templates] +async def list_templates(request: Request) -> TemplateList: + """Get a list of all templates.""" + + template_strings = [] + if get_key_permission(request) == "admin": + templates = get_all_templates() + template_strings = [template.stem for template in templates] + else: + if model.container and model.container.prompt_template: + template_strings.append(model.container.prompt_template.name) + return TemplateList(data=template_strings) @@ -464,6 +454,7 @@ async def list_templates() -> TemplateList: ) async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template""" + if not data.name: error_message = handle_request_error( "New template name not found.", @@ -496,11 +487,16 @@ async def unload_template(): # Sampler override endpoints @router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) @router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) -async def list_sampler_overrides() -> SamplerOverrideListResponse: +async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: """API wrapper to list all currently applied sampler overrides""" + if get_key_permission(request) == "admin": + presets = sampling.get_all_presets() + else: + presets = [] + return SamplerOverrideListResponse( - presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump() + presets=presets, **sampling.overrides_container.model_dump() ) diff --git a/endpoints/OAI/utils/lora.py b/endpoints/OAI/utils/lora.py index d00910f..3e31f68 100644 --- a/endpoints/OAI/utils/lora.py +++ b/endpoints/OAI/utils/lora.py @@ -1,5 +1,6 @@ import pathlib +from common import model from endpoints.OAI.types.lora import LoraCard, LoraList @@ -12,3 +13,18 @@ def get_lora_list(lora_path: pathlib.Path): lora_list.data.append(lora_card) return lora_list + + +def get_active_loras(): + if model.container: + active_loras = [ + LoraCard( + id=pathlib.Path(lora.lora_path).parent.name, + scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, + ) + for lora in model.container.get_loras() + ] + else: + active_loras = [] + + return LoraList(data=active_loras) diff --git a/endpoints/OAI/utils/model.py b/endpoints/OAI/utils/model.py index 0502193..a8057e4 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/OAI/utils/model.py @@ -2,11 +2,12 @@ import pathlib from asyncio import CancelledError from typing import Optional -from common import model +from common import gen_logging, model from common.networking import get_generator_error, handle_request_disconnect - +from common.utils import unwrap from endpoints.OAI.types.model import ( ModelCard, + ModelCardParameters, ModelList, ModelLoadRequest, ModelLoadResponse, @@ -31,6 +32,50 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list +async def get_current_model_list(is_draft: bool = False): + """Gets the current model in list format and with path only.""" + current_models = [] + + # Make sure the model container exists + if model.container: + model_path = model.container.get_model_path(is_draft) + if model_path: + current_models.append(ModelCard(id=model_path.name)) + + return ModelList(data=current_models) + + +def get_current_model(): + """Gets the current model with all parameters.""" + + model_params = model.container.get_model_parameters() + draft_model_params = model_params.pop("draft", {}) + + if draft_model_params: + model_params["draft"] = ModelCard( + id=unwrap(draft_model_params.get("name"), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), + ) + else: + draft_model_params = None + + model_card = ModelCard( + id=unwrap(model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(model_params), + logging=gen_logging.PREFERENCES, + ) + + if draft_model_params: + draft_card = ModelCard( + id=unwrap(draft_model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), + ) + + model_card.parameters.draft = draft_card + + return model_card + + async def stream_model_load( data: ModelLoadRequest, model_path: pathlib.Path,