From 9ad69e8ab6419f878fe0ba405f77e1364e8e6548 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 14:08:48 -0400 Subject: [PATCH] API: Migrate universal routes to core Place OAI specific routes in the appropriate folder. This is in preperation for adding new API servers that can be optionally enabled. Signed-off-by: kingbri --- backends/exllamav2/model.py | 2 + endpoints/OAI/router.py | 427 +---------------- endpoints/core/router.py | 432 ++++++++++++++++++ endpoints/{OAI => core}/types/auth.py | 0 endpoints/{OAI => core}/types/download.py | 0 endpoints/{OAI => core}/types/lora.py | 0 endpoints/{OAI => core}/types/model.py | 0 .../{OAI => core}/types/sampler_overrides.py | 0 endpoints/{OAI => core}/types/template.py | 0 endpoints/{OAI => core}/types/token.py | 0 endpoints/{OAI => core}/utils/lora.py | 2 +- endpoints/{OAI => core}/utils/model.py | 2 +- endpoints/server.py | 4 + 13 files changed, 442 insertions(+), 427 deletions(-) create mode 100644 endpoints/core/router.py rename endpoints/{OAI => core}/types/auth.py (100%) rename endpoints/{OAI => core}/types/download.py (100%) rename endpoints/{OAI => core}/types/lora.py (100%) rename endpoints/{OAI => core}/types/model.py (100%) rename endpoints/{OAI => core}/types/sampler_overrides.py (100%) rename endpoints/{OAI => core}/types/template.py (100%) rename endpoints/{OAI => core}/types/token.py (100%) rename endpoints/{OAI => core}/utils/lora.py (92%) rename endpoints/{OAI => core}/utils/model.py (98%) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 200be6b..057e8c1 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1080,6 +1080,8 @@ class ExllamaV2Container: else [self.tokenizer.eos_token_id] ) + print(self.tokenizer.eos_token_id) + # Ban the EOS token if specified. If not, append to stop conditions # as well. # Set this below logging to avoid polluting the stop strings array diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 4269c89..771b7f3 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,45 +1,18 @@ import asyncio -import pathlib from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from common import config, model, sampling -from common.auth import check_admin_key, check_api_key, get_key_permission -from common.downloader import hf_repo_download +from common import config, model +from common.auth import check_api_key from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect -from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap -from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, ) -from endpoints.OAI.types.download import DownloadRequest, DownloadResponse -from endpoints.OAI.types.lora import ( - LoraList, - LoraLoadRequest, - LoraLoadResponse, -) -from endpoints.OAI.types.model import ( - ModelCard, - ModelList, - ModelLoadRequest, - ModelLoadResponse, -) -from endpoints.OAI.types.sampler_overrides import ( - SamplerOverrideListResponse, - SamplerOverrideSwitchRequest, -) -from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest -from endpoints.OAI.types.token import ( - TokenEncodeRequest, - TokenEncodeResponse, - TokenDecodeRequest, - TokenDecodeResponse, -) from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -49,13 +22,6 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.model import ( - get_current_model, - get_current_model_list, - get_model_list, - stream_model_load, -) -from endpoints.OAI.utils.lora import get_active_loras, get_lora_list router = APIRouter() @@ -159,392 +125,3 @@ async def chat_completion_request( disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response - - -# Model list endpoint -@router.get("/v1/models", dependencies=[Depends(check_api_key)]) -@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) -async def list_models(request: Request) -> ModelList: - """ - Lists all models in the model directory. - - Requires an admin key to see all models. - """ - - model_config = config.model_config() - model_dir = unwrap(model_config.get("model_dir"), "models") - model_path = pathlib.Path(model_dir) - - draft_model_dir = config.draft_model_config().get("draft_model_dir") - - if get_key_permission(request) == "admin": - models = get_model_list(model_path.resolve(), draft_model_dir) - else: - models = await get_current_model_list() - - if unwrap(model_config.get("use_dummy_models"), False): - models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) - - return models - - -# Currently loaded model endpoint -@router.get( - "/v1/model", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def current_model() -> ModelCard: - """Returns the currently loaded model.""" - - return get_current_model() - - -@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) -async def list_draft_models(request: Request) -> ModelList: - """ - Lists all draft models in the model directory. - - Requires an admin key to see all draft models. - """ - - if get_key_permission(request) == "admin": - draft_model_dir = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - draft_model_path = pathlib.Path(draft_model_dir) - - models = get_model_list(draft_model_path.resolve()) - else: - models = await get_current_model_list(is_draft=True) - - return models - - -# Load model endpoint -@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: - """Loads a model into the model container. This returns an SSE stream.""" - - # Verify request parameters - if not data.name: - error_message = handle_request_error( - "A model name was not provided for load.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) - model_path = model_path / data.name - - draft_model_path = None - if data.draft: - if not data.draft.draft_model_name: - error_message = handle_request_error( - "Could not find the draft model name for model load.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - draft_model_path = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - - if not model_path.exists(): - error_message = handle_request_error( - "Could not find the model path for load. Check model name or config.yml?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - return EventSourceResponse( - stream_model_load(data, model_path, draft_model_path), ping=maxsize - ) - - -# Unload model endpoint -@router.post( - "/v1/model/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_model(): - """Unloads the currently loaded model.""" - await model.unload_model(skip_wait=True) - - -@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) -async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: - """Downloads a model from HuggingFace.""" - - try: - download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) - - # For now, the downloader and request data are 1:1 - download_path = await run_with_request_disconnect( - request, - download_task, - "Download request cancelled by user. Files have been cleaned up.", - ) - - return DownloadResponse(download_path=str(download_path)) - except Exception as exc: - error_message = handle_request_error(str(exc)).error.message - - raise HTTPException(400, error_message) from exc - - -# Lora list endpoint -@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) -@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) -async def list_all_loras(request: Request) -> LoraList: - """ - Lists all LoRAs in the lora directory. - - Requires an admin key to see all LoRAs. - """ - - if get_key_permission(request) == "admin": - lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - loras = get_lora_list(lora_path.resolve()) - else: - loras = get_active_loras() - - return loras - - -# Currently loaded loras endpoint -@router.get( - "/v1/lora", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def active_loras() -> LoraList: - """Returns the currently loaded loras.""" - - return get_active_loras() - - -# Load lora endpoint -@router.post( - "/v1/lora/load", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: - """Loads a LoRA into the model container.""" - - if not data.loras: - error_message = handle_request_error( - "List of loras to load is not found.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - if not lora_dir.exists(): - error_message = handle_request_error( - "A parent lora directory does not exist for load. Check your config.yml?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - load_result = await model.load_loras( - lora_dir, **data.model_dump(), skip_wait=data.skip_queue - ) - - return LoraLoadResponse( - success=unwrap(load_result.get("success"), []), - failure=unwrap(load_result.get("failure"), []), - ) - - -# Unload lora endpoint -@router.post( - "/v1/lora/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_loras(): - """Unloads the currently loaded loras.""" - - await model.unload_loras() - - -# Encode tokens endpoint -@router.post( - "/v1/token/encode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: - """Encodes a string or chat completion messages into tokens.""" - - if isinstance(data.text, str): - text = data.text - else: - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True) - ) - - template_vars = { - "messages": data.text, - "add_generation_prompt": False, - **special_tokens_dict, - } - - text, _ = model.container.prompt_template.render(template_vars) - - raw_tokens = model.container.encode_tokens(text, **data.get_params()) - tokens = unwrap(raw_tokens, []) - response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) - - return response - - -# Decode tokens endpoint -@router.post( - "/v1/token/decode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: - """Decodes tokens into a string.""" - - message = model.container.decode_tokens(data.tokens, **data.get_params()) - response = TokenDecodeResponse(text=unwrap(message, "")) - - return response - - -@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) -async def key_permission(request: Request) -> AuthPermissionResponse: - """ - Gets the access level/permission of a provided key in headers. - - Priority: - - X-admin-key - - X-api-key - - Authorization - """ - - try: - permission = get_key_permission(request) - return AuthPermissionResponse(permission=permission) - except ValueError as exc: - error_message = handle_request_error(str(exc)).error.message - - raise HTTPException(400, error_message) from exc - - -@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -async def list_templates(request: Request) -> TemplateList: - """ - Get a list of all templates. - - Requires an admin key to see all templates. - """ - - template_strings = [] - if get_key_permission(request) == "admin": - templates = get_all_templates() - template_strings = [template.stem for template in templates] - else: - if model.container and model.container.prompt_template: - template_strings.append(model.container.prompt_template.name) - - return TemplateList(data=template_strings) - - -@router.post( - "/v1/template/switch", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def switch_template(data: TemplateSwitchRequest): - """Switch the currently loaded template.""" - - if not data.name: - error_message = handle_request_error( - "New template name not found.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - try: - model.container.prompt_template = PromptTemplate.from_file(data.name) - except FileNotFoundError as e: - error_message = handle_request_error( - f"The template name {data.name} doesn't exist. Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - - -@router.post( - "/v1/template/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_template(): - """Unloads the currently selected template""" - - model.container.prompt_template = None - - -# Sampler override endpoints -@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) -async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: - """ - List all currently applied sampler overrides. - - Requires an admin key to see all override presets. - """ - - if get_key_permission(request) == "admin": - presets = sampling.get_all_presets() - else: - presets = [] - - return SamplerOverrideListResponse( - presets=presets, **sampling.overrides_container.model_dump() - ) - - -@router.post( - "/v1/sampling/override/switch", - dependencies=[Depends(check_admin_key)], -) -async def switch_sampler_override(data: SamplerOverrideSwitchRequest): - """Switch the currently loaded override preset""" - - if data.preset: - try: - sampling.overrides_from_file(data.preset) - except FileNotFoundError as e: - error_message = handle_request_error( - f"Sampler override preset with name {data.preset} does not exist. " - + "Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - elif data.overrides: - sampling.overrides_from_dict(data.overrides) - else: - error_message = handle_request_error( - "A sampler override preset or dictionary wasn't provided.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - -@router.post( - "/v1/sampling/override/unload", - dependencies=[Depends(check_admin_key)], -) -async def unload_sampler_override(): - """Unloads the currently selected override preset""" - - sampling.overrides_from_dict({}) diff --git a/endpoints/core/router.py b/endpoints/core/router.py new file mode 100644 index 0000000..cd0ed37 --- /dev/null +++ b/endpoints/core/router.py @@ -0,0 +1,432 @@ +import asyncio +import pathlib +from sys import maxsize +from fastapi import APIRouter, Depends, HTTPException, Request +from sse_starlette import EventSourceResponse + +from common import config, model, sampling +from common.auth import check_admin_key, check_api_key, get_key_permission +from common.downloader import hf_repo_download +from common.model import check_model_container +from common.networking import handle_request_error, run_with_request_disconnect +from common.templating import PromptTemplate, get_all_templates +from common.utils import unwrap +from endpoints.core.types.auth import AuthPermissionResponse +from endpoints.core.types.download import DownloadRequest, DownloadResponse +from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse +from endpoints.core.types.model import ( + ModelCard, + ModelList, + ModelLoadRequest, + ModelLoadResponse, +) +from endpoints.core.types.sampler_overrides import ( + SamplerOverrideListResponse, + SamplerOverrideSwitchRequest, +) +from endpoints.core.types.template import TemplateList, TemplateSwitchRequest +from endpoints.core.types.token import ( + TokenDecodeRequest, + TokenDecodeResponse, + TokenEncodeRequest, + TokenEncodeResponse, +) +from endpoints.core.utils.lora import get_active_loras, get_lora_list +from endpoints.core.utils.model import ( + get_current_model, + get_current_model_list, + get_model_list, + stream_model_load, +) + + +router = APIRouter() + + +# Model list endpoint +@router.get("/v1/models", dependencies=[Depends(check_api_key)]) +@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) +async def list_models(request: Request) -> ModelList: + """ + Lists all models in the model directory. + + Requires an admin key to see all models. + """ + + model_config = config.model_config() + model_dir = unwrap(model_config.get("model_dir"), "models") + model_path = pathlib.Path(model_dir) + + draft_model_dir = config.draft_model_config().get("draft_model_dir") + + if get_key_permission(request) == "admin": + models = get_model_list(model_path.resolve(), draft_model_dir) + else: + models = await get_current_model_list() + + if unwrap(model_config.get("use_dummy_models"), False): + models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) + + return models + + +# Currently loaded model endpoint +@router.get( + "/v1/model", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def current_model() -> ModelCard: + """Returns the currently loaded model.""" + + return get_current_model() + + +@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) +async def list_draft_models(request: Request) -> ModelList: + """ + Lists all draft models in the model directory. + + Requires an admin key to see all draft models. + """ + + if get_key_permission(request) == "admin": + draft_model_dir = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + draft_model_path = pathlib.Path(draft_model_dir) + + models = get_model_list(draft_model_path.resolve()) + else: + models = await get_current_model_list(is_draft=True) + + return models + + +# Load model endpoint +@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) +async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: + """Loads a model into the model container. This returns an SSE stream.""" + + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) + model_path = model_path / data.name + + draft_model_path = None + if data.draft: + if not data.draft.draft_model_name: + error_message = handle_request_error( + "Could not find the draft model name for model load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + draft_model_path = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + + if not model_path.exists(): + error_message = handle_request_error( + "Could not find the model path for load. Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + return EventSourceResponse( + stream_model_load(data, model_path, draft_model_path), ping=maxsize + ) + + +# Unload model endpoint +@router.post( + "/v1/model/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_model(): + """Unloads the currently loaded model.""" + await model.unload_model(skip_wait=True) + + +@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) +async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: + """Downloads a model from HuggingFace.""" + + try: + download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) + + # For now, the downloader and request data are 1:1 + download_path = await run_with_request_disconnect( + request, + download_task, + "Download request cancelled by user. Files have been cleaned up.", + ) + + return DownloadResponse(download_path=str(download_path)) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + +# Lora list endpoint +@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) +@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) +async def list_all_loras(request: Request) -> LoraList: + """ + Lists all LoRAs in the lora directory. + + Requires an admin key to see all LoRAs. + """ + + if get_key_permission(request) == "admin": + lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + loras = get_lora_list(lora_path.resolve()) + else: + loras = get_active_loras() + + return loras + + +# Currently loaded loras endpoint +@router.get( + "/v1/lora", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def active_loras() -> LoraList: + """Returns the currently loaded loras.""" + + return get_active_loras() + + +# Load lora endpoint +@router.post( + "/v1/lora/load", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: + """Loads a LoRA into the model container.""" + + if not data.loras: + error_message = handle_request_error( + "List of loras to load is not found.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + if not lora_dir.exists(): + error_message = handle_request_error( + "A parent lora directory does not exist for load. Check your config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + load_result = await model.load_loras( + lora_dir, **data.model_dump(), skip_wait=data.skip_queue + ) + + return LoraLoadResponse( + success=unwrap(load_result.get("success"), []), + failure=unwrap(load_result.get("failure"), []), + ) + + +# Unload lora endpoint +@router.post( + "/v1/lora/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_loras(): + """Unloads the currently loaded loras.""" + + await model.unload_loras() + + +# Encode tokens endpoint +@router.post( + "/v1/token/encode", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: + """Encodes a string or chat completion messages into tokens.""" + + if isinstance(data.text, str): + text = data.text + else: + special_tokens_dict = model.container.get_special_tokens( + unwrap(data.add_bos_token, True) + ) + + template_vars = { + "messages": data.text, + "add_generation_prompt": False, + **special_tokens_dict, + } + + text, _ = model.container.prompt_template.render(template_vars) + + raw_tokens = model.container.encode_tokens(text, **data.get_params()) + tokens = unwrap(raw_tokens, []) + response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) + + return response + + +# Decode tokens endpoint +@router.post( + "/v1/token/decode", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: + """Decodes tokens into a string.""" + + message = model.container.decode_tokens(data.tokens, **data.get_params()) + response = TokenDecodeResponse(text=unwrap(message, "")) + + return response + + +@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) +async def key_permission(request: Request) -> AuthPermissionResponse: + """ + Gets the access level/permission of a provided key in headers. + + Priority: + - X-admin-key + - X-api-key + - Authorization + """ + + try: + permission = get_key_permission(request) + return AuthPermissionResponse(permission=permission) + except ValueError as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + +@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) +@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +async def list_templates(request: Request) -> TemplateList: + """ + Get a list of all templates. + + Requires an admin key to see all templates. + """ + + template_strings = [] + if get_key_permission(request) == "admin": + templates = get_all_templates() + template_strings = [template.stem for template in templates] + else: + if model.container and model.container.prompt_template: + template_strings.append(model.container.prompt_template.name) + + return TemplateList(data=template_strings) + + +@router.post( + "/v1/template/switch", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def switch_template(data: TemplateSwitchRequest): + """Switch the currently loaded template.""" + + if not data.name: + error_message = handle_request_error( + "New template name not found.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + model.container.prompt_template = PromptTemplate.from_file(data.name) + except FileNotFoundError as e: + error_message = handle_request_error( + f"The template name {data.name} doesn't exist. Check the spelling?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) from e + + +@router.post( + "/v1/template/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_template(): + """Unloads the currently selected template""" + + model.container.prompt_template = None + + +# Sampler override endpoints +@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) +@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: + """ + List all currently applied sampler overrides. + + Requires an admin key to see all override presets. + """ + + if get_key_permission(request) == "admin": + presets = sampling.get_all_presets() + else: + presets = [] + + return SamplerOverrideListResponse( + presets=presets, **sampling.overrides_container.model_dump() + ) + + +@router.post( + "/v1/sampling/override/switch", + dependencies=[Depends(check_admin_key)], +) +async def switch_sampler_override(data: SamplerOverrideSwitchRequest): + """Switch the currently loaded override preset""" + + if data.preset: + try: + sampling.overrides_from_file(data.preset) + except FileNotFoundError as e: + error_message = handle_request_error( + f"Sampler override preset with name {data.preset} does not exist. " + + "Check the spelling?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) from e + elif data.overrides: + sampling.overrides_from_dict(data.overrides) + else: + error_message = handle_request_error( + "A sampler override preset or dictionary wasn't provided.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + +@router.post( + "/v1/sampling/override/unload", + dependencies=[Depends(check_admin_key)], +) +async def unload_sampler_override(): + """Unloads the currently selected override preset""" + + sampling.overrides_from_dict({}) diff --git a/endpoints/OAI/types/auth.py b/endpoints/core/types/auth.py similarity index 100% rename from endpoints/OAI/types/auth.py rename to endpoints/core/types/auth.py diff --git a/endpoints/OAI/types/download.py b/endpoints/core/types/download.py similarity index 100% rename from endpoints/OAI/types/download.py rename to endpoints/core/types/download.py diff --git a/endpoints/OAI/types/lora.py b/endpoints/core/types/lora.py similarity index 100% rename from endpoints/OAI/types/lora.py rename to endpoints/core/types/lora.py diff --git a/endpoints/OAI/types/model.py b/endpoints/core/types/model.py similarity index 100% rename from endpoints/OAI/types/model.py rename to endpoints/core/types/model.py diff --git a/endpoints/OAI/types/sampler_overrides.py b/endpoints/core/types/sampler_overrides.py similarity index 100% rename from endpoints/OAI/types/sampler_overrides.py rename to endpoints/core/types/sampler_overrides.py diff --git a/endpoints/OAI/types/template.py b/endpoints/core/types/template.py similarity index 100% rename from endpoints/OAI/types/template.py rename to endpoints/core/types/template.py diff --git a/endpoints/OAI/types/token.py b/endpoints/core/types/token.py similarity index 100% rename from endpoints/OAI/types/token.py rename to endpoints/core/types/token.py diff --git a/endpoints/OAI/utils/lora.py b/endpoints/core/utils/lora.py similarity index 92% rename from endpoints/OAI/utils/lora.py rename to endpoints/core/utils/lora.py index 3e31f68..c8c9cc4 100644 --- a/endpoints/OAI/utils/lora.py +++ b/endpoints/core/utils/lora.py @@ -1,7 +1,7 @@ import pathlib from common import model -from endpoints.OAI.types.lora import LoraCard, LoraList +from endpoints.core.types.lora import LoraCard, LoraList def get_lora_list(lora_path: pathlib.Path): diff --git a/endpoints/OAI/utils/model.py b/endpoints/core/utils/model.py similarity index 98% rename from endpoints/OAI/utils/model.py rename to endpoints/core/utils/model.py index a8057e4..0cfb26a 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/core/utils/model.py @@ -5,7 +5,7 @@ from typing import Optional from common import gen_logging, model from common.networking import get_generator_error, handle_request_disconnect from common.utils import unwrap -from endpoints.OAI.types.model import ( +from endpoints.core.types.model import ( ModelCard, ModelCardParameters, ModelList, diff --git a/endpoints/server.py b/endpoints/server.py index 8c10b63..dc501f5 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -5,6 +5,7 @@ from loguru import logger from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends +from endpoints.core.router import router as CoreRouter from endpoints.OAI.router import router as OAIRouter @@ -32,6 +33,9 @@ def setup_app(): app.include_router(OAIRouter) + # Include core API request paths + app.include_router(CoreRouter) + return app