From b9a58ff01b4b7ca37141aa5120e609b762ac5676 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 11 Jul 2024 11:16:24 -0400 Subject: [PATCH 01/89] Auth: Make key permission check work on Requests Pass a request and internally unwrap the headers. In addition, allow X-admin-key to get checked in an API key request. Signed-off-by: kingbri --- common/auth.py | 31 ++++++++++++++++++++++++++++--- endpoints/OAI/router.py | 21 +++++++-------------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/common/auth.py b/common/auth.py index fa53262..7ba13de 100644 --- a/common/auth.py +++ b/common/auth.py @@ -5,11 +5,13 @@ application, it should be fine. import secrets import yaml -from fastapi import Header, HTTPException +from fastapi import Header, HTTPException, Request from pydantic import BaseModel from loguru import logger from typing import Optional +from common.utils import coalesce + class AuthKeys(BaseModel): """ @@ -75,7 +77,23 @@ def load_auth_keys(disable_from_config: bool): ) -async def validate_key_permission(test_key: str): +def get_key_permission(request: Request): + """ + Gets the key permission from a request. + + Internal only! Use the depends functions for incoming requests. + """ + + # Hyphens are okay here + test_key = coalesce( + request.headers.get("authorization"), + request.headers.get("x-admin-key"), + request.headers.get("x-api-key"), + ) + + if test_key is None: + raise ValueError("The provided authentication key is missing.") + if test_key.lower().startswith("bearer"): test_key = test_key.split(" ")[1] @@ -88,7 +106,9 @@ async def validate_key_permission(test_key: str): async def check_api_key( - x_api_key: str = Header(None), authorization: str = Header(None) + x_api_key: str = Header(None), + x_admin_key: str = Header(None), + authorization: str = Header(None), ): """Check if the API key is valid.""" @@ -101,6 +121,11 @@ async def check_api_key( raise HTTPException(401, "Invalid API key") return x_api_key + if x_admin_key: + if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"): + raise HTTPException(401, "Invalid API key") + return x_admin_key + if authorization: split_key = authorization.split(" ") if len(split_key) < 2: diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 0e4f27b..78c29b1 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,16 +1,15 @@ import asyncio import pathlib -from fastapi import APIRouter, Depends, HTTPException, Header, Request +from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from typing import Optional from common import config, model, gen_logging, sampling -from common.auth import check_admin_key, check_api_key, validate_key_permission +from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates -from common.utils import coalesce, unwrap +from common.utils import unwrap from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( @@ -432,24 +431,18 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: @router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) -async def get_key_permission( - x_admin_key: Optional[str] = Header(None), - x_api_key: Optional[str] = Header(None), - authorization: Optional[str] = Header(None), -) -> AuthPermissionResponse: +async def key_permission(request: Request) -> AuthPermissionResponse: """ Gets the access level/permission of a provided key in headers. Priority: - - X-api-key - - X-admin-key - Authorization + - X-admin-key + - X-api-key """ - test_key = coalesce(x_admin_key, x_api_key, authorization) - try: - permission = await validate_key_permission(test_key) + permission = get_key_permission(request) return AuthPermissionResponse(permission=permission) except ValueError as exc: error_message = handle_request_error(str(exc)).error.message From dfb4c51d5ff41bf0a6d317804d90fef600ac34c8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 11 Jul 2024 11:26:50 -0400 Subject: [PATCH 02/89] OAI: Fix function idioms Make functions mean the same thing to avoid confusion. Signed-off-by: kingbri --- endpoints/OAI/router.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 78c29b1..0cf921a 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -192,7 +192,7 @@ async def list_models() -> ModelList: "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def get_current_model() -> ModelCard: +async def current_model() -> ModelCard: """Returns the currently loaded model.""" model_params = model.container.get_model_parameters() draft_model_params = model_params.pop("draft", {}) @@ -313,7 +313,7 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes # Lora list endpoint @router.get("/v1/loras", dependencies=[Depends(check_api_key)]) @router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) -async def get_all_loras() -> LoraList: +async def list_all_loras() -> LoraList: """Lists all LoRAs in the lora directory.""" lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) @@ -326,9 +326,9 @@ async def get_all_loras() -> LoraList: "/v1/lora", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def get_active_loras() -> LoraList: +async def active_loras() -> LoraList: """Returns the currently loaded loras.""" - active_loras = LoraList( + loras = LoraList( data=[ LoraCard( id=pathlib.Path(lora.lora_path).parent.name, @@ -338,7 +338,7 @@ async def get_active_loras() -> LoraList: ] ) - return active_loras + return loras # Load lora endpoint @@ -452,7 +452,7 @@ async def key_permission(request: Request) -> AuthPermissionResponse: @router.get("/v1/templates", dependencies=[Depends(check_api_key)]) @router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -async def get_templates() -> TemplateList: +async def list_templates() -> TemplateList: templates = get_all_templates() template_strings = [template.stem for template in templates] return TemplateList(data=template_strings) From 10890913b81668785fe89abc2cd756bfa775dad4 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 11 Jul 2024 12:33:06 -0400 Subject: [PATCH 03/89] Auth: Revert x-admin-key allowance in API key check These kinda clash with each other. Use the correct header for the correct endpoint. Signed-off-by: kingbri --- common/auth.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/common/auth.py b/common/auth.py index 7ba13de..623de63 100644 --- a/common/auth.py +++ b/common/auth.py @@ -107,8 +107,7 @@ def get_key_permission(request: Request): async def check_api_key( x_api_key: str = Header(None), - x_admin_key: str = Header(None), - authorization: str = Header(None), + authorization: str = Header(None) ): """Check if the API key is valid.""" @@ -121,11 +120,6 @@ async def check_api_key( raise HTTPException(401, "Invalid API key") return x_api_key - if x_admin_key: - if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"): - raise HTTPException(401, "Invalid API key") - return x_admin_key - if authorization: split_key = authorization.split(" ") if len(split_key) < 2: From 1f46a1130ced1b0d6a457134206e3aa134780147 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 11 Jul 2024 14:06:03 -0400 Subject: [PATCH 04/89] OAI: Restrict list permissions for API keys API keys are not allowed to view all the admin's models, templates, draft models, loras, etc. Basically anything that can be viewed on the filesystem outside of anything that's currently loaded is not allowed to be returned unless an admin key is present. This change helps preserve user privacy while not erroring out on list endpoints that the OAI spec requires. Signed-off-by: kingbri --- backends/exllamav2/model.py | 3 + common/auth.py | 3 +- endpoints/OAI/router.py | 108 +++++++++++++++++------------------ endpoints/OAI/utils/lora.py | 16 ++++++ endpoints/OAI/utils/model.py | 49 +++++++++++++++- 5 files changed, 119 insertions(+), 60 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5f4e86b..04c6d08 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -405,6 +405,9 @@ class ExllamaV2Container: def get_model_path(self, is_draft: bool = False): """Get the path for this model.""" + if is_draft and not self.draft_config: + return None + model_path = pathlib.Path( self.draft_config.model_dir if is_draft else self.config.model_dir ) diff --git a/common/auth.py b/common/auth.py index 623de63..7c0d83b 100644 --- a/common/auth.py +++ b/common/auth.py @@ -106,8 +106,7 @@ def get_key_permission(request: Request): async def check_api_key( - x_api_key: str = Header(None), - authorization: str = Header(None) + x_api_key: str = Header(None), authorization: str = Header(None) ): """Check if the API key is valid.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 0cf921a..5f6d1d1 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from common import config, model, gen_logging, sampling +from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download from common.networking import handle_request_error, run_with_request_disconnect @@ -18,7 +18,6 @@ from endpoints.OAI.types.chat_completion import ( ) from endpoints.OAI.types.download import DownloadRequest, DownloadResponse from endpoints.OAI.types.lora import ( - LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse, @@ -27,7 +26,6 @@ from endpoints.OAI.types.model import ( ModelCard, ModelList, ModelLoadRequest, - ModelCardParameters, ModelLoadResponse, ) from endpoints.OAI.types.sampler_overrides import ( @@ -50,8 +48,13 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.model import get_model_list, stream_model_load -from endpoints.OAI.utils.lora import get_lora_list +from endpoints.OAI.utils.model import ( + get_current_model, + get_current_model_list, + get_model_list, + stream_model_load, +) +from endpoints.OAI.utils.lora import get_active_loras, get_lora_list router = APIRouter() @@ -172,7 +175,7 @@ async def chat_completion_request( # Model list endpoint @router.get("/v1/models", dependencies=[Depends(check_api_key)]) @router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) -async def list_models() -> ModelList: +async def list_models(request: Request) -> ModelList: """Lists all models in the model directory.""" model_config = config.model_config() model_dir = unwrap(model_config.get("model_dir"), "models") @@ -180,7 +183,11 @@ async def list_models() -> ModelList: draft_model_dir = config.draft_model_config().get("draft_model_dir") - models = get_model_list(model_path.resolve(), draft_model_dir) + if get_key_permission(request) == "admin": + models = get_model_list(model_path.resolve(), draft_model_dir) + else: + models = await get_current_model_list() + if unwrap(model_config.get("use_dummy_models"), False): models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) @@ -194,43 +201,23 @@ async def list_models() -> ModelList: ) async def current_model() -> ModelCard: """Returns the currently loaded model.""" - model_params = model.container.get_model_parameters() - draft_model_params = model_params.pop("draft", {}) - if draft_model_params: - model_params["draft"] = ModelCard( - id=unwrap(draft_model_params.get("name"), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - else: - draft_model_params = None - - model_card = ModelCard( - id=unwrap(model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(model_params), - logging=gen_logging.PREFERENCES, - ) - - if draft_model_params: - draft_card = ModelCard( - id=unwrap(draft_model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - - model_card.parameters.draft = draft_card - - return model_card + return get_current_model() @router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) -async def list_draft_models() -> ModelList: +async def list_draft_models(request: Request) -> ModelList: """Lists all draft models in the model directory.""" - draft_model_dir = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - draft_model_path = pathlib.Path(draft_model_dir) - models = get_model_list(draft_model_path.resolve()) + if get_key_permission(request) == "admin": + draft_model_dir = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + draft_model_path = pathlib.Path(draft_model_dir) + + models = get_model_list(draft_model_path.resolve()) + else: + models = await get_current_model_list(is_draft=True) return models @@ -313,10 +300,14 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes # Lora list endpoint @router.get("/v1/loras", dependencies=[Depends(check_api_key)]) @router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) -async def list_all_loras() -> LoraList: +async def list_all_loras(request: Request) -> LoraList: """Lists all LoRAs in the lora directory.""" - lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - loras = get_lora_list(lora_path.resolve()) + + if get_key_permission(request) == "admin": + lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + loras = get_lora_list(lora_path.resolve()) + else: + loras = get_active_loras() return loras @@ -328,17 +319,8 @@ async def list_all_loras() -> LoraList: ) async def active_loras() -> LoraList: """Returns the currently loaded loras.""" - loras = LoraList( - data=[ - LoraCard( - id=pathlib.Path(lora.lora_path).parent.name, - scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, - ) - for lora in model.container.get_loras() - ] - ) - return loras + return get_active_loras() # Load lora endpoint @@ -452,9 +434,17 @@ async def key_permission(request: Request) -> AuthPermissionResponse: @router.get("/v1/templates", dependencies=[Depends(check_api_key)]) @router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -async def list_templates() -> TemplateList: - templates = get_all_templates() - template_strings = [template.stem for template in templates] +async def list_templates(request: Request) -> TemplateList: + """Get a list of all templates.""" + + template_strings = [] + if get_key_permission(request) == "admin": + templates = get_all_templates() + template_strings = [template.stem for template in templates] + else: + if model.container and model.container.prompt_template: + template_strings.append(model.container.prompt_template.name) + return TemplateList(data=template_strings) @@ -464,6 +454,7 @@ async def list_templates() -> TemplateList: ) async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template""" + if not data.name: error_message = handle_request_error( "New template name not found.", @@ -496,11 +487,16 @@ async def unload_template(): # Sampler override endpoints @router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) @router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) -async def list_sampler_overrides() -> SamplerOverrideListResponse: +async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: """API wrapper to list all currently applied sampler overrides""" + if get_key_permission(request) == "admin": + presets = sampling.get_all_presets() + else: + presets = [] + return SamplerOverrideListResponse( - presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump() + presets=presets, **sampling.overrides_container.model_dump() ) diff --git a/endpoints/OAI/utils/lora.py b/endpoints/OAI/utils/lora.py index d00910f..3e31f68 100644 --- a/endpoints/OAI/utils/lora.py +++ b/endpoints/OAI/utils/lora.py @@ -1,5 +1,6 @@ import pathlib +from common import model from endpoints.OAI.types.lora import LoraCard, LoraList @@ -12,3 +13,18 @@ def get_lora_list(lora_path: pathlib.Path): lora_list.data.append(lora_card) return lora_list + + +def get_active_loras(): + if model.container: + active_loras = [ + LoraCard( + id=pathlib.Path(lora.lora_path).parent.name, + scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, + ) + for lora in model.container.get_loras() + ] + else: + active_loras = [] + + return LoraList(data=active_loras) diff --git a/endpoints/OAI/utils/model.py b/endpoints/OAI/utils/model.py index 0502193..a8057e4 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/OAI/utils/model.py @@ -2,11 +2,12 @@ import pathlib from asyncio import CancelledError from typing import Optional -from common import model +from common import gen_logging, model from common.networking import get_generator_error, handle_request_disconnect - +from common.utils import unwrap from endpoints.OAI.types.model import ( ModelCard, + ModelCardParameters, ModelList, ModelLoadRequest, ModelLoadResponse, @@ -31,6 +32,50 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list +async def get_current_model_list(is_draft: bool = False): + """Gets the current model in list format and with path only.""" + current_models = [] + + # Make sure the model container exists + if model.container: + model_path = model.container.get_model_path(is_draft) + if model_path: + current_models.append(ModelCard(id=model_path.name)) + + return ModelList(data=current_models) + + +def get_current_model(): + """Gets the current model with all parameters.""" + + model_params = model.container.get_model_parameters() + draft_model_params = model_params.pop("draft", {}) + + if draft_model_params: + model_params["draft"] = ModelCard( + id=unwrap(draft_model_params.get("name"), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), + ) + else: + draft_model_params = None + + model_card = ModelCard( + id=unwrap(model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(model_params), + logging=gen_logging.PREFERENCES, + ) + + if draft_model_params: + draft_card = ModelCard( + id=unwrap(draft_model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), + ) + + model_card.parameters.draft = draft_card + + return model_card + + async def stream_model_load( data: ModelLoadRequest, model_path: pathlib.Path, From 9fc3fc4c5447bce5fca3073b52f15311a19922c0 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 11 Jul 2024 14:10:24 -0400 Subject: [PATCH 05/89] OAI: Amend comments Clarify what the user can and can't see. Signed-off-by: kingbri --- endpoints/OAI/router.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 5f6d1d1..f3ab99f 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -176,7 +176,12 @@ async def chat_completion_request( @router.get("/v1/models", dependencies=[Depends(check_api_key)]) @router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(request: Request) -> ModelList: - """Lists all models in the model directory.""" + """ + Lists all models in the model directory. + + Requires an admin key to see all models. + """ + model_config = config.model_config() model_dir = unwrap(model_config.get("model_dir"), "models") model_path = pathlib.Path(model_dir) @@ -207,7 +212,11 @@ async def current_model() -> ModelCard: @router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(request: Request) -> ModelList: - """Lists all draft models in the model directory.""" + """ + Lists all draft models in the model directory. + + Requires an admin key to see all draft models. + """ if get_key_permission(request) == "admin": draft_model_dir = unwrap( @@ -301,7 +310,11 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes @router.get("/v1/loras", dependencies=[Depends(check_api_key)]) @router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) async def list_all_loras(request: Request) -> LoraList: - """Lists all LoRAs in the lora directory.""" + """ + Lists all LoRAs in the lora directory. + + Requires an admin key to see all LoRAs. + """ if get_key_permission(request) == "admin": lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) @@ -406,6 +419,7 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: ) async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: """Decodes tokens into a string.""" + message = model.container.decode_tokens(data.tokens, **data.get_params()) response = TokenDecodeResponse(text=unwrap(message, "")) @@ -435,7 +449,11 @@ async def key_permission(request: Request) -> AuthPermissionResponse: @router.get("/v1/templates", dependencies=[Depends(check_api_key)]) @router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) async def list_templates(request: Request) -> TemplateList: - """Get a list of all templates.""" + """ + Get a list of all templates. + + Requires an admin key to see all templates. + """ template_strings = [] if get_key_permission(request) == "admin": @@ -453,7 +471,7 @@ async def list_templates(request: Request) -> TemplateList: dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) async def switch_template(data: TemplateSwitchRequest): - """Switch the currently loaded template""" + """Switch the currently loaded template.""" if not data.name: error_message = handle_request_error( @@ -488,7 +506,11 @@ async def unload_template(): @router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) @router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: - """API wrapper to list all currently applied sampler overrides""" + """ + List all currently applied sampler overrides. + + Requires an admin key to see all override presets. + """ if get_key_permission(request) == "admin": presets = sampling.get_all_presets() From 073e9fa6f01bbcc919b97f1e25a9a54ce4a76208 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 11 Jul 2024 14:11:37 -0400 Subject: [PATCH 06/89] Dependencies: Bump ExllamaV2 v0.1.7 Signed-off-by: kingbri --- backends/exllamav2/utils.py | 2 +- pyproject.toml | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 5b1d567..a5e8779 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -7,7 +7,7 @@ import torch def check_exllama_version(): """Verifies the exllama version""" - required_version = version.parse("0.1.6") + required_version = version.parse("0.1.7") current_version = version.parse(package_version("exllamav2").split("+")[0]) unsupported_message = ( diff --git a/pyproject.toml b/pyproject.toml index 9d9aaf8..836113b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,12 +58,12 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", @@ -85,12 +85,12 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", @@ -109,9 +109,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options From b149d3398d4f62c282452121fc6057c4439eb123 Mon Sep 17 00:00:00 2001 From: Volodymyr Kuznetsov Date: Mon, 8 Jul 2024 13:42:54 -0700 Subject: [PATCH 07/89] OAI: support stream_options argument --- endpoints/OAI/types/chat_completion.py | 1 + endpoints/OAI/types/common.py | 5 +++++ endpoints/OAI/utils/chat_completion.py | 19 ++++++++++++++++++- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea..b50e646 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -64,3 +64,4 @@ class ChatCompletionStreamChunk(BaseModel): created: int = Field(default_factory=lambda: int(time())) model: str object: str = "chat.completion.chunk" + usage: Optional[UsageStats] = None diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index d44e41a..6970adf 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel): type: str = "text" +class ChatCompletionStreamOptions(BaseModel): + include_usage: Optional[bool] = False + + class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" @@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False + stream_options: Optional[ChatCompletionStreamOptions] = None logprobs: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("logprobs", 0) ) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e82b1b..9b91d1d 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -246,6 +246,7 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + need_usage = data.stream_options and data.stream_options.include_usage try: gen_params = data.to_gen_params() @@ -275,10 +276,26 @@ async def stream_generate_chat_completion( raise generation response = _create_stream_chunk(const_id, generation, model_path.name) - yield response.model_dump_json() + yield response.model_dump_json(exclude=None if need_usage else "usage") # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): + if need_usage: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) + + response = ChatCompletionStreamChunk( + id=const_id, + choices=[], + model=unwrap(model_path.name, ""), + usage=UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + yield response.model_dump_json() break except CancelledError: # Get out if the request gets disconnected From 2e5cf0ea3f56317ac1e156729287206d756eece4 Mon Sep 17 00:00:00 2001 From: Amgad Hasan <109704569+AmgadHasan@users.noreply.github.com> Date: Fri, 12 Jul 2024 13:23:58 +0000 Subject: [PATCH 08/89] Fix docker compose volume mount --- docker/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index fd6634c..c337b20 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -12,7 +12,7 @@ services: - NAME=TabbyAPI - NVIDIA_VISIBLE_DEVICES=all volumes: - - ./models:/usr/src/app/models + - ./models:/app/models deploy: resources: reservations: From 59175156964b31d5d3117f27115ceaf815d3022a Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 12 Jul 2024 10:09:49 -0400 Subject: [PATCH 09/89] Dependencies: Update flash-attention v2.6.1 Signed-off-by: kingbri --- pyproject.toml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 836113b..ecf6362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,14 +66,14 @@ cu121 = [ "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] cu118 = [ # Torch @@ -93,9 +93,9 @@ cu118 = [ "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] amd = [ # Torch triton for ROCm From c1b61441f46d7601076fe7f737475b3f2392f61d Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 12 Jul 2024 14:35:48 -0400 Subject: [PATCH 10/89] OAI: Fix usage chunk return Place the logic into their proper utility functions and cleanup the code with formatting. Also, OAI's docs specify that a [DONE] return is needed when everything is finished. Signed-off-by: kingbri --- endpoints/OAI/utils/chat_completion.py | 49 ++++++++++++++++---------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9b91d1d..10f25cd 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -93,22 +93,37 @@ def _create_stream_chunk( const_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, + is_usage_chunk: bool = False, ): """Create a chat completion stream chunk from the provided text.""" index = generation.get("index") - logprob_response = None + choices = [] + usage_stats = None - if "finish_reason" in generation: + if is_usage_chunk: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) + + usage_stats = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + elif "finish_reason" in generation: choice = ChatCompletionStreamChoice( index=index, finish_reason=generation.get("finish_reason"), ) + + choices.append(choice) else: message = ChatCompletionMessage( role="assistant", content=unwrap(generation.get("text"), "") ) + logprob_response = None + token_probs = unwrap(generation.get("token_probs"), {}) if token_probs: logprobs = unwrap(generation.get("logprobs"), {}) @@ -132,8 +147,13 @@ def _create_stream_chunk( logprobs=logprob_response, ) + choices.append(choice) + chunk = ChatCompletionStreamChunk( - id=const_id, choices=[choice], model=unwrap(model_name, "") + id=const_id, + choices=choices, + model=unwrap(model_name, ""), + usage=usage_stats, ) return chunk @@ -246,7 +266,6 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) - need_usage = data.stream_options and data.stream_options.include_usage try: gen_params = data.to_gen_params() @@ -276,26 +295,18 @@ async def stream_generate_chat_completion( raise generation response = _create_stream_chunk(const_id, generation, model_path.name) - yield response.model_dump_json(exclude=None if need_usage else "usage") + yield response.model_dump_json() # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): - if need_usage: - prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) - completion_tokens = unwrap(generation.get("generated_tokens"), 0) - - response = ChatCompletionStreamChunk( - id=const_id, - choices=[], - model=unwrap(model_path.name, ""), - usage=UsageStats( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), + # Send a usage chunk + if data.stream_options and data.stream_options.include_usage: + usage_chunk = _create_stream_chunk( + const_id, generation, model_path.name, is_usage_chunk=True ) + yield usage_chunk.model_dump_json() - yield response.model_dump_json() + yield "[DONE]" break except CancelledError: # Get out if the request gets disconnected From 6019c936379ffaa702d6da0454f7ae8eb4e20cae Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 13 Jul 2024 17:59:58 -0400 Subject: [PATCH 11/89] Networking: Gate sending tracebacks over the API It's possible that tracebacks can give too much info about a system when sent over the API. Gate this under a flag to send them only when debugging since this feature is still useful. Signed-off-by: kingbri --- common/networking.py | 12 +++++++++--- config_sample.yml | 4 ++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/common/networking.py b/common/networking.py index 47ebe06..1afce68 100644 --- a/common/networking.py +++ b/common/networking.py @@ -8,6 +8,9 @@ from loguru import logger from pydantic import BaseModel from typing import Optional +from common import config +from common.utils import unwrap + class TabbyRequestErrorMessage(BaseModel): """Common request error type.""" @@ -33,15 +36,18 @@ def get_generator_error(message: str, exc_info: bool = True): def handle_request_error(message: str, exc_info: bool = True): """Log a request error to the console.""" + trace = traceback.format_exc() + send_trace = unwrap(config.network_config().get("send_tracebacks"), False) + error_message = TabbyRequestErrorMessage( - message=message, trace=traceback.format_exc() + message=message, trace=trace if send_trace else None ) request_error = TabbyRequestError(error=error_message) # Log the error and provided message to the console - if error_message.trace and exc_info: - logger.error(error_message.trace) + if trace and exc_info: + logger.error(trace) logger.error(f"Sent to request: {message}") diff --git a/config_sample.yml b/config_sample.yml index 0e4b180..3bde5ee 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -19,6 +19,10 @@ network: # Turn on this option if you are ONLY connecting from localhost disable_auth: False + # Send tracebacks over the API to clients (default: False) + # NOTE: Only enable this for debug purposes + send_tracebacks: False + # Options for logging logging: # Enable prompt logging (default: False) From 9dae46114288dc1feb6a3f70fce5e5d1c117eeab Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 15 Jul 2024 01:09:49 -0400 Subject: [PATCH 12/89] Model: Attempt to recreate generator on a fatal error If a job causes the generator to error, tabby stops working until a relaunch. It's better to try establishing a system of redundancy and remake the generator in the event that it fails. May replace this with an exit signal for a fatal error instead, but not sure. Signed-off-by: kingbri --- backends/exllamav2/model.py | 49 ++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 04c6d08..0c65b38 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -488,15 +488,7 @@ class ExllamaV2Container: yield value # Create async generator - self.generator = ExLlamaV2DynamicGeneratorAsync( - model=self.model, - cache=self.cache, - draft_model=self.draft_model, - draft_cache=self.draft_cache, - tokenizer=self.tokenizer, - max_batch_size=self.max_batch_size, - paged=self.paged, - ) + await self.create_generator() # Clean up any extra vram usage from torch and cuda # (Helps reduce VRAM bottlenecking on Windows) @@ -645,6 +637,34 @@ class ExllamaV2Container: input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) + async def create_generator(self): + try: + # Don't acquire locks unless a model is loaded + if self.model_loaded: + await self.load_lock.acquire() + + # Immediately cancel all jobs + await self.wait_for_jobs(skip_wait=True) + + # Create new generator + self.generator = ExLlamaV2DynamicGeneratorAsync( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + paged=self.paged, + ) + finally: + # This means the generator is being recreated + # The load lock is already released in the load function + if self.model_loaded: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + def get_loras(self): """Convenience function to get all loras.""" @@ -1223,3 +1243,14 @@ class ExllamaV2Container: break except asyncio.CancelledError: await job.cancel() + except Exception as ex: + # Create a new generator since the current state is broken + # No need to wait for this to finish + logger.error( + "FATAL ERROR with generation. " + "Attempting to recreate the generator. " + "If this fails, please restart the server.\n" + ) + asyncio.ensure_future(self.create_generator()) + + raise ex From 933404c185a4d66d764a9a939d54144102b527c2 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 15 Jul 2024 11:33:44 -0400 Subject: [PATCH 13/89] Model: Warn user if terminating jobs If skip_wait is true, it's best to let the user know that all jobs will be forcibly cancelled. Signed-off-by: kingbri --- backends/exllamav2/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 0c65b38..74cf713 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -449,6 +449,11 @@ class ExllamaV2Container: # Immediately abort all jobs if asked if skip_wait: + logger.warning( + "Immediately terminating all jobs. " + "Clients will have their requests cancelled.\n" + ) + # Requires a copy to avoid errors during iteration jobs_copy = self.generator.jobs.copy() for job in jobs_copy.values(): From e20a2d504b95b12560cb3a90d4841a7e9d6b0e1e Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 15 Jul 2024 14:39:55 -0400 Subject: [PATCH 14/89] API: Fix pydantic validation errors on disconnect poll returns Raise a 422 exception for the disconnect. This prevents pydantic errors when returning a "response" which doesn't contain anything in this case. Signed-off-by: kingbri --- common/networking.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/networking.py b/common/networking.py index 1afce68..5540105 100644 --- a/common/networking.py +++ b/common/networking.py @@ -3,7 +3,7 @@ import asyncio import socket import traceback -from fastapi import Request +from fastapi import HTTPException, Request from loguru import logger from pydantic import BaseModel from typing import Optional @@ -84,8 +84,9 @@ async def run_with_request_disconnect( try: return call_task.result() - except (asyncio.CancelledError, asyncio.InvalidStateError): + except (asyncio.CancelledError, asyncio.InvalidStateError) as ex: handle_request_disconnect(disconnect_message) + raise HTTPException(422, disconnect_message) from ex def is_port_in_use(port: int) -> bool: From 38185a1ff4a1a24134a579b5325df2bcb68653d5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 19 Jul 2024 10:08:57 -0400 Subject: [PATCH 15/89] Auth: Fix key check coalesce Prefer the auth-specific headers before the generic authorization header. Signed-off-by: kingbri --- common/auth.py | 2 +- endpoints/OAI/router.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/auth.py b/common/auth.py index 7c0d83b..4f7c9f8 100644 --- a/common/auth.py +++ b/common/auth.py @@ -86,9 +86,9 @@ def get_key_permission(request: Request): # Hyphens are okay here test_key = coalesce( - request.headers.get("authorization"), request.headers.get("x-admin-key"), request.headers.get("x-api-key"), + request.headers.get("authorization"), ) if test_key is None: diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index f3ab99f..1c0a7c6 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -432,9 +432,9 @@ async def key_permission(request: Request) -> AuthPermissionResponse: Gets the access level/permission of a provided key in headers. Priority: - - Authorization - X-admin-key - X-api-key + - Authorization """ try: From cae94b920c2a226e70b775be4af5063388c8951e Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 21 Jul 2024 21:01:05 -0400 Subject: [PATCH 16/89] API: Add ability to use request IDs Identify which request is being processed to help users disambiguate which logs correspond to which request. Signed-off-by: kingbri --- backends/exllamav2/model.py | 14 +++-- common/gen_logging.py | 6 +- endpoints/OAI/router.py | 10 ++-- endpoints/OAI/utils/chat_completion.py | 77 ++++++++++++++------------ endpoints/OAI/utils/completion.py | 50 +++++++++++++---- endpoints/server.py | 12 +++- 6 files changed, 112 insertions(+), 57 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 74cf713..200be6b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -828,10 +828,10 @@ class ExllamaV2Container: return dict(zip_longest(top_tokens, cleaned_values)) - async def generate(self, prompt: str, **kwargs): + async def generate(self, prompt: str, request_id: str, **kwargs): """Generate a response to a prompt""" generations = [] - async for generation in self.generate_gen(prompt, **kwargs): + async for generation in self.generate_gen(prompt, request_id, **kwargs): generations.append(generation) joined_generation = { @@ -881,7 +881,11 @@ class ExllamaV2Container: return kwargs async def generate_gen( - self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs + self, + prompt: str, + request_id: str, + abort_event: Optional[asyncio.Event] = None, + **kwargs, ): """ Create generator function for prompt completion. @@ -1116,6 +1120,7 @@ class ExllamaV2Container: # Log generation options to console # Some options are too large, so log the args instead log_generation_params( + request_id=request_id, max_tokens=max_tokens, min_tokens=min_tokens, stream=kwargs.get("stream"), @@ -1138,9 +1143,10 @@ class ExllamaV2Container: ) # Log prompt to console - log_prompt(prompt, negative_prompt) + log_prompt(prompt, request_id, negative_prompt) # Create and add a new job + # Don't use the request ID here as there can be multiple jobs per request job_id = uuid.uuid4().hex job = ExLlamaV2DynamicJobAsync( self.generator, diff --git a/common/gen_logging.py b/common/gen_logging.py index fbf10f6..94c4405 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -51,11 +51,13 @@ def log_generation_params(**kwargs): logger.info(f"Generation options: {kwargs}\n") -def log_prompt(prompt: str, negative_prompt: Optional[str]): +def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]): """Logs the prompt to console.""" if PREFERENCES.prompt: formatted_prompt = "\n" + prompt - logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n") + logger.info( + f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n" + ) if negative_prompt: formatted_negative_prompt = "\n" + negative_prompt diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 1c0a7c6..1297d87 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -107,12 +107,14 @@ async def completion_request( ping=maxsize, ) else: - generate_task = asyncio.create_task(generate_completion(data, model_path)) + generate_task = asyncio.create_task( + generate_completion(data, request, model_path) + ) response = await run_with_request_disconnect( request, generate_task, - disconnect_message="Completion generation cancelled by user.", + disconnect_message=f"Completion {request.state.id} cancelled by user.", ) return response @@ -161,13 +163,13 @@ async def chat_completion_request( ) else: generate_task = asyncio.create_task( - generate_chat_completion(prompt, data, model_path) + generate_chat_completion(prompt, data, request, model_path) ) response = await run_with_request_disconnect( request, generate_task, - disconnect_message="Chat completion generation cancelled by user.", + disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 10f25cd..b9c6f71 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -5,7 +5,6 @@ import pathlib from asyncio import CancelledError from copy import deepcopy from typing import List, Optional -from uuid import uuid4 from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import ( ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats +from endpoints.OAI.utils.completion import _stream_collector -def _create_response(generations: List[dict], model_name: Optional[str]): +def _create_response( + request_id: str, generations: List[dict], model_name: Optional[str] +): """Create a chat completion response from the provided text.""" prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) @@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]): choices.append(choice) response = ChatCompletionResponse( + id=f"chatcmpl-{request_id}", choices=choices, model=unwrap(model_name, ""), usage=UsageStats( @@ -90,7 +93,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]): def _create_stream_chunk( - const_id: str, + request_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, is_usage_chunk: bool = False, @@ -150,7 +153,7 @@ def _create_stream_chunk( choices.append(choice) chunk = ChatCompletionStreamChunk( - id=const_id, + id=f"chatcmpl-{request_id}", choices=choices, model=unwrap(model_name, ""), usage=usage_stats, @@ -235,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest): raise HTTPException(400, error_message) from exc -async def _stream_collector( - task_idx: int, - gen_queue: asyncio.Queue, - prompt: str, - abort_event: asyncio.Event, - **kwargs, -): - """Collects a stream and places results in a common queue""" - - try: - new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) - async for generation in new_generation: - generation["index"] = task_idx - - await gen_queue.put(generation) - - if "finish_reason" in generation: - break - except Exception as e: - await gen_queue.put(e) - - async def stream_generate_chat_completion( prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): """Generator for the generation process.""" - const_id = f"chatcmpl-{uuid4().hex}" abort_event = asyncio.Event() gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: + logger.info(f"Recieved chat completion streaming request {request.state.id}") + gen_params = data.to_gen_params() for n in range(0, data.n): @@ -277,7 +259,14 @@ async def stream_generate_chat_completion( task_gen_params = gen_params gen_task = asyncio.create_task( - _stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params) + _stream_collector( + n, + gen_queue, + prompt, + request.state.id, + abort_event, + **task_gen_params, + ) ) gen_tasks.append(gen_task) @@ -286,7 +275,9 @@ async def stream_generate_chat_completion( while True: if disconnect_task.done(): abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Chat completion generation {request.state.id} cancelled by user." + ) generation = await gen_queue.get() @@ -294,7 +285,9 @@ async def stream_generate_chat_completion( if isinstance(generation, Exception): raise generation - response = _create_stream_chunk(const_id, generation, model_path.name) + response = _create_stream_chunk( + request.state.id, generation, model_path.name + ) yield response.model_dump_json() # Check if all tasks are completed @@ -302,10 +295,17 @@ async def stream_generate_chat_completion( # Send a usage chunk if data.stream_options and data.stream_options.include_usage: usage_chunk = _create_stream_chunk( - const_id, generation, model_path.name, is_usage_chunk=True + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, ) yield usage_chunk.model_dump_json() + logger.info( + f"Finished chat completion streaming request {request.state.id}" + ) + yield "[DONE]" break except CancelledError: @@ -320,7 +320,7 @@ async def stream_generate_chat_completion( async def generate_chat_completion( - prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path + prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): gen_tasks: List[asyncio.Task] = [] gen_params = data.to_gen_params() @@ -335,16 +335,23 @@ async def generate_chat_completion( task_gen_params = gen_params gen_tasks.append( - asyncio.create_task(model.container.generate(prompt, **task_gen_params)) + asyncio.create_task( + model.container.generate( + prompt, request.state.id, **task_gen_params + ) + ) ) generations = await asyncio.gather(*gen_tasks) - response = _create_response(generations, model_path.name) + response = _create_response(request.state.id, generations, model_path.name) + + logger.info(f"Finished chat completion request {request.state.id}") return response except Exception as exc: error_message = handle_request_error( - "Chat completion aborted. Maybe the model was unloaded? " + f"Chat completion {request.state.id} aborted. " + "Maybe the model was unloaded? " "Please check the server console." ).error.message diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 2b5dfbf..23f2692 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,6 +7,8 @@ from copy import deepcopy from fastapi import HTTPException, Request from typing import List, Union +from loguru import logger + from common import model from common.networking import ( get_generator_error, @@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import ( from endpoints.OAI.types.common import UsageStats -def _create_response(generations: Union[dict, List[dict]], model_name: str = ""): +def _create_response( + request_id: str, generations: Union[dict, List[dict]], model_name: str = "" +): """Create a completion response from the provided choices.""" # Convert the single choice object into a list @@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "") completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) response = CompletionResponse( + id=f"cmpl-{request_id}", choices=choices, model=model_name, usage=UsageStats( @@ -77,13 +82,16 @@ async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, prompt: str, + request_id: str, abort_event: asyncio.Event, **kwargs, ): """Collects a stream and places results in a common queue""" try: - new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) + new_generation = model.container.generate_gen( + prompt, request_id, abort_event, **kwargs + ) async for generation in new_generation: generation["index"] = task_idx @@ -106,6 +114,8 @@ async def stream_generate_completion( disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: + logger.info(f"Recieved streaming completion request {request.state.id}") + gen_params = data.to_gen_params() for n in range(0, data.n): @@ -116,7 +126,12 @@ async def stream_generate_completion( gen_task = asyncio.create_task( _stream_collector( - n, gen_queue, data.prompt, abort_event, **task_gen_params + n, + gen_queue, + data.prompt, + request.state.id, + abort_event, + **task_gen_params, ) ) @@ -126,7 +141,9 @@ async def stream_generate_completion( while True: if disconnect_task.done(): abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) generation = await gen_queue.get() @@ -134,31 +151,38 @@ async def stream_generate_completion( if isinstance(generation, Exception): raise generation - response = _create_response(generation, model_path.name) + response = _create_response(request.state.id, generation, model_path.name) yield response.model_dump_json() # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): yield "[DONE]" + logger.info(f"Finished streaming completion request {request.state.id}") break except CancelledError: # Get out if the request gets disconnected abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( - "Completion aborted. Please check the server console." + f"Completion {request.state.id} aborted. Please check the server console." ) -async def generate_completion(data: CompletionRequest, model_path: pathlib.Path): +async def generate_completion( + data: CompletionRequest, request: Request, model_path: pathlib.Path +): """Non-streaming generate for completions""" gen_tasks: List[asyncio.Task] = [] gen_params = data.to_gen_params() try: + logger.info(f"Recieved completion request {request.state.id}") + for n in range(0, data.n): # Deepcopy gen params above the first index # to ensure nested structures aren't shared @@ -169,17 +193,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path) gen_tasks.append( asyncio.create_task( - model.container.generate(data.prompt, **task_gen_params) + model.container.generate( + data.prompt, request.state.id, **task_gen_params + ) ) ) generations = await asyncio.gather(*gen_tasks) - response = _create_response(generations, model_path.name) + response = _create_response(request.state.id, generations, model_path.name) + + logger.info(f"Finished completion request {request.state.id}") return response except Exception as exc: error_message = handle_request_error( - "Completion aborted. Maybe the model was unloaded? " + f"Completion {request.state.id} aborted. Maybe the model was unloaded? " "Please check the server console." ).error.message diff --git a/endpoints/server.py b/endpoints/server.py index 7ceb208..ec59455 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,5 +1,6 @@ +from uuid import uuid4 import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from loguru import logger @@ -25,6 +26,15 @@ app.add_middleware( ) +@app.middleware("http") +async def add_request_id(request: Request, call_next): + """Middleware to append an ID to a request""" + + request.state.id = uuid4().hex + response = await call_next(request) + return response + + def setup_app(): """Includes the correct routers for startup""" From 0eedc8ca146dea5d532a8022f4bef8a8c4b899b7 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 12:19:46 -0400 Subject: [PATCH 17/89] API: Switch from request ID middleware to depends Middleware runs on both the request and response. Therefore, streaming responses had increased latency when processing tasks and sending data to the client which resulted in erratic streaming behavior. Use a depends to add request IDs since it only executes when the request is run rather than expecting the response to be sent as well. For the future, it would be best to think about limiting the time between each tick of chunk data to be safe. Signed-off-by: kingbri --- common/networking.py | 8 ++++++++ endpoints/server.py | 14 +++----------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/common/networking.py b/common/networking.py index 5540105..706cb60 100644 --- a/common/networking.py +++ b/common/networking.py @@ -7,6 +7,7 @@ from fastapi import HTTPException, Request from loguru import logger from pydantic import BaseModel from typing import Optional +from uuid import uuid4 from common import config from common.utils import unwrap @@ -100,3 +101,10 @@ def is_port_in_use(port: int) -> bool: test_socket.settimeout(1) with test_socket: return test_socket.connect_ex(("localhost", port)) == 0 + + +async def add_request_id(request: Request): + """FastAPI depends to add a UUID to a request's state.""" + + request.state.id = uuid4().hex + return request diff --git a/endpoints/server.py b/endpoints/server.py index ec59455..0a29e15 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,10 +1,10 @@ -from uuid import uuid4 import uvicorn -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger from common.logger import UVICORN_LOG_CONFIG +from common.networking import add_request_id from endpoints.OAI.router import router as OAIRouter app = FastAPI( @@ -14,6 +14,7 @@ app = FastAPI( "This docs page is not meant to send requests! Please use a service " "like Postman or a frontend UI." ), + dependencies=[Depends(add_request_id)] ) # ALlow CORS requests @@ -26,15 +27,6 @@ app.add_middleware( ) -@app.middleware("http") -async def add_request_id(request: Request, call_next): - """Middleware to append an ID to a request""" - - request.state.id = uuid4().hex - response = await call_next(request) - return response - - def setup_app(): """Includes the correct routers for startup""" From 21516bd7b5ca90b190c785c0c767e6045136e4ab Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 12:23:49 -0400 Subject: [PATCH 18/89] Model: Skip empty token chunks This helps make the generation loop more efficient by skipping past chunks that aren't providing any tokens anyways. The offset isn't affected. Signed-off-by: kingbri --- backends/exllamav2/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 200be6b..f42dd00 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1185,13 +1185,15 @@ class ExllamaV2Container: result_id = result.get("identifier") if stage == "streaming" and result_id == job_id: + chunk_tokens = result.get("token_ids") + if chunk_tokens is None: + continue + else: + generated_tokens += chunk_tokens.size(dim=0) + chunk = unwrap(result.get("text"), "") full_response += chunk - chunk_tokens = result.get("token_ids") - if chunk_tokens is not None: - generated_tokens += chunk_tokens.size(dim=0) - generation = { "text": chunk, "prompt_tokens": context_len, From ad4d17bca2a024fb0d22a3d27b90fea59d53d678 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 12:24:34 -0400 Subject: [PATCH 19/89] Tree: Format Signed-off-by: kingbri --- endpoints/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/endpoints/server.py b/endpoints/server.py index 0a29e15..a69ecd1 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -14,7 +14,7 @@ app = FastAPI( "This docs page is not meant to send requests! Please use a service " "like Postman or a frontend UI." ), - dependencies=[Depends(add_request_id)] + dependencies=[Depends(add_request_id)], ) # ALlow CORS requests From 15f891b277c4532d63de25f6e95dd862603cb5db Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 16:25:26 -0400 Subject: [PATCH 20/89] Args: Update to latest config.yml Fix order of params to follow the same flow as config.yml Signed-off-by: kingbri --- common/args.py | 84 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 18 deletions(-) diff --git a/common/args.py b/common/args.py index 14508a7..6bb0804 100644 --- a/common/args.py +++ b/common/args.py @@ -17,13 +17,15 @@ def init_argparser(): """Creates an argument parser that any function can use""" parser = argparse.ArgumentParser( - epilog="These args are only for a subset of the config. " - + "Please edit config.yml for all options!" + epilog="NOTE: These args serve to override parts of the config. " + + "It's highly recommended to edit config.yml for all options and " + + "better descriptions!" ) add_network_args(parser) add_model_args(parser) add_logging_args(parser) add_developer_args(parser) + add_sampling_args(parser) add_config_args(parser) return parser @@ -64,6 +66,11 @@ def add_network_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Disable HTTP token authenticaion with requests", ) + network_group.add_argument( + "--send-tracebacks", + type=str_to_bool, + help="Decide whether to send error tracebacks over the API", + ) def add_model_args(parser: argparse.ArgumentParser): @@ -74,6 +81,17 @@ def add_model_args(parser: argparse.ArgumentParser): "--model-dir", type=str, help="Overrides the directory to look for models" ) model_group.add_argument("--model-name", type=str, help="An initial model to load") + model_group.add_argument( + "--use-dummy-models", + type=str_to_bool, + help="Add dummy OAI model names for API queries", + ) + model_group.add_argument( + "--use-as-default", + type=str, + nargs="+", + help="Names of args to use as a default fallback for API load requests ", + ) model_group.add_argument( "--max-seq-len", type=int, help="Override the maximum model sequence length" ) @@ -82,25 +100,17 @@ def add_model_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Overrides base model context length", ) - model_group.add_argument( - "--cache-size", - type=int, - help="The size of the prompt cache (in number of tokens) to allocate", - ) - model_group.add_argument( - "--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb" - ) - model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK") - model_group.add_argument( - "--prompt-template", - type=str, - help="Set the prompt template for chat completions", - ) model_group.add_argument( "--gpu-split-auto", type=str_to_bool, help="Automatically allocate resources to GPUs", ) + model_group.add_argument( + "--autosplit-reserve", + type=int, + nargs="+", + help="Reserve VRAM used for autosplit loading (in MBs) ", + ) model_group.add_argument( "--gpu-split", type=float, @@ -108,15 +118,44 @@ def add_model_args(parser: argparse.ArgumentParser): help="An integer array of GBs of vram to split between GPUs. " + "Ignored if gpu_split_auto is true", ) + model_group.add_argument( + "--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb" + ) + model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK") + model_group.add_argument( + "--cache-mode", + type=str, + help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)", + ) + model_group.add_argument( + "--cache-size", + type=int, + help="The size of the prompt cache (in number of tokens) to allocate", + ) + model_group.add_argument( + "--chunk-size", + type=int, + help="Chunk size for prompt ingestion", + ) + model_group.add_argument( + "--max-batch-size", + type=int, + help="Maximum amount of prompts to process at one time", + ) + model_group.add_argument( + "--prompt-template", + type=str, + help="Set the jinja2 prompt template for chat completions", + ) model_group.add_argument( "--num-experts-per-token", type=int, help="Number of experts to use per token in MoE models", ) model_group.add_argument( - "--use-cfg", + "--fasttensors", type=str_to_bool, - help="Enables CFG support", + help="Possibly increases model loading speeds", ) @@ -151,3 +190,12 @@ def add_developer_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Disables API request streaming", ) + + +def add_sampling_args(parser: argparse.ArgumentParser): + """Adds sampling-specific arguments""" + + sampling_group = parser.add_argument_group("sampling") + sampling_group.add_argument( + "--override-preset", type=str, help="Select a sampler override preset" + ) From 191600a150df7eb1c1c9ecea290327f3f102680f Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 18:34:00 -0400 Subject: [PATCH 21/89] Revert "Model: Skip empty token chunks" This reverts commit 21516bd7b5ca90b190c785c0c767e6045136e4ab. This skips EOS and implementing it the proper way seems more costly than necessary. Signed-off-by: kingbri --- backends/exllamav2/model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index f42dd00..200be6b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1185,15 +1185,13 @@ class ExllamaV2Container: result_id = result.get("identifier") if stage == "streaming" and result_id == job_id: - chunk_tokens = result.get("token_ids") - if chunk_tokens is None: - continue - else: - generated_tokens += chunk_tokens.size(dim=0) - chunk = unwrap(result.get("text"), "") full_response += chunk + chunk_tokens = result.get("token_ids") + if chunk_tokens is not None: + generated_tokens += chunk_tokens.size(dim=0) + generation = { "text": chunk, "prompt_tokens": context_len, From 522999ebb4478300aa1e5a13443b06dca228cba8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 21:15:16 -0400 Subject: [PATCH 22/89] Config: Change from gen_logging to logging More accurately reflects the config.yml's sections. Signed-off-by: kingbri --- common/config.py | 16 ++++++++-------- main.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/common/config.py b/common/config.py index 86aedac..4de7d6b 100644 --- a/common/config.py +++ b/common/config.py @@ -46,14 +46,14 @@ def from_args(args: dict): GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override} # Generation Logging config - gen_logging_override = args.get("logging") - if gen_logging_override: - cur_gen_logging_config = gen_logging_config() + logging_override = args.get("logging") + if logging_override: + cur_logging_config = logging_config() GLOBAL_CONFIG["logging"] = { - **cur_gen_logging_config, + **cur_logging_config, **{ - k.replace("log_", ""): gen_logging_override[k] - for k in gen_logging_override + k.replace("log_", ""): logging_override[k] + for k in logging_override }, } @@ -90,8 +90,8 @@ def network_config(): return unwrap(GLOBAL_CONFIG.get("network"), {}) -def gen_logging_config(): - """Returns the generation logging config from the global config""" +def logging_config(): + """Returns the logging config from the global config""" return unwrap(GLOBAL_CONFIG.get("logging"), {}) diff --git a/main.py b/main.py index e089d81..fd29297 100644 --- a/main.py +++ b/main.py @@ -97,7 +97,7 @@ async def entrypoint(args: Optional[dict] = None): load_auth_keys(unwrap(network_config.get("disable_auth"), False)) # Override the generation log options if given - log_config = config.gen_logging_config() + log_config = config.logging_config() if log_config: gen_logging.update_from_dict(log_config) From 3826815edbb69c7445ee3799a8d4885ef8d29106 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 21:33:10 -0400 Subject: [PATCH 23/89] API: Add request logging Log all the parts of a request if the config flag is set. The logged fields are all server side anyways, so nothing is being exposed to clients. Signed-off-by: kingbri --- common/config.py | 5 +---- common/networking.py | 32 ++++++++++++++++++++++++++++- config_sample.yml | 4 ++++ endpoints/server.py | 48 +++++++++++++++++++++++--------------------- 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/common/config.py b/common/config.py index 4de7d6b..972b382 100644 --- a/common/config.py +++ b/common/config.py @@ -51,10 +51,7 @@ def from_args(args: dict): cur_logging_config = logging_config() GLOBAL_CONFIG["logging"] = { **cur_logging_config, - **{ - k.replace("log_", ""): logging_override[k] - for k in logging_override - }, + **{k.replace("log_", ""): logging_override[k] for k in logging_override}, } developer_override = args.get("developer") diff --git a/common/networking.py b/common/networking.py index 706cb60..7c088a9 100644 --- a/common/networking.py +++ b/common/networking.py @@ -1,9 +1,10 @@ """Common utility functions""" import asyncio +import json import socket import traceback -from fastapi import HTTPException, Request +from fastapi import Depends, HTTPException, Request from loguru import logger from pydantic import BaseModel from typing import Optional @@ -108,3 +109,32 @@ async def add_request_id(request: Request): request.state.id = uuid4().hex return request + + +async def log_request(request: Request): + """FastAPI depends to log a request to the user.""" + + log_message = [f"Information for {request.method} request {request.state.id}:"] + + log_message.append(f"URL: {request.url}") + log_message.append(f"Headers: {dict(request.headers)}") + + if request.method != "GET": + body_bytes = await request.body() + if body_bytes: + body = json.loads(body_bytes.decode("utf-8")) + + log_message.append(f"Body: {dict(body)}") + + logger.info("\n".join(log_message)) + + +def get_global_depends(): + """Returns global dependencies for a FastAPI app.""" + + depends = [Depends(add_request_id)] + + if config.logging_config().get("requests"): + depends.append(Depends(log_request)) + + return depends diff --git a/config_sample.yml b/config_sample.yml index 3bde5ee..12458de 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -31,6 +31,10 @@ logging: # Enable generation parameter logging (default: False) generation_params: False + # Enable request logging (default: False) + # NOTE: Only use this for debugging! + requests: False + # Options for sampling sampling: # Override preset name. Find this in the sampler-overrides folder (default: None) diff --git a/endpoints/server.py b/endpoints/server.py index a69ecd1..8c10b63 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,42 +1,44 @@ import uvicorn -from fastapi import Depends, FastAPI +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger from common.logger import UVICORN_LOG_CONFIG -from common.networking import add_request_id +from common.networking import get_global_depends from endpoints.OAI.router import router as OAIRouter -app = FastAPI( - title="TabbyAPI", - summary="An OAI compatible exllamav2 API that's both lightweight and fast", - description=( - "This docs page is not meant to send requests! Please use a service " - "like Postman or a frontend UI." - ), - dependencies=[Depends(add_request_id)], -) - -# ALlow CORS requests -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - def setup_app(): """Includes the correct routers for startup""" + app = FastAPI( + title="TabbyAPI", + summary="An OAI compatible exllamav2 API that's both lightweight and fast", + description=( + "This docs page is not meant to send requests! Please use a service " + "like Postman or a frontend UI." + ), + dependencies=get_global_depends(), + ) + + # ALlow CORS requests + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + app.include_router(OAIRouter) + return app + def export_openapi(): """Function to return the OpenAPI JSON from the API server""" - setup_app() + app = setup_app() return app.openapi() @@ -49,7 +51,7 @@ async def start_api(host: str, port: int): logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") # Setup app - setup_app() + app = setup_app() config = uvicorn.Config( app, From 14dfaf600a398c5c60e927f2080f76bbcaf665b2 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 21:41:42 -0400 Subject: [PATCH 24/89] Args: Add request logging Signed-off-by: kingbri --- common/args.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/common/args.py b/common/args.py index 6bb0804..bbf007c 100644 --- a/common/args.py +++ b/common/args.py @@ -171,6 +171,11 @@ def add_logging_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Enable generation parameter logging", ) + logging_group.add_argument( + "--log-requests", + type=str_to_bool, + help="Enable request logging", + ) def add_developer_args(parser: argparse.ArgumentParser): From d1706fb06713efaf2f9ea86db629c58958f80c91 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jul 2024 21:48:59 -0400 Subject: [PATCH 25/89] OAI: Remove double logging if request is cancelled Uvicorn can log in both the request disconnect handler and the CancelledError. However, these sometimes don't work and both need to be checked. But, don't log twice if one works. Signed-off-by: kingbri --- endpoints/OAI/utils/chat_completion.py | 5 +++-- endpoints/OAI/utils/completion.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b9c6f71..c1c0263 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -311,8 +311,9 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - abort_event.set() - handle_request_disconnect("Chat completion generation cancelled by user.") + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 23f2692..fe5520c 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -162,10 +162,11 @@ async def stream_generate_completion( except CancelledError: # Get out if the request gets disconnected - abort_event.set() - handle_request_disconnect( - f"Completion generation {request.state.id} cancelled by user." - ) + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( f"Completion {request.state.id} aborted. Please check the server console." From 64c2cc85c9caa718bb66fac9687caa4584790364 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 12:40:32 -0400 Subject: [PATCH 26/89] OAI: Migrate model depends into proper file Use amongst multiple routers. Signed-off-by: kingbri --- common/model.py | 14 ++++++++++++++ endpoints/OAI/router.py | 15 +-------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/common/model.py b/common/model.py index b925f15..a6477c2 100644 --- a/common/model.py +++ b/common/model.py @@ -5,11 +5,13 @@ Containers exist as a common interface for backends. """ import pathlib +from fastapi import HTTPException from loguru import logger from typing import Optional from common import config from common.logger import get_loading_progress_bar +from common.networking import handle_request_error from common.utils import unwrap from endpoints.utils import do_export_openapi @@ -112,3 +114,15 @@ def get_config_default(key, fallback=None, is_draft=False): return unwrap(model_config.get(key), fallback) else: return fallback + + +async def check_model_container(): + """FastAPI depends that checks if a model isn't loaded or currently loading.""" + + if container is None or not (container.model_is_loading or container.model_loaded): + error_message = handle_request_error( + "No models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 1297d87..4269c89 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -7,6 +7,7 @@ from sys import maxsize from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download +from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap @@ -60,20 +61,6 @@ from endpoints.OAI.utils.lora import get_active_loras, get_lora_list router = APIRouter() -async def check_model_container(): - """FastAPI depends that checks if a model isn't loaded or currently loading.""" - - if model.container is None or not ( - model.container.model_is_loading or model.container.model_loaded - ): - error_message = handle_request_error( - "No models are currently loaded.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - # Completions endpoint @router.post( "/v1/completions", From 9ad69e8ab6419f878fe0ba405f77e1364e8e6548 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 14:08:48 -0400 Subject: [PATCH 27/89] API: Migrate universal routes to core Place OAI specific routes in the appropriate folder. This is in preperation for adding new API servers that can be optionally enabled. Signed-off-by: kingbri --- backends/exllamav2/model.py | 2 + endpoints/OAI/router.py | 427 +---------------- endpoints/core/router.py | 432 ++++++++++++++++++ endpoints/{OAI => core}/types/auth.py | 0 endpoints/{OAI => core}/types/download.py | 0 endpoints/{OAI => core}/types/lora.py | 0 endpoints/{OAI => core}/types/model.py | 0 .../{OAI => core}/types/sampler_overrides.py | 0 endpoints/{OAI => core}/types/template.py | 0 endpoints/{OAI => core}/types/token.py | 0 endpoints/{OAI => core}/utils/lora.py | 2 +- endpoints/{OAI => core}/utils/model.py | 2 +- endpoints/server.py | 4 + 13 files changed, 442 insertions(+), 427 deletions(-) create mode 100644 endpoints/core/router.py rename endpoints/{OAI => core}/types/auth.py (100%) rename endpoints/{OAI => core}/types/download.py (100%) rename endpoints/{OAI => core}/types/lora.py (100%) rename endpoints/{OAI => core}/types/model.py (100%) rename endpoints/{OAI => core}/types/sampler_overrides.py (100%) rename endpoints/{OAI => core}/types/template.py (100%) rename endpoints/{OAI => core}/types/token.py (100%) rename endpoints/{OAI => core}/utils/lora.py (92%) rename endpoints/{OAI => core}/utils/model.py (98%) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 200be6b..057e8c1 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1080,6 +1080,8 @@ class ExllamaV2Container: else [self.tokenizer.eos_token_id] ) + print(self.tokenizer.eos_token_id) + # Ban the EOS token if specified. If not, append to stop conditions # as well. # Set this below logging to avoid polluting the stop strings array diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 4269c89..771b7f3 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,45 +1,18 @@ import asyncio -import pathlib from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from common import config, model, sampling -from common.auth import check_admin_key, check_api_key, get_key_permission -from common.downloader import hf_repo_download +from common import config, model +from common.auth import check_api_key from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect -from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap -from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, ) -from endpoints.OAI.types.download import DownloadRequest, DownloadResponse -from endpoints.OAI.types.lora import ( - LoraList, - LoraLoadRequest, - LoraLoadResponse, -) -from endpoints.OAI.types.model import ( - ModelCard, - ModelList, - ModelLoadRequest, - ModelLoadResponse, -) -from endpoints.OAI.types.sampler_overrides import ( - SamplerOverrideListResponse, - SamplerOverrideSwitchRequest, -) -from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest -from endpoints.OAI.types.token import ( - TokenEncodeRequest, - TokenEncodeResponse, - TokenDecodeRequest, - TokenDecodeResponse, -) from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -49,13 +22,6 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.model import ( - get_current_model, - get_current_model_list, - get_model_list, - stream_model_load, -) -from endpoints.OAI.utils.lora import get_active_loras, get_lora_list router = APIRouter() @@ -159,392 +125,3 @@ async def chat_completion_request( disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response - - -# Model list endpoint -@router.get("/v1/models", dependencies=[Depends(check_api_key)]) -@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) -async def list_models(request: Request) -> ModelList: - """ - Lists all models in the model directory. - - Requires an admin key to see all models. - """ - - model_config = config.model_config() - model_dir = unwrap(model_config.get("model_dir"), "models") - model_path = pathlib.Path(model_dir) - - draft_model_dir = config.draft_model_config().get("draft_model_dir") - - if get_key_permission(request) == "admin": - models = get_model_list(model_path.resolve(), draft_model_dir) - else: - models = await get_current_model_list() - - if unwrap(model_config.get("use_dummy_models"), False): - models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) - - return models - - -# Currently loaded model endpoint -@router.get( - "/v1/model", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def current_model() -> ModelCard: - """Returns the currently loaded model.""" - - return get_current_model() - - -@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) -async def list_draft_models(request: Request) -> ModelList: - """ - Lists all draft models in the model directory. - - Requires an admin key to see all draft models. - """ - - if get_key_permission(request) == "admin": - draft_model_dir = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - draft_model_path = pathlib.Path(draft_model_dir) - - models = get_model_list(draft_model_path.resolve()) - else: - models = await get_current_model_list(is_draft=True) - - return models - - -# Load model endpoint -@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: - """Loads a model into the model container. This returns an SSE stream.""" - - # Verify request parameters - if not data.name: - error_message = handle_request_error( - "A model name was not provided for load.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) - model_path = model_path / data.name - - draft_model_path = None - if data.draft: - if not data.draft.draft_model_name: - error_message = handle_request_error( - "Could not find the draft model name for model load.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - draft_model_path = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - - if not model_path.exists(): - error_message = handle_request_error( - "Could not find the model path for load. Check model name or config.yml?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - return EventSourceResponse( - stream_model_load(data, model_path, draft_model_path), ping=maxsize - ) - - -# Unload model endpoint -@router.post( - "/v1/model/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_model(): - """Unloads the currently loaded model.""" - await model.unload_model(skip_wait=True) - - -@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) -async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: - """Downloads a model from HuggingFace.""" - - try: - download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) - - # For now, the downloader and request data are 1:1 - download_path = await run_with_request_disconnect( - request, - download_task, - "Download request cancelled by user. Files have been cleaned up.", - ) - - return DownloadResponse(download_path=str(download_path)) - except Exception as exc: - error_message = handle_request_error(str(exc)).error.message - - raise HTTPException(400, error_message) from exc - - -# Lora list endpoint -@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) -@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) -async def list_all_loras(request: Request) -> LoraList: - """ - Lists all LoRAs in the lora directory. - - Requires an admin key to see all LoRAs. - """ - - if get_key_permission(request) == "admin": - lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - loras = get_lora_list(lora_path.resolve()) - else: - loras = get_active_loras() - - return loras - - -# Currently loaded loras endpoint -@router.get( - "/v1/lora", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def active_loras() -> LoraList: - """Returns the currently loaded loras.""" - - return get_active_loras() - - -# Load lora endpoint -@router.post( - "/v1/lora/load", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: - """Loads a LoRA into the model container.""" - - if not data.loras: - error_message = handle_request_error( - "List of loras to load is not found.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - if not lora_dir.exists(): - error_message = handle_request_error( - "A parent lora directory does not exist for load. Check your config.yml?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - load_result = await model.load_loras( - lora_dir, **data.model_dump(), skip_wait=data.skip_queue - ) - - return LoraLoadResponse( - success=unwrap(load_result.get("success"), []), - failure=unwrap(load_result.get("failure"), []), - ) - - -# Unload lora endpoint -@router.post( - "/v1/lora/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_loras(): - """Unloads the currently loaded loras.""" - - await model.unload_loras() - - -# Encode tokens endpoint -@router.post( - "/v1/token/encode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: - """Encodes a string or chat completion messages into tokens.""" - - if isinstance(data.text, str): - text = data.text - else: - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True) - ) - - template_vars = { - "messages": data.text, - "add_generation_prompt": False, - **special_tokens_dict, - } - - text, _ = model.container.prompt_template.render(template_vars) - - raw_tokens = model.container.encode_tokens(text, **data.get_params()) - tokens = unwrap(raw_tokens, []) - response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) - - return response - - -# Decode tokens endpoint -@router.post( - "/v1/token/decode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: - """Decodes tokens into a string.""" - - message = model.container.decode_tokens(data.tokens, **data.get_params()) - response = TokenDecodeResponse(text=unwrap(message, "")) - - return response - - -@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) -async def key_permission(request: Request) -> AuthPermissionResponse: - """ - Gets the access level/permission of a provided key in headers. - - Priority: - - X-admin-key - - X-api-key - - Authorization - """ - - try: - permission = get_key_permission(request) - return AuthPermissionResponse(permission=permission) - except ValueError as exc: - error_message = handle_request_error(str(exc)).error.message - - raise HTTPException(400, error_message) from exc - - -@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -async def list_templates(request: Request) -> TemplateList: - """ - Get a list of all templates. - - Requires an admin key to see all templates. - """ - - template_strings = [] - if get_key_permission(request) == "admin": - templates = get_all_templates() - template_strings = [template.stem for template in templates] - else: - if model.container and model.container.prompt_template: - template_strings.append(model.container.prompt_template.name) - - return TemplateList(data=template_strings) - - -@router.post( - "/v1/template/switch", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def switch_template(data: TemplateSwitchRequest): - """Switch the currently loaded template.""" - - if not data.name: - error_message = handle_request_error( - "New template name not found.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - try: - model.container.prompt_template = PromptTemplate.from_file(data.name) - except FileNotFoundError as e: - error_message = handle_request_error( - f"The template name {data.name} doesn't exist. Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - - -@router.post( - "/v1/template/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_template(): - """Unloads the currently selected template""" - - model.container.prompt_template = None - - -# Sampler override endpoints -@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) -async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: - """ - List all currently applied sampler overrides. - - Requires an admin key to see all override presets. - """ - - if get_key_permission(request) == "admin": - presets = sampling.get_all_presets() - else: - presets = [] - - return SamplerOverrideListResponse( - presets=presets, **sampling.overrides_container.model_dump() - ) - - -@router.post( - "/v1/sampling/override/switch", - dependencies=[Depends(check_admin_key)], -) -async def switch_sampler_override(data: SamplerOverrideSwitchRequest): - """Switch the currently loaded override preset""" - - if data.preset: - try: - sampling.overrides_from_file(data.preset) - except FileNotFoundError as e: - error_message = handle_request_error( - f"Sampler override preset with name {data.preset} does not exist. " - + "Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - elif data.overrides: - sampling.overrides_from_dict(data.overrides) - else: - error_message = handle_request_error( - "A sampler override preset or dictionary wasn't provided.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - -@router.post( - "/v1/sampling/override/unload", - dependencies=[Depends(check_admin_key)], -) -async def unload_sampler_override(): - """Unloads the currently selected override preset""" - - sampling.overrides_from_dict({}) diff --git a/endpoints/core/router.py b/endpoints/core/router.py new file mode 100644 index 0000000..cd0ed37 --- /dev/null +++ b/endpoints/core/router.py @@ -0,0 +1,432 @@ +import asyncio +import pathlib +from sys import maxsize +from fastapi import APIRouter, Depends, HTTPException, Request +from sse_starlette import EventSourceResponse + +from common import config, model, sampling +from common.auth import check_admin_key, check_api_key, get_key_permission +from common.downloader import hf_repo_download +from common.model import check_model_container +from common.networking import handle_request_error, run_with_request_disconnect +from common.templating import PromptTemplate, get_all_templates +from common.utils import unwrap +from endpoints.core.types.auth import AuthPermissionResponse +from endpoints.core.types.download import DownloadRequest, DownloadResponse +from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse +from endpoints.core.types.model import ( + ModelCard, + ModelList, + ModelLoadRequest, + ModelLoadResponse, +) +from endpoints.core.types.sampler_overrides import ( + SamplerOverrideListResponse, + SamplerOverrideSwitchRequest, +) +from endpoints.core.types.template import TemplateList, TemplateSwitchRequest +from endpoints.core.types.token import ( + TokenDecodeRequest, + TokenDecodeResponse, + TokenEncodeRequest, + TokenEncodeResponse, +) +from endpoints.core.utils.lora import get_active_loras, get_lora_list +from endpoints.core.utils.model import ( + get_current_model, + get_current_model_list, + get_model_list, + stream_model_load, +) + + +router = APIRouter() + + +# Model list endpoint +@router.get("/v1/models", dependencies=[Depends(check_api_key)]) +@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) +async def list_models(request: Request) -> ModelList: + """ + Lists all models in the model directory. + + Requires an admin key to see all models. + """ + + model_config = config.model_config() + model_dir = unwrap(model_config.get("model_dir"), "models") + model_path = pathlib.Path(model_dir) + + draft_model_dir = config.draft_model_config().get("draft_model_dir") + + if get_key_permission(request) == "admin": + models = get_model_list(model_path.resolve(), draft_model_dir) + else: + models = await get_current_model_list() + + if unwrap(model_config.get("use_dummy_models"), False): + models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) + + return models + + +# Currently loaded model endpoint +@router.get( + "/v1/model", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def current_model() -> ModelCard: + """Returns the currently loaded model.""" + + return get_current_model() + + +@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) +async def list_draft_models(request: Request) -> ModelList: + """ + Lists all draft models in the model directory. + + Requires an admin key to see all draft models. + """ + + if get_key_permission(request) == "admin": + draft_model_dir = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + draft_model_path = pathlib.Path(draft_model_dir) + + models = get_model_list(draft_model_path.resolve()) + else: + models = await get_current_model_list(is_draft=True) + + return models + + +# Load model endpoint +@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) +async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: + """Loads a model into the model container. This returns an SSE stream.""" + + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) + model_path = model_path / data.name + + draft_model_path = None + if data.draft: + if not data.draft.draft_model_name: + error_message = handle_request_error( + "Could not find the draft model name for model load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + draft_model_path = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + + if not model_path.exists(): + error_message = handle_request_error( + "Could not find the model path for load. Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + return EventSourceResponse( + stream_model_load(data, model_path, draft_model_path), ping=maxsize + ) + + +# Unload model endpoint +@router.post( + "/v1/model/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_model(): + """Unloads the currently loaded model.""" + await model.unload_model(skip_wait=True) + + +@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) +async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: + """Downloads a model from HuggingFace.""" + + try: + download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) + + # For now, the downloader and request data are 1:1 + download_path = await run_with_request_disconnect( + request, + download_task, + "Download request cancelled by user. Files have been cleaned up.", + ) + + return DownloadResponse(download_path=str(download_path)) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + +# Lora list endpoint +@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) +@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) +async def list_all_loras(request: Request) -> LoraList: + """ + Lists all LoRAs in the lora directory. + + Requires an admin key to see all LoRAs. + """ + + if get_key_permission(request) == "admin": + lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + loras = get_lora_list(lora_path.resolve()) + else: + loras = get_active_loras() + + return loras + + +# Currently loaded loras endpoint +@router.get( + "/v1/lora", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def active_loras() -> LoraList: + """Returns the currently loaded loras.""" + + return get_active_loras() + + +# Load lora endpoint +@router.post( + "/v1/lora/load", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: + """Loads a LoRA into the model container.""" + + if not data.loras: + error_message = handle_request_error( + "List of loras to load is not found.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + if not lora_dir.exists(): + error_message = handle_request_error( + "A parent lora directory does not exist for load. Check your config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + load_result = await model.load_loras( + lora_dir, **data.model_dump(), skip_wait=data.skip_queue + ) + + return LoraLoadResponse( + success=unwrap(load_result.get("success"), []), + failure=unwrap(load_result.get("failure"), []), + ) + + +# Unload lora endpoint +@router.post( + "/v1/lora/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_loras(): + """Unloads the currently loaded loras.""" + + await model.unload_loras() + + +# Encode tokens endpoint +@router.post( + "/v1/token/encode", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: + """Encodes a string or chat completion messages into tokens.""" + + if isinstance(data.text, str): + text = data.text + else: + special_tokens_dict = model.container.get_special_tokens( + unwrap(data.add_bos_token, True) + ) + + template_vars = { + "messages": data.text, + "add_generation_prompt": False, + **special_tokens_dict, + } + + text, _ = model.container.prompt_template.render(template_vars) + + raw_tokens = model.container.encode_tokens(text, **data.get_params()) + tokens = unwrap(raw_tokens, []) + response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) + + return response + + +# Decode tokens endpoint +@router.post( + "/v1/token/decode", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: + """Decodes tokens into a string.""" + + message = model.container.decode_tokens(data.tokens, **data.get_params()) + response = TokenDecodeResponse(text=unwrap(message, "")) + + return response + + +@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) +async def key_permission(request: Request) -> AuthPermissionResponse: + """ + Gets the access level/permission of a provided key in headers. + + Priority: + - X-admin-key + - X-api-key + - Authorization + """ + + try: + permission = get_key_permission(request) + return AuthPermissionResponse(permission=permission) + except ValueError as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + +@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) +@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +async def list_templates(request: Request) -> TemplateList: + """ + Get a list of all templates. + + Requires an admin key to see all templates. + """ + + template_strings = [] + if get_key_permission(request) == "admin": + templates = get_all_templates() + template_strings = [template.stem for template in templates] + else: + if model.container and model.container.prompt_template: + template_strings.append(model.container.prompt_template.name) + + return TemplateList(data=template_strings) + + +@router.post( + "/v1/template/switch", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def switch_template(data: TemplateSwitchRequest): + """Switch the currently loaded template.""" + + if not data.name: + error_message = handle_request_error( + "New template name not found.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + model.container.prompt_template = PromptTemplate.from_file(data.name) + except FileNotFoundError as e: + error_message = handle_request_error( + f"The template name {data.name} doesn't exist. Check the spelling?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) from e + + +@router.post( + "/v1/template/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_template(): + """Unloads the currently selected template""" + + model.container.prompt_template = None + + +# Sampler override endpoints +@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) +@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: + """ + List all currently applied sampler overrides. + + Requires an admin key to see all override presets. + """ + + if get_key_permission(request) == "admin": + presets = sampling.get_all_presets() + else: + presets = [] + + return SamplerOverrideListResponse( + presets=presets, **sampling.overrides_container.model_dump() + ) + + +@router.post( + "/v1/sampling/override/switch", + dependencies=[Depends(check_admin_key)], +) +async def switch_sampler_override(data: SamplerOverrideSwitchRequest): + """Switch the currently loaded override preset""" + + if data.preset: + try: + sampling.overrides_from_file(data.preset) + except FileNotFoundError as e: + error_message = handle_request_error( + f"Sampler override preset with name {data.preset} does not exist. " + + "Check the spelling?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) from e + elif data.overrides: + sampling.overrides_from_dict(data.overrides) + else: + error_message = handle_request_error( + "A sampler override preset or dictionary wasn't provided.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + +@router.post( + "/v1/sampling/override/unload", + dependencies=[Depends(check_admin_key)], +) +async def unload_sampler_override(): + """Unloads the currently selected override preset""" + + sampling.overrides_from_dict({}) diff --git a/endpoints/OAI/types/auth.py b/endpoints/core/types/auth.py similarity index 100% rename from endpoints/OAI/types/auth.py rename to endpoints/core/types/auth.py diff --git a/endpoints/OAI/types/download.py b/endpoints/core/types/download.py similarity index 100% rename from endpoints/OAI/types/download.py rename to endpoints/core/types/download.py diff --git a/endpoints/OAI/types/lora.py b/endpoints/core/types/lora.py similarity index 100% rename from endpoints/OAI/types/lora.py rename to endpoints/core/types/lora.py diff --git a/endpoints/OAI/types/model.py b/endpoints/core/types/model.py similarity index 100% rename from endpoints/OAI/types/model.py rename to endpoints/core/types/model.py diff --git a/endpoints/OAI/types/sampler_overrides.py b/endpoints/core/types/sampler_overrides.py similarity index 100% rename from endpoints/OAI/types/sampler_overrides.py rename to endpoints/core/types/sampler_overrides.py diff --git a/endpoints/OAI/types/template.py b/endpoints/core/types/template.py similarity index 100% rename from endpoints/OAI/types/template.py rename to endpoints/core/types/template.py diff --git a/endpoints/OAI/types/token.py b/endpoints/core/types/token.py similarity index 100% rename from endpoints/OAI/types/token.py rename to endpoints/core/types/token.py diff --git a/endpoints/OAI/utils/lora.py b/endpoints/core/utils/lora.py similarity index 92% rename from endpoints/OAI/utils/lora.py rename to endpoints/core/utils/lora.py index 3e31f68..c8c9cc4 100644 --- a/endpoints/OAI/utils/lora.py +++ b/endpoints/core/utils/lora.py @@ -1,7 +1,7 @@ import pathlib from common import model -from endpoints.OAI.types.lora import LoraCard, LoraList +from endpoints.core.types.lora import LoraCard, LoraList def get_lora_list(lora_path: pathlib.Path): diff --git a/endpoints/OAI/utils/model.py b/endpoints/core/utils/model.py similarity index 98% rename from endpoints/OAI/utils/model.py rename to endpoints/core/utils/model.py index a8057e4..0cfb26a 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/core/utils/model.py @@ -5,7 +5,7 @@ from typing import Optional from common import gen_logging, model from common.networking import get_generator_error, handle_request_disconnect from common.utils import unwrap -from endpoints.OAI.types.model import ( +from endpoints.core.types.model import ( ModelCard, ModelCardParameters, ModelList, diff --git a/endpoints/server.py b/endpoints/server.py index 8c10b63..dc501f5 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -5,6 +5,7 @@ from loguru import logger from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends +from endpoints.core.router import router as CoreRouter from endpoints.OAI.router import router as OAIRouter @@ -32,6 +33,9 @@ def setup_app(): app.include_router(OAIRouter) + # Include core API request paths + app.include_router(CoreRouter) + return app From 300f0342337e0522e76c2cbd7ea7afdee7d8ccce Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 14:26:15 -0400 Subject: [PATCH 28/89] API: Add config option to select servers Always enable the core endpoints and allow servers to be selected as needed. Use the OAI server by default. Signed-off-by: kingbri --- config_sample.yml | 4 ++++ endpoints/server.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/config_sample.yml b/config_sample.yml index 12458de..c5d6c9c 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -23,6 +23,10 @@ network: # NOTE: Only enable this for debug purposes send_tracebacks: False + # Select API servers to enable (default: ["OAI"]) + # Possible values: OAI + api_servers: ["OAI"] + # Options for logging logging: # Enable prompt logging (default: False) diff --git a/endpoints/server.py b/endpoints/server.py index dc501f5..dbe35de 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,10 +1,13 @@ +from typing import List import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger +from common import config from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends +from common.utils import unwrap from endpoints.core.router import router as CoreRouter from endpoints.OAI.router import router as OAIRouter @@ -31,7 +34,19 @@ def setup_app(): allow_headers=["*"], ) - app.include_router(OAIRouter) + api_servers: List[str] = unwrap(config.network_config().get("api_servers"), []) + + # Map for API id to server router + router_mapping = {"oai": OAIRouter} + + # Include the OAI api by default + if api_servers: + for server in api_servers: + server_name = server.lower() + if server_name in router_mapping: + app.include_router(router_mapping[server_name]) + else: + app.include_router(OAIRouter) # Include core API request paths app.include_router(CoreRouter) From 3e8ffebdd31557d55ac28a3ce725a9a271e43615 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 14:32:50 -0400 Subject: [PATCH 29/89] Tree: Format Signed-off-by: kingbri --- backends/exllamav2/model.py | 2 -- endpoints/server.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 057e8c1..200be6b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1080,8 +1080,6 @@ class ExllamaV2Container: else [self.tokenizer.eos_token_id] ) - print(self.tokenizer.eos_token_id) - # Ban the EOS token if specified. If not, append to stop conditions # as well. # Set this below logging to avoid polluting the stop strings array diff --git a/endpoints/server.py b/endpoints/server.py index dbe35de..4fc3f0b 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,4 +1,3 @@ -from typing import List import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -34,7 +33,7 @@ def setup_app(): allow_headers=["*"], ) - api_servers: List[str] = unwrap(config.network_config().get("api_servers"), []) + api_servers = unwrap(config.network_config().get("api_servers"), []) # Map for API id to server router router_mapping = {"oai": OAIRouter} From 88e4b108b4d9e36bb9186b4130360cfe9c7d2a9c Mon Sep 17 00:00:00 2001 From: Vhallo Date: Tue, 23 Jul 2024 23:48:50 +0200 Subject: [PATCH 30/89] Typo fix in chat_completion.py --- endpoints/OAI/utils/chat_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index c1c0263..80b5715 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -248,7 +248,7 @@ async def stream_generate_chat_completion( disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: - logger.info(f"Recieved chat completion streaming request {request.state.id}") + logger.info(f"Received chat completion streaming request {request.state.id}") gen_params = data.to_gen_params() From b2064bbfb407bf619a87f041fcb62358a6ab21f4 Mon Sep 17 00:00:00 2001 From: Vhallo Date: Tue, 23 Jul 2024 23:49:43 +0200 Subject: [PATCH 31/89] Typo fix in completion.py --- endpoints/OAI/utils/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index fe5520c..52c2bb4 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -114,7 +114,7 @@ async def stream_generate_completion( disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: - logger.info(f"Recieved streaming completion request {request.state.id}") + logger.info(f"Received streaming completion request {request.state.id}") gen_params = data.to_gen_params() From 8c02fe97714a234103c3cc1303293922de7c3162 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 21:37:53 -0400 Subject: [PATCH 32/89] Downloader: Disable timeout This prevents TimeoutErrors from showing up. However, a longer timeout may be necessary since this is in the API. Turning it off for now will help resolve immediate errors. Signed-off-by: kingbri --- common/downloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/downloader.py b/common/downloader.py index a0b16cb..b252a0f 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -145,7 +145,8 @@ async def hf_repo_download( logger.info(f"Saving {repo_id} to {str(download_path)}") try: - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(total=None) # Turn off timeout + async with aiohttp.ClientSession(timeout=timeout) as session: tasks = [] logger.info(f"Starting download for {repo_id}") From 71de3060bbeff492bb3793b5701b3fc114b2722f Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 21:42:38 -0400 Subject: [PATCH 33/89] Downloader: Make timeout configurable Add an API parameter to set the timeout in seconds. Keep it to None by default for uninterrupted downloads. Signed-off-by: kingbri --- common/downloader.py | 5 +++-- endpoints/core/types/download.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/common/downloader.py b/common/downloader.py index b252a0f..b9e1b72 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -101,6 +101,7 @@ async def hf_repo_download( chunk_limit: Optional[float], include: Optional[List[str]], exclude: Optional[List[str]], + timeout: Optional[int], repo_type: Optional[str] = "model", ): """Gets a repo's information from HuggingFace and downloads it locally.""" @@ -145,8 +146,8 @@ async def hf_repo_download( logger.info(f"Saving {repo_id} to {str(download_path)}") try: - timeout = aiohttp.ClientTimeout(total=None) # Turn off timeout - async with aiohttp.ClientSession(timeout=timeout) as session: + client_timeout = aiohttp.ClientTimeout(total=timeout) # Turn off timeout + async with aiohttp.ClientSession(timeout=client_timeout) as session: tasks = [] logger.info(f"Starting download for {repo_id}") diff --git a/endpoints/core/types/download.py b/endpoints/core/types/download.py index ac681bf..cf49501 100644 --- a/endpoints/core/types/download.py +++ b/endpoints/core/types/download.py @@ -17,6 +17,7 @@ class DownloadRequest(BaseModel): include: List[str] = Field(default_factory=_generate_include_list) exclude: List[str] = Field(default_factory=list) chunk_limit: Optional[int] = None + timeout: Optional[int] = None class DownloadResponse(BaseModel): From 5c082b7e8c4570ceff108191b53b5497f7c00cf5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jul 2024 18:56:28 -0400 Subject: [PATCH 34/89] Async: Add option to use Uvloop/Winloop These are faster event loops for asyncio which should improve overall performance. Gate these under an experimental flag for now to stress test these loops. Signed-off-by: kingbri --- common/args.py | 7 ++- config_sample.yml | 5 ++ endpoints/server.py | 5 ++ main.py | 110 +++++++++++++++++++++++++------------------- pyproject.toml | 4 ++ start.py | 4 +- 6 files changed, 85 insertions(+), 50 deletions(-) diff --git a/common/args.py b/common/args.py index bbf007c..e57de78 100644 --- a/common/args.py +++ b/common/args.py @@ -193,7 +193,12 @@ def add_developer_args(parser: argparse.ArgumentParser): developer_group.add_argument( "--cuda-malloc-backend", type=str_to_bool, - help="Disables API request streaming", + help="Runs with the pytorch CUDA malloc backend", + ) + developer_group.add_argument( + "--uvloop", + type=str_to_bool, + help="Run asyncio using Uvloop or Winloop", ) diff --git a/config_sample.yml b/config_sample.yml index c5d6c9c..3070642 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -62,6 +62,11 @@ developer: # This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk. #cuda_malloc_backend: False + # Enable Uvloop or Winloop (default: False) + # Make the program utilize a faster async event loop which can improve performance + # NOTE: It's recommended to enable this, but if something breaks, turn this off. + #uvloop: False + # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: diff --git a/endpoints/server.py b/endpoints/server.py index 4fc3f0b..401b211 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,3 +1,4 @@ +import asyncio import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -71,11 +72,15 @@ async def start_api(host: str, port: int): # Setup app app = setup_app() + # Get the current event loop + loop = asyncio.get_running_loop() + config = uvicorn.Config( app, host=host, port=port, log_config=UVICORN_LOG_CONFIG, + loop=loop, ) server = uvicorn.Server(config) diff --git a/main.py b/main.py index fd29297..2ec98db 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,10 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" import asyncio -import aiofiles import json import os import pathlib +import platform import signal from loguru import logger from typing import Optional @@ -23,51 +23,8 @@ if not do_export_openapi: from backends.exllamav2.utils import check_exllama_version -async def entrypoint(args: Optional[dict] = None): - """Entry function for program startup""" - - setup_logger() - - # Set up signal aborting - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - if os.getenv("EXPORT_OPENAPI", "").lower() in ("true", "1"): - openapi_json = export_openapi() - - async with aiofiles.open("openapi.json", "w") as f: - await f.write(json.dumps(openapi_json)) - logger.info("Successfully wrote OpenAPI spec to openapi.json") - - return - - # Load from YAML config - config.from_file(pathlib.Path("config.yml")) - - # Parse and override config from args - if args is None: - parser = init_argparser() - args = convert_args_to_dict(parser.parse_args(), parser) - - config.from_args(args) - - developer_config = config.developer_config() - - # Check exllamav2 version and give a descriptive error if it's too old - # Skip if launching unsafely - - if unwrap(developer_config.get("unsafe_launch"), False): - logger.warning( - "UNSAFE: Skipping ExllamaV2 version check.\n" - "If you aren't a developer, please keep this off!" - ) - else: - check_exllama_version() - - # Enable CUDA malloc backend - if unwrap(developer_config.get("cuda_malloc_backend"), False): - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" - logger.warning("Enabled the experimental CUDA malloc backend.") +async def entrypoint_async(): + """Async entry function for program startup""" network_config = config.network_config() @@ -131,5 +88,64 @@ async def entrypoint(args: Optional[dict] = None): await start_api(host, port) +def entrypoint(arguments: Optional[dict] = None): + setup_logger() + + # Set up signal aborting + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + if do_export_openapi: + openapi_json = export_openapi() + + with open("openapi.json", "w") as f: + f.write(json.dumps(openapi_json)) + logger.info("Successfully wrote OpenAPI spec to openapi.json") + + return + + # Load from YAML config + config.from_file(pathlib.Path("config.yml")) + + # Parse and override config from args + if arguments is None: + parser = init_argparser() + arguments = convert_args_to_dict(parser.parse_args(), parser) + + config.from_args(arguments) + developer_config = config.developer_config() + + # Check exllamav2 version and give a descriptive error if it's too old + # Skip if launching unsafely + + if unwrap(developer_config.get("unsafe_launch"), False): + logger.warning( + "UNSAFE: Skipping ExllamaV2 version check.\n" + "If you aren't a developer, please keep this off!" + ) + else: + check_exllama_version() + + # Enable CUDA malloc backend + if unwrap(developer_config.get("cuda_malloc_backend"), False): + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" + logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.") + + # Use Uvloop/Winloop + if unwrap(developer_config.get("uvloop"), False): + if platform.system() == "Windows": + from winloop import install + else: + from uvloop import install + + # Set loop event policy + install() + + logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.") + + # Enter into the async event loop + asyncio.run(entrypoint_async()) + + if __name__ == "__main__": - asyncio.run(entrypoint()) + entrypoint() diff --git a/pyproject.toml b/pyproject.toml index ecf6362..031e228 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,10 @@ dependencies = [ "lm-format-enforcer >= 0.9.6", "aiofiles", + # Improved asyncio loops + "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", + "winloop ; platform_system == 'Windows'", + # TEMP: Remove once 2.x is fixed in upstream "numpy < 2.0.0", diff --git a/start.py b/start.py index d1f8843..ddf4d27 100644 --- a/start.py +++ b/start.py @@ -1,6 +1,5 @@ """Utility to automatically upgrade and start the API""" -import asyncio import argparse import os import pathlib @@ -159,4 +158,5 @@ if __name__ == "__main__": # Import entrypoint after installing all requirements from main import entrypoint - asyncio.run(entrypoint(convert_args_to_dict(args, parser))) + converted_args = convert_args_to_dict(args, parser) + entrypoint(converted_args) From 42bc4adcfb31fea1aead6ca12633643e0f13057e Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jul 2024 21:47:47 -0400 Subject: [PATCH 35/89] Config: Add option to set priority to realtime Realtime process priority assigns resources to point to tabby's processes. Running as administrator will give realtime priority while running as a normal user will set as high priority. Signed-off-by: kingbri --- config_sample.yml | 5 +++++ main.py | 15 +++++++++++++++ pyproject.toml | 7 +++---- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/config_sample.yml b/config_sample.yml index 3070642..c92f673 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -67,6 +67,11 @@ developer: # NOTE: It's recommended to enable this, but if something breaks, turn this off. #uvloop: False + # Set process to use a higher priority + # For realtime process priority, run as administrator or sudo + # Otherwise, the priority will be set to high + #realtime_process_priority: False + # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: diff --git a/main.py b/main.py index 2ec98db..c62a381 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,8 @@ import signal from loguru import logger from typing import Optional +import psutil + from common import config, gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser from common.auth import load_auth_keys @@ -143,6 +145,19 @@ def entrypoint(arguments: Optional[dict] = None): logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.") + # Set the process priority + if unwrap(developer_config.get("realtime_process_priority"), False): + current_process = psutil.Process(os.getpid()) + if platform.system() == "Windows": + current_process.nice(psutil.REALTIME_PRIORITY_CLASS) + else: + current_process.nice(psutil.IOPRIO_CLASS_RT) + + logger.warning( + "EXPERIMENTAL: Process priority set to Realtime. \n" + "If you're not running on administrator/sudo, the priority is set to high." + ) + # Enter into the async event loop asyncio.run(entrypoint_async()) diff --git a/pyproject.toml b/pyproject.toml index 031e228..f89846f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ dependencies = [ "tokenizers", "lm-format-enforcer >= 0.9.6", "aiofiles", + "aiohttp", + "huggingface_hub", + "psutil", # Improved asyncio loops "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", @@ -35,10 +38,6 @@ dependencies = [ # TEMP: Remove once 2.x is fixed in upstream "numpy < 2.0.0", - - # TODO: Maybe move these to a downloader feature? - "aiohttp", - "huggingface_hub", ] [project.urls] From 27f9559d83541a579afeb925064fffdab1139fb3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jul 2024 21:50:16 -0400 Subject: [PATCH 36/89] Dependencies: Switch to fastapi-slim Reduces dependency size since the full fastapi package isn't required. Add httptools since it makes requests faster and it was installed with fastapi previously. Signed-off-by: kingbri --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f89846f..d86d865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ version = "0.0.1" description = "An OAI compatible exllamav2 API that's both lightweight and fast" requires-python = ">=3.10" dependencies = [ - "fastapi >= 0.110.0", + "fastapi-slim >= 0.110.0", "pydantic >= 2.0.0", "PyYAML", "rich", @@ -31,6 +31,7 @@ dependencies = [ "aiohttp", "huggingface_hub", "psutil", + "httptools>=0.5.0", # Improved asyncio loops "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", From a1c3f6cc1cc13d1f52c28bd6a11adb5b05fc8a19 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jul 2024 22:00:43 -0400 Subject: [PATCH 37/89] Dependencies: Update ExllamaV2 v0.1.8 Signed-off-by: kingbri --- pyproject.toml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d86d865..38cdbcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,12 +62,12 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", @@ -89,12 +89,12 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", @@ -113,9 +113,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options From f20cd330ef3b1e9999c4242db4a83e75eb392050 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 26 Jul 2024 02:45:07 +0000 Subject: [PATCH 38/89] feat: add embeddings support via sentence-transformers --- endpoints/OAI/embeddings.py | 145 +++++++++++++++++++++++++++++++ endpoints/OAI/router.py | 23 +++++ endpoints/OAI/types/embedding.py | 39 +++++++++ pyproject.toml | 3 +- tabbyAPI | 1 + 5 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 endpoints/OAI/embeddings.py create mode 100644 endpoints/OAI/types/embedding.py create mode 160000 tabbyAPI diff --git a/endpoints/OAI/embeddings.py b/endpoints/OAI/embeddings.py new file mode 100644 index 0000000..3cc59de --- /dev/null +++ b/endpoints/OAI/embeddings.py @@ -0,0 +1,145 @@ +""" +This file is derived from +[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py) +and modified. +The changes introduced are: Suppression of progress bar, +typing/pydantic classes moved into this file, +embeddings function declared async. +""" + +import os +import base64 +import numpy as np +from transformers import AutoModel + +embeddings_params_initialized = False + + +def initialize_embedding_params(): + ''' + using 'lazy loading' to avoid circular import + so this function will be executed only once + ''' + global embeddings_params_initialized + if not embeddings_params_initialized: + + global st_model, embeddings_model, embeddings_device + + st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", + 'all-mpnet-base-v2') + embeddings_model = None + # OPENAI_EMBEDDING_DEVICE: auto (best or cpu), + # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, + # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, + # hpu, mtia, privateuseone + embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu') + if embeddings_device.lower() == 'auto': + embeddings_device = None + + embeddings_params_initialized = True + + +def load_embedding_model(model: str): + try: + from sentence_transformers import SentenceTransformer + except ModuleNotFoundError: + print("The sentence_transformers module has not been found. " + + "Please install it manually with " + + "pip install -U sentence-transformers.") + raise ModuleNotFoundError from None + + initialize_embedding_params() + global embeddings_device, embeddings_model + try: + print(f"Try embedding model: {model} on {embeddings_device}") + if 'jina-embeddings' in model: + # trust_remote_code is needed to use the encode method + embeddings_model = AutoModel.from_pretrained( + model, trust_remote_code=True) + embeddings_model = embeddings_model.to(embeddings_device) + else: + embeddings_model = SentenceTransformer( + model, + device=embeddings_device, + ) + + print(f"Loaded embedding model: {model}") + except Exception as e: + embeddings_model = None + raise Exception(f"Error: Failed to load embedding model: {model}", + internal_message=repr(e)) from None + + +def get_embeddings_model(): + initialize_embedding_params() + global embeddings_model, st_model + if st_model and not embeddings_model: + load_embedding_model(st_model) # lazy load the model + + return embeddings_model + + +def get_embeddings_model_name() -> str: + initialize_embedding_params() + global st_model + return st_model + + +def get_embeddings(input: list) -> np.ndarray: + model = get_embeddings_model() + embedding = model.encode(input, + convert_to_numpy=True, + normalize_embeddings=True, + convert_to_tensor=False, + show_progress_bar=False) + return embedding + + +async def embeddings(input: list, + encoding_format: str, + model: str = None) -> dict: + if model is None: + model = st_model + else: + load_embedding_model(model) + + embeddings = get_embeddings(input) + if encoding_format == "base64": + data = [{ + "object": "embedding", + "embedding": float_list_to_base64(emb), + "index": n + } for n, emb in enumerate(embeddings)] + else: + data = [{ + "object": "embedding", + "embedding": emb.tolist(), + "index": n + } for n, emb in enumerate(embeddings)] + + response = { + "object": "list", + "data": data, + "model": st_model if model is None else model, + "usage": { + "prompt_tokens": 0, + "total_tokens": 0, + } + } + return response + + +def float_list_to_base64(float_array: np.ndarray) -> str: + # Convert the list to a float32 array that the OpenAPI client expects + # float_array = np.array(float_list, dtype="float32") + + # Get raw bytes + bytes_array = float_array.tobytes() + + # Encode bytes into base64 + encoded_bytes = base64.b64encode(bytes_array) + + # Turn raw base64 encoded bytes into ASCII + ascii_string = encoded_bytes.decode('ascii') + return ascii_string + diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 771b7f3..9897353 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,5 +1,6 @@ import asyncio from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse from sse_starlette import EventSourceResponse from sys import maxsize @@ -8,11 +9,16 @@ from common.auth import check_api_key from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap +import endpoints.OAI.embeddings as OAIembeddings from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, ) +from endpoints.OAI.types.embedding import ( + EmbeddingsRequest, + EmbeddingsResponse +) from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -125,3 +131,20 @@ async def chat_completion_request( disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response + +# Embeddings endpoint +@router.post( + "/v1/embeddings", + dependencies=[Depends(check_api_key), Depends(check_model_container)], + response_model=EmbeddingsResponse +) +async def handle_embeddings(request: EmbeddingsRequest): + input = request.input + if not input: + raise JSONResponse(status_code=400, + content={"error": "Missing required argument input"}) + model = request.model if request.model else None + response = await OAIembeddings.embeddings(input, request.encoding_format, + model) + return JSONResponse(response) + diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py new file mode 100644 index 0000000..3dd5e86 --- /dev/null +++ b/endpoints/OAI/types/embedding.py @@ -0,0 +1,39 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + +class EmbeddingsRequest(BaseModel): + input: List[str] = Field( + ..., description="List of input texts to generate embeddings for.") + encoding_format: str = Field( + "float", + description="Encoding format for the embeddings. " + "Can be 'float' or 'base64'.") + model: Optional[str] = Field( + None, + description="Name of the embedding model to use. " + "If not provided, the default model will be used.") + + +class EmbeddingObject(BaseModel): + object: str = Field("embedding", description="Type of the object.") + embedding: List[float] = Field( + ..., description="Embedding values as a list of floats.") + index: int = Field( + ..., + description="Index of the input text corresponding to " + "the embedding.") + + +class EmbeddingsResponse(BaseModel): + object: str = Field("list", description="Type of the response object.") + data: List[EmbeddingObject] = Field( + ..., description="List of embedding objects.") + model: str = Field(..., description="Name of the embedding model used.") + usage: UsageInfo = Field(..., description="Information about token usage.") diff --git a/pyproject.toml b/pyproject.toml index 38cdbcf..f934096 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ dependencies = [ [project.optional-dependencies] extras = [ # Heavy dependencies that aren't for everyday use - "outlines" + "outlines", + "sentence-transformers" ] dev = [ "ruff == 0.3.2" diff --git a/tabbyAPI b/tabbyAPI new file mode 160000 index 0000000..1650e6e --- /dev/null +++ b/tabbyAPI @@ -0,0 +1 @@ +Subproject commit 1650e6e6406edf797576c077aaceafcf28895c26 From 765d3593b3e22888149292355a8d5ae03bd2f630 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 26 Jul 2024 02:52:18 +0000 Subject: [PATCH 39/89] remove submodule --- tabbyAPI | 1 - 1 file changed, 1 deletion(-) delete mode 160000 tabbyAPI diff --git a/tabbyAPI b/tabbyAPI deleted file mode 160000 index 1650e6e..0000000 --- a/tabbyAPI +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1650e6e6406edf797576c077aaceafcf28895c26 From 5adfab1cbd7dc4d23501384bebb0ac6566c4fd62 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 26 Jul 2024 02:53:14 +0000 Subject: [PATCH 40/89] ruff: formatting --- endpoints/OAI/embeddings.py | 69 +++++++++++++++----------------- endpoints/OAI/router.py | 17 ++++---- endpoints/OAI/types/embedding.py | 21 +++++----- 3 files changed, 52 insertions(+), 55 deletions(-) diff --git a/endpoints/OAI/embeddings.py b/endpoints/OAI/embeddings.py index 3cc59de..725d7ba 100644 --- a/endpoints/OAI/embeddings.py +++ b/endpoints/OAI/embeddings.py @@ -16,24 +16,22 @@ embeddings_params_initialized = False def initialize_embedding_params(): - ''' + """ using 'lazy loading' to avoid circular import so this function will be executed only once - ''' + """ global embeddings_params_initialized if not embeddings_params_initialized: - global st_model, embeddings_model, embeddings_device - st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", - 'all-mpnet-base-v2') + st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2") embeddings_model = None # OPENAI_EMBEDDING_DEVICE: auto (best or cpu), # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, # hpu, mtia, privateuseone - embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu') - if embeddings_device.lower() == 'auto': + embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu") + if embeddings_device.lower() == "auto": embeddings_device = None embeddings_params_initialized = True @@ -43,19 +41,20 @@ def load_embedding_model(model: str): try: from sentence_transformers import SentenceTransformer except ModuleNotFoundError: - print("The sentence_transformers module has not been found. " + - "Please install it manually with " + - "pip install -U sentence-transformers.") + print( + "The sentence_transformers module has not been found. " + + "Please install it manually with " + + "pip install -U sentence-transformers." + ) raise ModuleNotFoundError from None initialize_embedding_params() global embeddings_device, embeddings_model try: print(f"Try embedding model: {model} on {embeddings_device}") - if 'jina-embeddings' in model: + if "jina-embeddings" in model: # trust_remote_code is needed to use the encode method - embeddings_model = AutoModel.from_pretrained( - model, trust_remote_code=True) + embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) embeddings_model = embeddings_model.to(embeddings_device) else: embeddings_model = SentenceTransformer( @@ -66,8 +65,9 @@ def load_embedding_model(model: str): print(f"Loaded embedding model: {model}") except Exception as e: embeddings_model = None - raise Exception(f"Error: Failed to load embedding model: {model}", - internal_message=repr(e)) from None + raise Exception( + f"Error: Failed to load embedding model: {model}", internal_message=repr(e) + ) from None def get_embeddings_model(): @@ -87,17 +87,17 @@ def get_embeddings_model_name() -> str: def get_embeddings(input: list) -> np.ndarray: model = get_embeddings_model() - embedding = model.encode(input, - convert_to_numpy=True, - normalize_embeddings=True, - convert_to_tensor=False, - show_progress_bar=False) + embedding = model.encode( + input, + convert_to_numpy=True, + normalize_embeddings=True, + convert_to_tensor=False, + show_progress_bar=False, + ) return embedding -async def embeddings(input: list, - encoding_format: str, - model: str = None) -> dict: +async def embeddings(input: list, encoding_format: str, model: str = None) -> dict: if model is None: model = st_model else: @@ -105,17 +105,15 @@ async def embeddings(input: list, embeddings = get_embeddings(input) if encoding_format == "base64": - data = [{ - "object": "embedding", - "embedding": float_list_to_base64(emb), - "index": n - } for n, emb in enumerate(embeddings)] + data = [ + {"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} + for n, emb in enumerate(embeddings) + ] else: - data = [{ - "object": "embedding", - "embedding": emb.tolist(), - "index": n - } for n, emb in enumerate(embeddings)] + data = [ + {"object": "embedding", "embedding": emb.tolist(), "index": n} + for n, emb in enumerate(embeddings) + ] response = { "object": "list", @@ -124,7 +122,7 @@ async def embeddings(input: list, "usage": { "prompt_tokens": 0, "total_tokens": 0, - } + }, } return response @@ -140,6 +138,5 @@ def float_list_to_base64(float_array: np.ndarray) -> str: encoded_bytes = base64.b64encode(bytes_array) # Turn raw base64 encoded bytes into ASCII - ascii_string = encoded_bytes.decode('ascii') + ascii_string = encoded_bytes.decode("ascii") return ascii_string - diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 9897353..039042a 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -15,10 +15,7 @@ from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, ) -from endpoints.OAI.types.embedding import ( - EmbeddingsRequest, - EmbeddingsResponse -) +from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -132,19 +129,19 @@ async def chat_completion_request( ) return response + # Embeddings endpoint @router.post( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], - response_model=EmbeddingsResponse + response_model=EmbeddingsResponse, ) async def handle_embeddings(request: EmbeddingsRequest): input = request.input if not input: - raise JSONResponse(status_code=400, - content={"error": "Missing required argument input"}) + raise JSONResponse( + status_code=400, content={"error": "Missing required argument input"} + ) model = request.model if request.model else None - response = await OAIembeddings.embeddings(input, request.encoding_format, - model) + response = await OAIembeddings.embeddings(input, request.encoding_format, model) return JSONResponse(response) - diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py index 3dd5e86..7d5779f 100644 --- a/endpoints/OAI/types/embedding.py +++ b/endpoints/OAI/types/embedding.py @@ -8,32 +8,35 @@ class UsageInfo(BaseModel): total_tokens: int = 0 completion_tokens: Optional[int] = 0 + class EmbeddingsRequest(BaseModel): input: List[str] = Field( - ..., description="List of input texts to generate embeddings for.") + ..., description="List of input texts to generate embeddings for." + ) encoding_format: str = Field( "float", description="Encoding format for the embeddings. " - "Can be 'float' or 'base64'.") + "Can be 'float' or 'base64'.", + ) model: Optional[str] = Field( None, description="Name of the embedding model to use. " - "If not provided, the default model will be used.") + "If not provided, the default model will be used.", + ) class EmbeddingObject(BaseModel): object: str = Field("embedding", description="Type of the object.") embedding: List[float] = Field( - ..., description="Embedding values as a list of floats.") + ..., description="Embedding values as a list of floats." + ) index: int = Field( - ..., - description="Index of the input text corresponding to " - "the embedding.") + ..., description="Index of the input text corresponding to " "the embedding." + ) class EmbeddingsResponse(BaseModel): object: str = Field("list", description="Type of the response object.") - data: List[EmbeddingObject] = Field( - ..., description="List of embedding objects.") + data: List[EmbeddingObject] = Field(..., description="List of embedding objects.") model: str = Field(..., description="Name of the embedding model used.") usage: UsageInfo = Field(..., description="Information about token usage.") From f47d96790ce00a629ba7cf83f45ee099129b726b Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 25 Jul 2024 23:39:52 -0400 Subject: [PATCH 41/89] Dependencies: Update pytorch and flash_attention v2.4.0 and v2.6.3 Also use torch 2.4 wheels. Signed-off-by: kingbri --- pyproject.toml | 84 +++++++++++++++++++++++++------------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38cdbcf..2110a5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,68 +54,68 @@ dev = [ ] cu121 = [ # Torch (Extra index URLs not support in pyproject.toml) - "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] cu118 = [ # Torch - "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] amd = [ # Torch triton for ROCm - "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Torch - "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/rocm6.1/torch-2.4.0%2Brocm6.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/rocm6.1/torch-2.4.0%2Brocm6.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/rocm6.1/torch-2.4.0%2Brocm6.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options From 4e808cbed78e61665223ed994da2dcf6b645fb1e Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 15:04:29 -0400 Subject: [PATCH 42/89] Auth: Fix disable auth when checking for key permissions Since authentication is disabled, remove the limited permissions for requests. Signed-off-by: kingbri --- common/auth.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/auth.py b/common/auth.py index 4f7c9f8..174208d 100644 --- a/common/auth.py +++ b/common/auth.py @@ -84,6 +84,10 @@ def get_key_permission(request: Request): Internal only! Use the depends functions for incoming requests. """ + # Give full admin permissions if auth is disabled + if DISABLE_AUTH: + return "admin" + # Hyphens are okay here test_key = coalesce( request.headers.get("x-admin-key"), From b7cb6f0b91c9299ca447107cac82f5dc65ca4d59 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 16:37:30 -0400 Subject: [PATCH 43/89] API: Add KoboldAI server Used for interacting with applications that use KoboldAI's API such as horde. Signed-off-by: kingbri --- backends/exllamav2/model.py | 8 +- endpoints/Kobold/router.py | 103 ++++++++++++++++++ endpoints/Kobold/types/generation.py | 53 ++++++++++ endpoints/Kobold/types/token.py | 15 +++ endpoints/Kobold/utils/generation.py | 151 +++++++++++++++++++++++++++ endpoints/server.py | 3 +- 6 files changed, 330 insertions(+), 3 deletions(-) create mode 100644 endpoints/Kobold/router.py create mode 100644 endpoints/Kobold/types/generation.py create mode 100644 endpoints/Kobold/types/token.py create mode 100644 endpoints/Kobold/utils/generation.py diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 200be6b..3515123 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -828,10 +828,14 @@ class ExllamaV2Container: return dict(zip_longest(top_tokens, cleaned_values)) - async def generate(self, prompt: str, request_id: str, **kwargs): + async def generate( + self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs + ): """Generate a response to a prompt""" generations = [] - async for generation in self.generate_gen(prompt, request_id, **kwargs): + async for generation in self.generate_gen( + prompt, request_id, abort_event, **kwargs + ): generations.append(generation) joined_generation = { diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py new file mode 100644 index 0000000..c3265fc --- /dev/null +++ b/endpoints/Kobold/router.py @@ -0,0 +1,103 @@ +from sys import maxsize +from fastapi import APIRouter, Depends, Request +from sse_starlette import EventSourceResponse + +from common import model +from common.auth import check_api_key +from common.model import check_model_container +from common.utils import unwrap +from endpoints.Kobold.types.generation import ( + AbortRequest, + CheckGenerateRequest, + GenerateRequest, + GenerateResponse, +) +from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse +from endpoints.Kobold.utils.generation import ( + abort_generation, + generation_status, + get_generation, + stream_generation, +) +from endpoints.core.utils.model import get_current_model + + +router = APIRouter(prefix="/api") + + +@router.post( + "/v1/generate", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: + response = await get_generation(data, request) + + return response + + +@router.post( + "/extra/generate/stream", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: + response = EventSourceResponse(stream_generation(data, request), ping=maxsize) + + return response + + +@router.post( + "/extra/abort", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def abort_generate(data: AbortRequest): + response = await abort_generation(data.genkey) + + return response + + +@router.get( + "/extra/generate/check", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@router.post( + "/extra/generate/check", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: + response = await generation_status(data.genkey) + + return response + + +@router.get( + "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] +) +async def current_model(): + """Fetches the current model and who owns it.""" + + current_model_card = get_current_model() + return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"} + + +@router.post( + "/extra/tokencount", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def get_tokencount(data: TokenCountRequest): + raw_tokens = model.container.encode_tokens(data.prompt) + tokens = unwrap(raw_tokens, []) + return TokenCountResponse(value=len(tokens), ids=tokens) + + +@router.get("/v1/info/version") +async def get_version(): + """Impersonate KAI United.""" + + return {"result": "1.2.5"} + + +@router.get("/extra/version") +async def get_extra_version(): + """Impersonate Koboldcpp.""" + + return {"result": "KoboldCpp", "version": "1.61"} diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py new file mode 100644 index 0000000..5468741 --- /dev/null +++ b/endpoints/Kobold/types/generation.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field +from common.sampling import BaseSamplerRequest, get_default_sampler_value + + +class GenerateRequest(BaseSamplerRequest): + prompt: str + use_default_badwordsids: Optional[bool] = False + genkey: Optional[str] = None + + max_length: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("max_tokens"), + examples=[150], + ) + rep_pen_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + ) + rep_pen: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + ) + + def to_gen_params(self, **kwargs): + # Swap kobold generation params to OAI/Exl2 ones + self.max_tokens = self.max_length + self.repetition_penalty = self.rep_pen + self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range + + return super().to_gen_params(**kwargs) + + +class GenerateResponseResult(BaseModel): + text: str + + +class GenerateResponse(BaseModel): + results: List[GenerateResponseResult] = Field(default_factory=list) + + +class StreamGenerateChunk(BaseModel): + token: str + + +class AbortRequest(BaseModel): + genkey: str + + +class AbortResponse(BaseModel): + success: bool + + +class CheckGenerateRequest(BaseModel): + genkey: str diff --git a/endpoints/Kobold/types/token.py b/endpoints/Kobold/types/token.py new file mode 100644 index 0000000..e6639d9 --- /dev/null +++ b/endpoints/Kobold/types/token.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel +from typing import List + + +class TokenCountRequest(BaseModel): + """Represents a KAI tokenization request.""" + + prompt: str + + +class TokenCountResponse(BaseModel): + """Represents a KAI tokenization response.""" + + value: int + ids: List[int] diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py new file mode 100644 index 0000000..5febcff --- /dev/null +++ b/endpoints/Kobold/utils/generation.py @@ -0,0 +1,151 @@ +import asyncio +from asyncio import CancelledError +from fastapi import HTTPException, Request +from loguru import logger +from sse_starlette import ServerSentEvent + +from common import model +from common.networking import ( + get_generator_error, + handle_request_disconnect, + handle_request_error, + request_disconnect_loop, +) +from common.utils import unwrap +from endpoints.Kobold.types.generation import ( + AbortResponse, + GenerateRequest, + GenerateResponse, + GenerateResponseResult, + StreamGenerateChunk, +) + + +generation_cache = {} + + +async def override_request_id(request: Request, data: GenerateRequest): + """Overrides the request ID with a KAI genkey if present.""" + + if data.genkey: + request.state.id = data.genkey + + +def _create_response(text: str): + results = [GenerateResponseResult(text=text)] + return GenerateResponse(results=results) + + +def _create_stream_chunk(text: str): + return StreamGenerateChunk(token=text) + + +async def _stream_collector(data: GenerateRequest, request: Request): + """Common async generator for generation streams.""" + + abort_event = asyncio.Event() + disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + + # Create a new entry in the cache + generation_cache[data.genkey] = {"abort": abort_event, "text": ""} + + try: + logger.info(f"Received Kobold generation request {data.genkey}") + + generator = model.container.generate_gen( + data.prompt, data.genkey, abort_event, **data.to_gen_params() + ) + async for generation in generator: + if disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Kobold generation {data.genkey} cancelled by user." + ) + + text = generation.get("text") + + # Update the generation cache with the new chunk + if text: + generation_cache[data.genkey]["text"] += text + yield text + + if "finish_reason" in generation: + logger.info(f"Finished streaming Kobold request {data.genkey}") + break + except CancelledError: + # If the request disconnects, break out + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Kobold generation {data.genkey} cancelled by user." + ) + finally: + # Cleanup the cache + del generation_cache[data.genkey] + + +async def stream_generation(data: GenerateRequest, request: Request): + """Wrapper for stream generations.""" + + # If the genkey doesn't exist, set it to the request ID + if not data.genkey: + data.genkey = request.state.id + + try: + async for chunk in _stream_collector(data, request): + response = _create_stream_chunk(chunk) + yield ServerSentEvent( + event="message", data=response.model_dump_json(), sep="\n" + ) + except Exception: + yield get_generator_error( + f"Kobold generation {data.genkey} aborted. " + "Please check the server console." + ) + + +async def get_generation(data: GenerateRequest, request: Request): + """Wrapper to get a static generation.""" + + # If the genkey doesn't exist, set it to the request ID + if not data.genkey: + data.genkey = request.state.id + + try: + full_text = "" + async for chunk in _stream_collector(data, request): + full_text += chunk + + response = _create_response(full_text) + return response + except Exception as exc: + error_message = handle_request_error( + f"Completion {request.state.id} aborted. Maybe the model was unloaded? " + "Please check the server console." + ).error.message + + # Server error if there's a generation exception + raise HTTPException(503, error_message) from exc + + +async def abort_generation(genkey: str): + """Aborts a generation from the cache.""" + + abort_event = unwrap(generation_cache.get(genkey), {}).get("abort") + if abort_event: + abort_event.set() + handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.") + + return AbortResponse(success=True) + + +async def generation_status(genkey: str): + """Fetches the status of a generation from the cache.""" + + current_text = unwrap(generation_cache.get(genkey), {}).get("text") + if current_text: + response = _create_response(current_text) + else: + response = GenerateResponse() + + return response diff --git a/endpoints/server.py b/endpoints/server.py index 401b211..dfb1cdd 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -9,6 +9,7 @@ from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends from common.utils import unwrap from endpoints.core.router import router as CoreRouter +from endpoints.Kobold.router import router as KoboldRouter from endpoints.OAI.router import router as OAIRouter @@ -37,7 +38,7 @@ def setup_app(): api_servers = unwrap(config.network_config().get("api_servers"), []) # Map for API id to server router - router_mapping = {"oai": OAIRouter} + router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter} # Include the OAI api by default if api_servers: From 545e26608f8dcfa9600e7060d53daf3efa112fc3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 16:45:29 -0400 Subject: [PATCH 44/89] Kobold: Move params to aliases Some of the parameters the API provides are aliases for their OAI equivalents. It makes more sense to move them to the common file. Signed-off-by: kingbri --- common/sampling.py | 14 +++++++++++++- endpoints/Kobold/types/generation.py | 22 +++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index bbeddb8..2b3850e 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -16,11 +16,15 @@ class BaseSamplerRequest(BaseModel): max_tokens: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("max_tokens"), + validation_alias=AliasChoices("max_tokens", "max_length"), + description="Aliases: max_length", examples=[150], ) min_tokens: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("min_tokens", 0), + validation_alias=AliasChoices("min_tokens", "min_length"), + description="Aliases: min_length", examples=[0], ) @@ -91,6 +95,8 @@ class BaseSamplerRequest(BaseModel): repetition_penalty: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + validation_alias=AliasChoices("repetition_penalty", "rep_pen"), + description="Aliases: rep_pen", examples=[1.0], ) @@ -118,6 +124,8 @@ class BaseSamplerRequest(BaseModel): ban_eos_token: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("ban_eos_token", False), + validation_alias=AliasChoices("ban_eos_token", "ignore_eos"), + description="Aliases: ignore_eos", examples=[False], ) @@ -165,8 +173,12 @@ class BaseSamplerRequest(BaseModel): "penalty_range", "repetition_range", "repetition_penalty_range", + "rep_pen_range", + ), + description=( + "Aliases: repetition_range, repetition_penalty_range, " + "rep_pen_range" ), - description="Aliases: repetition_range, repetition_penalty_range", ) cfg_scale: Optional[float] = Field( diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index 5468741..eab214c 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -1,30 +1,18 @@ from typing import List, Optional from pydantic import BaseModel, Field -from common.sampling import BaseSamplerRequest, get_default_sampler_value +from common.sampling import BaseSamplerRequest class GenerateRequest(BaseSamplerRequest): prompt: str - use_default_badwordsids: Optional[bool] = False genkey: Optional[str] = None - - max_length: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("max_tokens"), - examples=[150], - ) - rep_pen_range: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("penalty_range", -1), - ) - rep_pen: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), - ) + use_default_badwordsids: Optional[bool] = False def to_gen_params(self, **kwargs): - # Swap kobold generation params to OAI/Exl2 ones - self.max_tokens = self.max_length - self.repetition_penalty = self.rep_pen - self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range + # Exl2 uses -1 to include all tokens in repetition penalty + if self.penalty_range == 0: + self.penalty_range = -1 return super().to_gen_params(**kwargs) From 7522b1447b1e2d569fd9909ac91112d2308f40dd Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 18:23:22 -0400 Subject: [PATCH 45/89] Model: Add support for HuggingFace config and bad_words_ids This is necessary for Kobold's API. Current models use bad_words_ids in generation_config.json, but for some reason, they're also present in the model's config.json. Signed-off-by: kingbri --- backends/exllamav2/model.py | 20 +++++++------- common/transformers_utils.py | 39 ++++++++++++++++++++++++++++ common/utils.py | 6 +++++ endpoints/Kobold/types/generation.py | 12 +++++++++ 4 files changed, 68 insertions(+), 9 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3515123..3df16b0 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -47,7 +47,7 @@ from common.templating import ( TemplateLoadError, find_template_from_model, ) -from common.transformers_utils import GenerationConfig +from common.transformers_utils import GenerationConfig, HuggingFaceConfig from common.utils import coalesce, unwrap @@ -72,6 +72,7 @@ class ExllamaV2Container: draft_cache_mode: str = "FP16" max_batch_size: int = 20 generation_config: Optional[GenerationConfig] = None + hf_config: Optional[HuggingFaceConfig] = None # GPU split vars gpu_split: Optional[list] = None @@ -186,6 +187,9 @@ class ExllamaV2Container: except AttributeError: pass + # Create the hf_config + self.hf_config = HuggingFaceConfig.from_file(model_directory) + # Then override the base_seq_len if present override_base_seq_len = kwargs.get("override_base_seq_len") if override_base_seq_len: @@ -268,15 +272,8 @@ class ExllamaV2Container: else: self.cache_size = self.config.max_seq_len - # Try to set prompt template - self.prompt_template = self.find_prompt_template( - kwargs.get("prompt_template"), model_directory - ) - # Load generation config overrides - generation_config_path = ( - pathlib.Path(self.config.model_dir) / "generation_config.json" - ) + generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): try: self.generation_config = GenerationConfig.from_file( @@ -288,6 +285,11 @@ class ExllamaV2Container: "Skipping generation config load because of an unexpected error." ) + # Try to set prompt template + self.prompt_template = self.find_prompt_template( + kwargs.get("prompt_template"), model_directory + ) + # Catch all for template lookup errors if self.prompt_template: logger.info( diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 62d4622..2431b1b 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,6 +1,7 @@ import json import pathlib from typing import List, Optional, Union +from loguru import logger from pydantic import BaseModel @@ -11,6 +12,7 @@ class GenerationConfig(BaseModel): """ eos_token_id: Optional[Union[int, List[int]]] = None + bad_words_ids: Optional[List[List[int]]] = None @classmethod def from_file(self, model_directory: pathlib.Path): @@ -30,3 +32,40 @@ class GenerationConfig(BaseModel): return [self.eos_token_id] else: return self.eos_token_id + + +class HuggingFaceConfig(BaseModel): + """ + An abridged version of HuggingFace's model config. + Will be expanded as needed. + """ + + badwordsids: Optional[str] = None + + @classmethod + def from_file(self, model_directory: pathlib.Path): + """Create an instance from a generation config file.""" + + hf_config_path = model_directory / "config.json" + with open( + hf_config_path, "r", encoding="utf8" + ) as hf_config_json: + hf_config_dict = json.load(hf_config_json) + return self.model_validate(hf_config_dict) + + def get_badwordsids(self): + """Wrapper method to fetch badwordsids.""" + + if self.badwordsids: + try: + bad_words_list = json.loads(self.badwordsids) + return bad_words_list + except json.JSONDecodeError: + logger.warning( + "Skipping badwordsids from config.json " + "since it's not a valid array." + ) + + return [] + else: + return [] diff --git a/common/utils.py b/common/utils.py index 6787f39..b120022 100644 --- a/common/utils.py +++ b/common/utils.py @@ -18,3 +18,9 @@ def prune_dict(input_dict): """Trim out instances of None from a dictionary.""" return {k: v for k, v in input_dict.items() if v is not None} + + +def flat_map(input_list): + """Flattens a list of lists into a single list.""" + + return [item for sublist in input_list for item in sublist] diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index eab214c..0ee5489 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -1,7 +1,9 @@ from typing import List, Optional from pydantic import BaseModel, Field +from common import model from common.sampling import BaseSamplerRequest +from common.utils import flat_map, unwrap class GenerateRequest(BaseSamplerRequest): @@ -14,6 +16,16 @@ class GenerateRequest(BaseSamplerRequest): if self.penalty_range == 0: self.penalty_range = -1 + # Move badwordsids into banned tokens for generation + if self.use_default_badwordsids: + bad_words_ids = unwrap( + model.container.generation_config.bad_words_ids, + model.container.hf_config.get_badwordsids() + ) + + if bad_words_ids: + self.banned_tokens += flat_map(bad_words_ids) + return super().to_gen_params(**kwargs) From ea80b62e307b3668553b3fdf9d1b3dfccecaf84f Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 18:32:33 -0400 Subject: [PATCH 46/89] Sampling: Reorder aliased params and add kobold aliases Also add dynatemp range which is an alternative way of calculating min and max temp. Signed-off-by: kingbri --- common/sampling.py | 61 ++++++++++++++-------------- endpoints/Kobold/types/generation.py | 9 +++- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index 2b3850e..8851f00 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -34,13 +34,22 @@ class BaseSamplerRequest(BaseModel): ) stop: Optional[Union[str, List[str]]] = Field( - default_factory=lambda: get_default_sampler_value("stop", []) + default_factory=lambda: get_default_sampler_value("stop", []), + validation_alias=AliasChoices("stop", "stop_sequence"), + description="Aliases: stop_sequence", ) banned_strings: Optional[Union[str, List[str]]] = Field( default_factory=lambda: get_default_sampler_value("banned_strings", []) ) + banned_tokens: Optional[Union[List[int], str]] = Field( + default_factory=lambda: get_default_sampler_value("banned_tokens", []), + validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), + description="Aliases: custom_token_bans", + examples=[[128, 330]], + ) + token_healing: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) @@ -80,6 +89,13 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) + typical: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("typical", 1.0), + validation_alias=AliasChoices("typical", "typical_p"), + description="Aliases: typical_p", + examples=[1.0], + ) + skew: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("skew", 0.0), examples=[0.0], @@ -100,6 +116,20 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) + penalty_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + "rep_pen_range", + ), + description=( + "Aliases: repetition_range, repetition_penalty_range, " + "rep_pen_range" + ), + ) + repetition_decay: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) @@ -159,28 +189,6 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("speculative_ngram"), ) - # Aliased variables - typical: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("typical", 1.0), - validation_alias=AliasChoices("typical", "typical_p"), - description="Aliases: typical_p", - examples=[1.0], - ) - - penalty_range: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("penalty_range", -1), - validation_alias=AliasChoices( - "penalty_range", - "repetition_range", - "repetition_penalty_range", - "rep_pen_range", - ), - description=( - "Aliases: repetition_range, repetition_penalty_range, " - "rep_pen_range" - ), - ) - cfg_scale: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), @@ -208,13 +216,6 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) - banned_tokens: Optional[Union[List[int], str]] = Field( - default_factory=lambda: get_default_sampler_value("banned_tokens", []), - validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), - description="Aliases: custom_token_bans", - examples=[[128, 330]], - ) - # TODO: Return back to adaptable class-based validation But that's just too much # abstraction compared to simple if statements at the moment def validate_params(self): diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index 0ee5489..210a914 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -2,7 +2,7 @@ from typing import List, Optional from pydantic import BaseModel, Field from common import model -from common.sampling import BaseSamplerRequest +from common.sampling import BaseSamplerRequest, get_default_sampler_value from common.utils import flat_map, unwrap @@ -10,12 +10,19 @@ class GenerateRequest(BaseSamplerRequest): prompt: str genkey: Optional[str] = None use_default_badwordsids: Optional[bool] = False + dynatemp_range: Optional[float] = Field( + default_factory=get_default_sampler_value("dynatemp_range") + ) def to_gen_params(self, **kwargs): # Exl2 uses -1 to include all tokens in repetition penalty if self.penalty_range == 0: self.penalty_range = -1 + if self.dynatemp_range: + self.min_temp = self.temperature - self.dynatemp_range + self.max_temp = self.temperature + self.dynatemp_range + # Move badwordsids into banned tokens for generation if self.use_default_badwordsids: bad_words_ids = unwrap( From e8fc13a1f6097f3c881fab09c791f4dfb5f8553d Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 18:33:04 -0400 Subject: [PATCH 47/89] Tree: Format Signed-off-by: kingbri --- common/sampling.py | 3 +-- common/transformers_utils.py | 4 +--- endpoints/Kobold/types/generation.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index 8851f00..72552ce 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -125,8 +125,7 @@ class BaseSamplerRequest(BaseModel): "rep_pen_range", ), description=( - "Aliases: repetition_range, repetition_penalty_range, " - "rep_pen_range" + "Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range" ), ) diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 2431b1b..9db8ad2 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -47,9 +47,7 @@ class HuggingFaceConfig(BaseModel): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" - with open( - hf_config_path, "r", encoding="utf8" - ) as hf_config_json: + with open(hf_config_path, "r", encoding="utf8") as hf_config_json: hf_config_dict = json.load(hf_config_json) return self.model_validate(hf_config_dict) diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index 210a914..310484b 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -27,7 +27,7 @@ class GenerateRequest(BaseSamplerRequest): if self.use_default_badwordsids: bad_words_ids = unwrap( model.container.generation_config.bad_words_ids, - model.container.hf_config.get_badwordsids() + model.container.hf_config.get_badwordsids(), ) if bad_words_ids: From 884b6f5ecd726686e19508bc09b7345671a48052 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 21:32:05 -0400 Subject: [PATCH 48/89] API: Add log options for initialization Make each API log their respective URLs to help inform users. Signed-off-by: kingbri --- endpoints/Kobold/router.py | 5 +++++ endpoints/OAI/router.py | 5 +++++ endpoints/server.py | 30 ++++++++++++++++++++---------- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index c3265fc..6cfccf5 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -22,7 +22,12 @@ from endpoints.Kobold.utils.generation import ( from endpoints.core.utils.model import get_current_model +api_name = "KoboldAI" router = APIRouter(prefix="/api") +urls = { + "Generation": "http://{host}:{port}/api/v1/generate", + "Streaming": "http://{host}:{port}/api/extra/generate/stream", +} @router.post( diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 771b7f3..55a2205 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -24,7 +24,12 @@ from endpoints.OAI.utils.completion import ( ) +api_name = "OAI" router = APIRouter() +urls = { + "Completions": "http://{host}:{port}/v1/completions", + "Chat completions": "http://{host}:{port}/v1/chat/completions", +} # Completions endpoint diff --git a/endpoints/server.py b/endpoints/server.py index dfb1cdd..f8bfc8a 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,4 +1,5 @@ import asyncio +from typing import Optional import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -8,12 +9,12 @@ from common import config from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends from common.utils import unwrap +from endpoints.Kobold import router as KoboldRouter +from endpoints.OAI import router as OAIRouter from endpoints.core.router import router as CoreRouter -from endpoints.Kobold.router import router as KoboldRouter -from endpoints.OAI.router import router as OAIRouter -def setup_app(): +def setup_app(host: Optional[str] = None, port: Optional[int] = None): """Includes the correct routers for startup""" app = FastAPI( @@ -43,11 +44,20 @@ def setup_app(): # Include the OAI api by default if api_servers: for server in api_servers: - server_name = server.lower() - if server_name in router_mapping: - app.include_router(router_mapping[server_name]) + selected_server = router_mapping.get(server.lower()) + + if selected_server: + app.include_router(selected_server.router) + + logger.info(f"Starting {selected_server.api_name} API") + for path, url in selected_server.urls.items(): + formatted_url = url.format(host=host, port=port) + logger.info(f"{path}: {formatted_url}") else: - app.include_router(OAIRouter) + app.include_router(OAIRouter.router) + for path, url in OAIRouter.urls.items(): + formatted_url = url.format(host=host, port=port) + logger.info(f"{path}: {formatted_url}") # Include core API request paths app.include_router(CoreRouter) @@ -67,11 +77,11 @@ async def start_api(host: str, port: int): # TODO: Move OAI API to a separate folder logger.info(f"Developer documentation: http://{host}:{port}/redoc") - logger.info(f"Completions: http://{host}:{port}/v1/completions") - logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") + # logger.info(f"Completions: http://{host}:{port}/v1/completions") + # logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") # Setup app - app = setup_app() + app = setup_app(host, port) # Get the current event loop loop = asyncio.get_running_loop() From 2773517a168a7240065a50423e6be19034fb711a Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 22:24:33 -0400 Subject: [PATCH 49/89] API: Add setup function to routers This helps prepare the router before exposing it to the parent app. Signed-off-by: kingbri --- endpoints/Kobold/router.py | 43 ++++++++++++++++++++++++-------------- endpoints/OAI/router.py | 4 ++++ endpoints/server.py | 4 ++-- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 6cfccf5..ff0ec8c 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -29,9 +29,20 @@ urls = { "Streaming": "http://{host}:{port}/api/extra/generate/stream", } +kai_router = APIRouter() +extra_kai_router = APIRouter() -@router.post( - "/v1/generate", + +def setup(): + router.include_router(kai_router, prefix="/v1") + router.include_router(kai_router, prefix="/latest", include_in_schema=False) + router.include_router(extra_kai_router, prefix="/extra") + + return router + + +@kai_router.post( + "/generate", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: @@ -40,8 +51,8 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: return response -@router.post( - "/extra/generate/stream", +@extra_kai_router.post( + "/generate/stream", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: @@ -50,8 +61,8 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe return response -@router.post( - "/extra/abort", +@extra_kai_router.post( + "/abort", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def abort_generate(data: AbortRequest): @@ -60,12 +71,12 @@ async def abort_generate(data: AbortRequest): return response -@router.get( - "/extra/generate/check", +@extra_kai_router.get( + "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -@router.post( - "/extra/generate/check", +@extra_kai_router.post( + "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: @@ -74,8 +85,8 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: return response -@router.get( - "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] +@kai_router.get( + "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] ) async def current_model(): """Fetches the current model and who owns it.""" @@ -84,8 +95,8 @@ async def current_model(): return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"} -@router.post( - "/extra/tokencount", +@extra_kai_router.post( + "/tokencount", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def get_tokencount(data: TokenCountRequest): @@ -94,14 +105,14 @@ async def get_tokencount(data: TokenCountRequest): return TokenCountResponse(value=len(tokens), ids=tokens) -@router.get("/v1/info/version") +@kai_router.get("/info/version") async def get_version(): """Impersonate KAI United.""" return {"result": "1.2.5"} -@router.get("/extra/version") +@extra_kai_router.get("/version") async def get_extra_version(): """Impersonate Koboldcpp.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 55a2205..d970161 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -32,6 +32,10 @@ urls = { } +def setup(): + return router + + # Completions endpoint @router.post( "/v1/completions", diff --git a/endpoints/server.py b/endpoints/server.py index f8bfc8a..0b3edfb 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -47,14 +47,14 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None): selected_server = router_mapping.get(server.lower()) if selected_server: - app.include_router(selected_server.router) + app.include_router(selected_server.setup()) logger.info(f"Starting {selected_server.api_name} API") for path, url in selected_server.urls.items(): formatted_url = url.format(host=host, port=port) logger.info(f"{path}: {formatted_url}") else: - app.include_router(OAIRouter.router) + app.include_router(OAIRouter.setup()) for path, url in OAIRouter.urls.items(): formatted_url = url.format(host=host, port=port) logger.info(f"{path}: {formatted_url}") From 3038f668e8bcfa4bc1579b1fc3f9b432f905729b Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 22:50:01 -0400 Subject: [PATCH 50/89] Kobold: Add extra routes for horde compatability Needed to connect to horde. Also do some reordering to clean the router file up. Signed-off-by: kingbri --- endpoints/Kobold/router.py | 50 +++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index ff0ec8c..334bae2 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -6,12 +6,15 @@ from common import model from common.auth import check_api_key from common.model import check_model_container from common.utils import unwrap +from endpoints.core.utils.model import get_current_model from endpoints.Kobold.types.generation import ( AbortRequest, + AbortResponse, CheckGenerateRequest, GenerateRequest, GenerateResponse, ) +from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse from endpoints.Kobold.utils.generation import ( abort_generation, @@ -19,7 +22,6 @@ from endpoints.Kobold.utils.generation import ( get_generation, stream_generation, ) -from endpoints.core.utils.model import get_current_model api_name = "KoboldAI" @@ -65,7 +67,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe "/abort", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def abort_generate(data: AbortRequest): +async def abort_generate(data: AbortRequest) -> AbortResponse: response = await abort_generation(data.genkey) return response @@ -88,7 +90,7 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: @kai_router.get( "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] ) -async def current_model(): +async def current_model() -> CurrentModelResponse: """Fetches the current model and who owns it.""" current_model_card = get_current_model() @@ -99,12 +101,31 @@ async def current_model(): "/tokencount", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def get_tokencount(data: TokenCountRequest): +async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: raw_tokens = model.container.encode_tokens(data.prompt) tokens = unwrap(raw_tokens, []) return TokenCountResponse(value=len(tokens), ids=tokens) +@kai_router.get( + "/config/max_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@kai_router.get( + "/config/max_context_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@extra_kai_router.get( + "/true_max_context_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def get_max_length() -> MaxLengthResponse: + """Fetches the max length of the model.""" + + max_length = model.container.get_model_parameters().get("max_seq_len") + return {"value": max_length} + + @kai_router.get("/info/version") async def get_version(): """Impersonate KAI United.""" @@ -117,3 +138,24 @@ async def get_extra_version(): """Impersonate Koboldcpp.""" return {"result": "KoboldCpp", "version": "1.61"} + + +@kai_router.get("/config/soft_prompts_list") +async def get_available_softprompts(): + """Used for KAI compliance.""" + + return {"values": []} + + +@kai_router.get("/config/soft_prompt") +async def get_current_softprompt(): + """Used for KAI compliance.""" + + return {"value": ""} + + +@kai_router.put("/config/soft_prompt") +async def set_current_softprompt(): + """Used for KAI compliance.""" + + return {} From e3226ed93037e84ab30ae018ac8113f71d4de3a7 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 22:57:55 -0400 Subject: [PATCH 51/89] Kobold: Add untracked file Model types weren't added. Signed-off-by: kingbri --- endpoints/Kobold/types/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 endpoints/Kobold/types/model.py diff --git a/endpoints/Kobold/types/model.py b/endpoints/Kobold/types/model.py new file mode 100644 index 0000000..761176f --- /dev/null +++ b/endpoints/Kobold/types/model.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class CurrentModelResponse(BaseModel): + result: str + + +class MaxLengthResponse(BaseModel): + value: str From 7b8b3fe23d66e53ddd454fedb5f56fada8b96c07 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 23:00:26 -0400 Subject: [PATCH 52/89] Kobold: Fix max length type Was mistakenly a string instead of an integer. Signed-off-by: kingbri --- endpoints/Kobold/types/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/endpoints/Kobold/types/model.py b/endpoints/Kobold/types/model.py index 761176f..8f7276e 100644 --- a/endpoints/Kobold/types/model.py +++ b/endpoints/Kobold/types/model.py @@ -6,4 +6,4 @@ class CurrentModelResponse(BaseModel): class MaxLengthResponse(BaseModel): - value: str + value: int From c79e0832d5f231ca9d3cb048ea3862cab7752721 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 28 Jul 2024 13:48:49 -0400 Subject: [PATCH 53/89] Revert "Dependencies: Update pytorch and flash_attention" This reverts commit f47d96790ce00a629ba7cf83f45ee099129b726b. See https://github.com/pytorch/pytorch/issues/131662 for more information. Signed-off-by: kingbri --- pyproject.toml | 84 +++++++++++++++++++++++++------------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2110a5d..38cdbcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,68 +54,68 @@ dev = [ ] cu121 = [ # Torch (Extra index URLs not support in pyproject.toml) - "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] cu118 = [ # Torch - "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/cu118/torch-2.4.0%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] amd = [ # Torch triton for ROCm - "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Torch - "torch @ https://download.pytorch.org/whl/rocm6.1/torch-2.4.0%2Brocm6.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "torch @ https://download.pytorch.org/whl/rocm6.1/torch-2.4.0%2Brocm6.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "torch @ https://download.pytorch.org/whl/rocm6.1/torch-2.4.0%2Brocm6.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options From d85414738ddee188b4fc00464cf2dde036c722a0 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 28 Jul 2024 13:50:15 -0400 Subject: [PATCH 54/89] Dependencies: Update Flash Attention 2 v2.6.3 with torch 2.3 wheels. Signed-off-by: kingbri --- pyproject.toml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38cdbcf..3deace0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,14 +70,14 @@ cu121 = [ "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] cu118 = [ # Torch @@ -97,9 +97,9 @@ cu118 = [ "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] amd = [ # Torch triton for ROCm From c9a5d2c363ce2e941937ec34a335a60bc34bbd65 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 28 Jul 2024 14:10:51 -0400 Subject: [PATCH 55/89] OAI: Refactor embeddings Move files and rewrite routes to adhere to Tabby's code style. Signed-off-by: kingbri --- endpoints/OAI/router.py | 17 +++++------------ endpoints/OAI/{ => utils}/embeddings.py | 0 2 files changed, 5 insertions(+), 12 deletions(-) rename endpoints/OAI/{ => utils}/embeddings.py (100%) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 039042a..ffb678b 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,6 +1,5 @@ import asyncio from fastapi import APIRouter, Depends, HTTPException, Request -from fastapi.responses import JSONResponse from sse_starlette import EventSourceResponse from sys import maxsize @@ -9,7 +8,6 @@ from common.auth import check_api_key from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap -import endpoints.OAI.embeddings as OAIembeddings from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, @@ -25,6 +23,7 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) +from endpoints.OAI.utils.embeddings import embeddings router = APIRouter() @@ -134,14 +133,8 @@ async def chat_completion_request( @router.post( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], - response_model=EmbeddingsResponse, ) -async def handle_embeddings(request: EmbeddingsRequest): - input = request.input - if not input: - raise JSONResponse( - status_code=400, content={"error": "Missing required argument input"} - ) - model = request.model if request.model else None - response = await OAIembeddings.embeddings(input, request.encoding_format, model) - return JSONResponse(response) +async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: + response = await embeddings(data.input, data.encoding_format, data.model) + + return response diff --git a/endpoints/OAI/embeddings.py b/endpoints/OAI/utils/embeddings.py similarity index 100% rename from endpoints/OAI/embeddings.py rename to endpoints/OAI/utils/embeddings.py From 3f21d9ef96a8c80b90e4444024d8bf3d4ac10a5b Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 29 Jul 2024 13:42:03 -0400 Subject: [PATCH 56/89] Embeddings: Switch to Infinity Infinity-emb is an async batching engine for embeddings. This is preferable to sentence-transformers since it handles scalable usecases without the need for external thread intervention. Signed-off-by: kingbri --- common/config.py | 5 + config_sample.yml | 7 ++ endpoints/OAI/router.py | 2 +- endpoints/OAI/utils/embeddings.py | 173 +++++++++++++----------------- 4 files changed, 87 insertions(+), 100 deletions(-) diff --git a/common/config.py b/common/config.py index 972b382..5546240 100644 --- a/common/config.py +++ b/common/config.py @@ -95,3 +95,8 @@ def logging_config(): def developer_config(): """Returns the developer specific config from the global config""" return unwrap(GLOBAL_CONFIG.get("developer"), {}) + + +def embeddings_config(): + """Returns the embeddings config from the global config""" + return unwrap(GLOBAL_CONFIG.get("embeddings"), {}) diff --git a/config_sample.yml b/config_sample.yml index c92f673..053feb6 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -72,6 +72,13 @@ developer: # Otherwise, the priority will be set to high #realtime_process_priority: False +embeddings: + embeddings_model_dir: models + + embeddings_model_name: + + embeddings_device: cpu + # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index ffb678b..2cad876 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -135,6 +135,6 @@ async def chat_completion_request( dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: - response = await embeddings(data.input, data.encoding_format, data.model) + response = await embeddings(data) return response diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index 725d7ba..cf5b799 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -7,135 +7,110 @@ typing/pydantic classes moved into this file, embeddings function declared async. """ +import asyncio import os import base64 +import pathlib +from loguru import logger import numpy as np from transformers import AutoModel -embeddings_params_initialized = False +from common import config +from common.utils import unwrap +from endpoints.OAI.types.embedding import ( + EmbeddingObject, + EmbeddingsRequest, + EmbeddingsResponse, +) -def initialize_embedding_params(): - """ - using 'lazy loading' to avoid circular import - so this function will be executed only once - """ - global embeddings_params_initialized - if not embeddings_params_initialized: - global st_model, embeddings_model, embeddings_device - - st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2") - embeddings_model = None - # OPENAI_EMBEDDING_DEVICE: auto (best or cpu), - # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, - # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, - # hpu, mtia, privateuseone - embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu") - if embeddings_device.lower() == "auto": - embeddings_device = None - - embeddings_params_initialized = True +embeddings_model = None -def load_embedding_model(model: str): +def load_embedding_model(model_path: pathlib.Path, device: str): try: - from sentence_transformers import SentenceTransformer + from infinity_emb import EngineArgs, AsyncEmbeddingEngine except ModuleNotFoundError: - print( - "The sentence_transformers module has not been found. " - + "Please install it manually with " - + "pip install -U sentence-transformers." + logger.error( + "Skipping embeddings because infinity-emb is not installed.\n" + "Please run the following command in your environment " + "to install extra packages:\n" + "pip install -U .[extras]" ) raise ModuleNotFoundError from None - initialize_embedding_params() - global embeddings_device, embeddings_model + global embeddings_model try: - print(f"Try embedding model: {model} on {embeddings_device}") - if "jina-embeddings" in model: - # trust_remote_code is needed to use the encode method - embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) - embeddings_model = embeddings_model.to(embeddings_device) - else: - embeddings_model = SentenceTransformer( - model, - device=embeddings_device, - ) - - print(f"Loaded embedding model: {model}") + engine_args = EngineArgs( + model_name_or_path=str(model_path.resolve()), + engine="torch", + device="cpu", + bettertransformer=False, + model_warmup=False, + ) + embeddings_model = AsyncEmbeddingEngine.from_args(engine_args) + logger.info(f"Trying to load embeddings model: {model_path.name} on {device}") except Exception as e: embeddings_model = None - raise Exception( - f"Error: Failed to load embedding model: {model}", internal_message=repr(e) - ) from None + raise e -def get_embeddings_model(): - initialize_embedding_params() - global embeddings_model, st_model - if st_model and not embeddings_model: - load_embedding_model(st_model) # lazy load the model +async def embeddings(data: EmbeddingsRequest) -> dict: + embeddings_config = config.embeddings_config() - return embeddings_model + # Use CPU by default + device = embeddings_config.get("embeddings_device", "cpu") + if device == "auto": + device = None - -def get_embeddings_model_name() -> str: - initialize_embedding_params() - global st_model - return st_model - - -def get_embeddings(input: list) -> np.ndarray: - model = get_embeddings_model() - embedding = model.encode( - input, - convert_to_numpy=True, - normalize_embeddings=True, - convert_to_tensor=False, - show_progress_bar=False, + model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir")) + model_path: pathlib.Path = model_path / embeddings_config.get( + "embeddings_model_name" ) - return embedding + if not model_path: + logger.info("Embeddings model path not found") + load_embedding_model(model_path, device) -async def embeddings(input: list, encoding_format: str, model: str = None) -> dict: - if model is None: - model = st_model - else: - load_embedding_model(model) + async with embeddings_model: + embeddings, usage = await embeddings_model.embed(data.input) - embeddings = get_embeddings(input) - if encoding_format == "base64": - data = [ - {"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} - for n, emb in enumerate(embeddings) - ] - else: - data = [ - {"object": "embedding", "embedding": emb.tolist(), "index": n} - for n, emb in enumerate(embeddings) - ] + # OAI expects a return of base64 if the input is base64 + if data.encoding_format == "base64": + embedding_data = [ + { + "object": "embedding", + "embedding": float_list_to_base64(emb), + "index": n, + } + for n, emb in enumerate(embeddings) + ] + else: + embedding_data = [ + {"object": "embedding", "embedding": emb.tolist(), "index": n} + for n, emb in enumerate(embeddings) + ] - response = { - "object": "list", - "data": data, - "model": st_model if model is None else model, - "usage": { - "prompt_tokens": 0, - "total_tokens": 0, - }, - } - return response + response = { + "object": "list", + "data": embedding_data, + "model": model_path.name, + "usage": { + "prompt_tokens": usage, + "total_tokens": usage, + }, + } + return response def float_list_to_base64(float_array: np.ndarray) -> str: - # Convert the list to a float32 array that the OpenAPI client expects - # float_array = np.array(float_list, dtype="float32") + """ + Converts the provided list to a float32 array for OpenAI + Ex. float_array = np.array(float_list, dtype="float32") + """ - # Get raw bytes - bytes_array = float_array.tobytes() - - # Encode bytes into base64 - encoded_bytes = base64.b64encode(bytes_array) + # Encode raw bytes into base64 + encoded_bytes = base64.b64encode(float_array.tobytes()) # Turn raw base64 encoded bytes into ASCII ascii_string = encoded_bytes.decode("ascii") From ac1afcc5886aba728fbe4df3f049dc131cc19ae3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 29 Jul 2024 14:15:40 -0400 Subject: [PATCH 57/89] Embeddings: Use response classes instead of dicts Follows the existing code style. Signed-off-by: kingbri --- endpoints/OAI/utils/embeddings.py | 60 +++++++++++++++---------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index cf5b799..1ce611c 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -7,37 +7,41 @@ typing/pydantic classes moved into this file, embeddings function declared async. """ -import asyncio -import os import base64 import pathlib -from loguru import logger import numpy as np -from transformers import AutoModel +from loguru import logger from common import config -from common.utils import unwrap from endpoints.OAI.types.embedding import ( EmbeddingObject, EmbeddingsRequest, EmbeddingsResponse, + UsageInfo, ) +# Conditionally import infinity embeddings engine +# Required so the logger doesn't take over tabby's logging handlers +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + embeddings_model = None def load_embedding_model(model_path: pathlib.Path, device: str): - try: - from infinity_emb import EngineArgs, AsyncEmbeddingEngine - except ModuleNotFoundError: + if not has_infinity_emb: logger.error( "Skipping embeddings because infinity-emb is not installed.\n" "Please run the following command in your environment " "to install extra packages:\n" "pip install -U .[extras]" ) - raise ModuleNotFoundError from None + raise ModuleNotFoundError global embeddings_model try: @@ -76,30 +80,22 @@ async def embeddings(data: EmbeddingsRequest) -> dict: embeddings, usage = await embeddings_model.embed(data.input) # OAI expects a return of base64 if the input is base64 - if data.encoding_format == "base64": - embedding_data = [ - { - "object": "embedding", - "embedding": float_list_to_base64(emb), - "index": n, - } - for n, emb in enumerate(embeddings) - ] - else: - embedding_data = [ - {"object": "embedding", "embedding": emb.tolist(), "index": n} - for n, emb in enumerate(embeddings) - ] + embedding_data = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embeddings) + ] + + response = EmbeddingsResponse( + data=embedding_data, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) - response = { - "object": "list", - "data": embedding_data, - "model": model_path.name, - "usage": { - "prompt_tokens": usage, - "total_tokens": usage, - }, - } return response From fbf1455db18a3ac2f4f312796150e99536d2361c Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 11:00:23 -0400 Subject: [PATCH 58/89] Embeddings: Migrate and organize Infinity Use Infinity as a separate backend and handle the model within the common module. This separates out the embeddings model from the endpoint which allows for model loading/unloading in core. Signed-off-by: kingbri --- backends/infinity/model.py | 56 +++++++++++++++ common/model.py | 45 ++++++++++++ common/signals.py | 13 ++++ endpoints/OAI/router.py | 11 ++- endpoints/OAI/utils/embeddings.py | 111 +++++++++--------------------- main.py | 12 ++++ 6 files changed, 165 insertions(+), 83 deletions(-) create mode 100644 backends/infinity/model.py diff --git a/backends/infinity/model.py b/backends/infinity/model.py new file mode 100644 index 0000000..2d4ae83 --- /dev/null +++ b/backends/infinity/model.py @@ -0,0 +1,56 @@ +import gc +import pathlib +import torch +from typing import List, Optional + +from common.utils import unwrap + +# Conditionally import infinity to sidestep its logger +# TODO: Make this prettier +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + + +class InfinityContainer: + model_dir: pathlib.Path + + # Conditionally set the type hint based on importablity + # TODO: Clean this up + if has_infinity_emb: + engine: Optional[AsyncEmbeddingEngine] = None + else: + engine = None + + def __init__(self, model_directory: pathlib.Path): + self.model_dir = model_directory + + async def load(self, **kwargs): + # Use cpu by default + device = unwrap(kwargs.get("device"), "cpu") + + engine_args = EngineArgs( + model_name_or_path=str(self.model_dir), + engine="torch", + device=device, + bettertransformer=False, + model_warmup=False, + ) + + self.engine = AsyncEmbeddingEngine.from_args(engine_args) + await self.engine.astart() + + async def unload(self): + await self.engine.astop() + self.engine = None + + gc.collect() + torch.cuda.empty_cache() + + async def generate(self, sentence_input: List[str]): + result_embeddings, usage = await self.engine.embed(sentence_input) + + return {"embeddings": result_embeddings, "usage": usage} diff --git a/common/model.py b/common/model.py index a6477c2..b4b259e 100644 --- a/common/model.py +++ b/common/model.py @@ -20,6 +20,15 @@ if not do_export_openapi: # Global model container container: Optional[ExllamaV2Container] = None + embeddings_container = None + + # Type hint the infinity emb container if it exists + from backends.infinity.model import has_infinity_emb + + if has_infinity_emb: + from backends.infinity.model import InfinityContainer + + embeddings_container: Optional[InfinityContainer] = None def load_progress(module, modules): @@ -100,6 +109,30 @@ async def unload_loras(): await container.unload(loras_only=True) +async def load_embeddings_model(model_path: pathlib.Path, **kwargs): + global embeddings_container + + # Break out if infinity isn't installed + if not has_infinity_emb: + logger.warning( + "Skipping embeddings because infinity-emb is not installed.\n" + "Please run the following command in your environment " + "to install extra packages:\n" + "pip install -U .[extras]" + ) + return + + embeddings_container = InfinityContainer(model_path) + await embeddings_container.load(**kwargs) + + +async def unload_embeddings_model(): + global embeddings_container + + await embeddings_container.unload() + embeddings_container = None + + def get_config_default(key, fallback=None, is_draft=False): """Fetches a default value from model config if allowed by the user.""" @@ -126,3 +159,15 @@ async def check_model_container(): ).error.message raise HTTPException(400, error_message) + + +async def check_embeddings_container(): + """FastAPI depends that checks if an embeddings model is loaded.""" + + if embeddings_container is None: + error_message = handle_request_error( + "No embeddings models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) diff --git a/common/signals.py b/common/signals.py index 07d7564..d4b67bc 100644 --- a/common/signals.py +++ b/common/signals.py @@ -1,13 +1,26 @@ +import asyncio import signal import sys from loguru import logger from types import FrameType +from common import model + def signal_handler(*_): """Signal handler for main function. Run before uvicorn starts.""" logger.warning("Shutdown signal called. Exiting gracefully.") + + # Run async unloads for model + loop = asyncio.get_running_loop() + if model.container: + loop.create_task(model.container.unload()) + + if model.embeddings_container: + loop.create_task(model.embeddings_container.unload()) + + # Exit the program sys.exit(0) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 2cad876..b702e52 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -23,7 +23,7 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.embeddings import embeddings +from endpoints.OAI.utils.embeddings import get_embeddings router = APIRouter() @@ -134,7 +134,12 @@ async def chat_completion_request( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: - response = await embeddings(data) +async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: + embeddings_task = asyncio.create_task(get_embeddings(data, request)) + response = await run_with_request_disconnect( + request, + embeddings_task, + f"Embeddings request {request.state.id} cancelled by user.", + ) return response diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index 1ce611c..5b43953 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -8,11 +8,11 @@ embeddings function declared async. """ import base64 -import pathlib +from fastapi import Request import numpy as np from loguru import logger -from common import config +from common import model from endpoints.OAI.types.embedding import ( EmbeddingObject, EmbeddingsRequest, @@ -20,84 +20,6 @@ from endpoints.OAI.types.embedding import ( UsageInfo, ) -# Conditionally import infinity embeddings engine -# Required so the logger doesn't take over tabby's logging handlers -try: - from infinity_emb import EngineArgs, AsyncEmbeddingEngine - - has_infinity_emb = True -except ImportError: - has_infinity_emb = False - - -embeddings_model = None - - -def load_embedding_model(model_path: pathlib.Path, device: str): - if not has_infinity_emb: - logger.error( - "Skipping embeddings because infinity-emb is not installed.\n" - "Please run the following command in your environment " - "to install extra packages:\n" - "pip install -U .[extras]" - ) - raise ModuleNotFoundError - - global embeddings_model - try: - engine_args = EngineArgs( - model_name_or_path=str(model_path.resolve()), - engine="torch", - device="cpu", - bettertransformer=False, - model_warmup=False, - ) - embeddings_model = AsyncEmbeddingEngine.from_args(engine_args) - logger.info(f"Trying to load embeddings model: {model_path.name} on {device}") - except Exception as e: - embeddings_model = None - raise e - - -async def embeddings(data: EmbeddingsRequest) -> dict: - embeddings_config = config.embeddings_config() - - # Use CPU by default - device = embeddings_config.get("embeddings_device", "cpu") - if device == "auto": - device = None - - model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir")) - model_path: pathlib.Path = model_path / embeddings_config.get( - "embeddings_model_name" - ) - if not model_path: - logger.info("Embeddings model path not found") - - load_embedding_model(model_path, device) - - async with embeddings_model: - embeddings, usage = await embeddings_model.embed(data.input) - - # OAI expects a return of base64 if the input is base64 - embedding_data = [ - EmbeddingObject( - embedding=float_list_to_base64(emb) - if data.encoding_format == "base64" - else emb.tolist(), - index=n, - ) - for n, emb in enumerate(embeddings) - ] - - response = EmbeddingsResponse( - data=embedding_data, - model=model_path.name, - usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), - ) - - return response - def float_list_to_base64(float_array: np.ndarray) -> str: """ @@ -111,3 +33,32 @@ def float_list_to_base64(float_array: np.ndarray) -> str: # Turn raw base64 encoded bytes into ASCII ascii_string = encoded_bytes.decode("ascii") return ascii_string + + +async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict: + model_path = model.embeddings_container.model_dir + + logger.info(f"Recieved embeddings request {request.state.id}") + embedding_data = await model.embeddings_container.generate(data.input) + + # OAI expects a return of base64 if the input is base64 + embedding_object = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embedding_data.get("embeddings")) + ] + + usage = embedding_data.get("usage") + response = EmbeddingsResponse( + data=embedding_object, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) + + logger.info(f"Finished embeddings request {request.state.id}") + + return response diff --git a/main.py b/main.py index c62a381..56873c4 100644 --- a/main.py +++ b/main.py @@ -87,6 +87,18 @@ async def entrypoint_async(): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) await model.container.load_loras(lora_dir.resolve(), **lora_config) + # If an initial embedding model name is specified, create a separate container + # and load the model + embedding_config = config.embeddings_config() + embedding_model_name = embedding_config.get("embeddings_model_name") + if embedding_model_name: + embedding_model_path = pathlib.Path( + unwrap(embedding_config.get("embeddings_model_dir"), "models") + ) + embedding_model_path = embedding_model_path / embedding_model_name + + await model.load_embeddings_model(embedding_model_path, **embedding_config) + await start_api(host, port) From 01c77028599f4fdb58a01433903ff5f57a0cf64c Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 11:11:05 -0400 Subject: [PATCH 59/89] Signal: Fix async signal handling Run unload async functions before exiting the program. Signed-off-by: kingbri --- backends/infinity/model.py | 3 +++ common/signals.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 2d4ae83..27fc9e5 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -1,6 +1,7 @@ import gc import pathlib import torch +from loguru import logger from typing import List, Optional from common.utils import unwrap @@ -50,6 +51,8 @@ class InfinityContainer: gc.collect() torch.cuda.empty_cache() + logger.info("Embedding model unloaded.") + async def generate(self, sentence_input: List[str]): result_embeddings, usage = await self.engine.embed(sentence_input) diff --git a/common/signals.py b/common/signals.py index d4b67bc..f0b7f19 100644 --- a/common/signals.py +++ b/common/signals.py @@ -13,17 +13,20 @@ def signal_handler(*_): logger.warning("Shutdown signal called. Exiting gracefully.") # Run async unloads for model - loop = asyncio.get_running_loop() - if model.container: - loop.create_task(model.container.unload()) - - if model.embeddings_container: - loop.create_task(model.embeddings_container.unload()) + asyncio.ensure_future(signal_handler_async()) # Exit the program sys.exit(0) +async def signal_handler_async(*_): + if model.container: + await model.container.unload() + + if model.embeddings_container: + await model.embeddings_container.unload() + + def uvicorn_signal_handler(signal_event: signal.Signals): """Overrides uvicorn's signal handler.""" From f13d0fb8b3ed36af37b89102c50c6fcc6a7179e5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 11:17:36 -0400 Subject: [PATCH 60/89] Embeddings: Add model load checks Same as the normal model container. Signed-off-by: kingbri --- backends/infinity/model.py | 7 +++++++ common/model.py | 10 ++++++++-- endpoints/OAI/router.py | 4 ++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 27fc9e5..4c9bb69 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -18,6 +18,8 @@ except ImportError: class InfinityContainer: model_dir: pathlib.Path + model_is_loading: bool = False + model_loaded: bool = False # Conditionally set the type hint based on importablity # TODO: Clean this up @@ -30,6 +32,8 @@ class InfinityContainer: self.model_dir = model_directory async def load(self, **kwargs): + self.model_is_loading = True + # Use cpu by default device = unwrap(kwargs.get("device"), "cpu") @@ -44,6 +48,9 @@ class InfinityContainer: self.engine = AsyncEmbeddingEngine.from_args(engine_args) await self.engine.astart() + self.model_loaded = True + logger.info("Embedding model successfully loaded.") + async def unload(self): await self.engine.astop() self.engine = None diff --git a/common/model.py b/common/model.py index b4b259e..3776ff9 100644 --- a/common/model.py +++ b/common/model.py @@ -162,9 +162,15 @@ async def check_model_container(): async def check_embeddings_container(): - """FastAPI depends that checks if an embeddings model is loaded.""" + """ + FastAPI depends that checks if an embeddings model is loaded. - if embeddings_container is None: + This is the same as the model container check, but with embeddings instead. + """ + + if embeddings_container is None or not ( + embeddings_container.model_is_loading or embeddings_container.model_loaded + ): error_message = handle_request_error( "No embeddings models are currently loaded.", exc_info=False, diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b702e52..b428c00 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -5,7 +5,7 @@ from sys import maxsize from common import config, model from common.auth import check_api_key -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse @@ -132,7 +132,7 @@ async def chat_completion_request( # Embeddings endpoint @router.post( "/v1/embeddings", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: embeddings_task = asyncio.create_task(get_embeddings(data, request)) From bfa011e0cea4a1bc934222ce4502e096df2ecad6 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 15:19:27 -0400 Subject: [PATCH 61/89] Embeddings: Add model management Embedding models are managed on a separate backend, but are run in parallel with the model itself. Therefore, manage this in a separate container with separate routes. Signed-off-by: kingbri --- common/model.py | 23 ++++++--- config_sample.yml | 4 +- endpoints/core/router.py | 90 ++++++++++++++++++++++++++++++++++- endpoints/core/types/model.py | 5 ++ endpoints/core/utils/model.py | 23 ++++++--- main.py | 9 ++-- 6 files changed, 135 insertions(+), 19 deletions(-) diff --git a/common/model.py b/common/model.py index 3776ff9..80858d4 100644 --- a/common/model.py +++ b/common/model.py @@ -57,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): f'Model "{loaded_model_name}" is already loaded! Aborting.' ) - # Unload the existing model - if container and container.model: logger.info("Unloading existing model.") await unload_model() @@ -109,24 +107,35 @@ async def unload_loras(): await container.unload(loras_only=True) -async def load_embeddings_model(model_path: pathlib.Path, **kwargs): +async def load_embedding_model(model_path: pathlib.Path, **kwargs): global embeddings_container # Break out if infinity isn't installed if not has_infinity_emb: - logger.warning( + raise ImportError( "Skipping embeddings because infinity-emb is not installed.\n" "Please run the following command in your environment " "to install extra packages:\n" "pip install -U .[extras]" ) - return + + # Check if the model is already loaded + if embeddings_container and embeddings_container.engine: + loaded_model_name = embeddings_container.model_dir.name + + if loaded_model_name == model_path.name and embeddings_container.model_loaded: + raise ValueError( + f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.' + ) + + logger.info("Unloading existing embeddings model.") + await unload_embedding_model() embeddings_container = InfinityContainer(model_path) await embeddings_container.load(**kwargs) -async def unload_embeddings_model(): +async def unload_embedding_model(): global embeddings_container await embeddings_container.unload() @@ -172,7 +181,7 @@ async def check_embeddings_container(): embeddings_container.model_is_loading or embeddings_container.model_loaded ): error_message = handle_request_error( - "No embeddings models are currently loaded.", + "No embedding models are currently loaded.", exc_info=False, ).error.message diff --git a/config_sample.yml b/config_sample.yml index 053feb6..71a58d2 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -73,9 +73,9 @@ developer: #realtime_process_priority: False embeddings: - embeddings_model_dir: models + embedding_model_dir: models - embeddings_model_name: + embedding_model_name: embeddings_device: cpu diff --git a/endpoints/core/router.py b/endpoints/core/router.py index cd0ed37..5aabd48 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -7,7 +7,7 @@ from sse_starlette import EventSourceResponse from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap @@ -15,6 +15,7 @@ from endpoints.core.types.auth import AuthPermissionResponse from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse from endpoints.core.types.model import ( + EmbeddingModelLoadRequest, ModelCard, ModelList, ModelLoadRequest, @@ -253,6 +254,93 @@ async def unload_loras(): await model.unload_loras() +@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +async def list_embedding_models(request: Request) -> ModelList: + """ + Lists all embedding models in the model directory. + + Requires an admin key to see all embedding models. + """ + + if get_key_permission(request) == "admin": + embedding_model_dir = unwrap( + config.embeddings_config().get("embedding_model_dir"), "models" + ) + embedding_model_path = pathlib.Path(embedding_model_dir) + + models = get_model_list(embedding_model_path.resolve()) + else: + models = await get_current_model_list(model_type="embedding") + + return models + + +@router.get( + "/v1/model/embedding", + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], +) +async def get_embedding_model() -> ModelList: + """Returns the currently loaded embedding model.""" + + return get_current_model_list(model_type="embedding")[0] + + +@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +async def load_embedding_model( + request: Request, data: EmbeddingModelLoadRequest +) -> ModelLoadResponse: + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + embedding_model_dir = pathlib.Path( + unwrap(config.model_config().get("embedding_model_dir"), "models") + ) + embedding_model_path = embedding_model_dir / data.name + + if not embedding_model_path.exists(): + error_message = handle_request_error( + "Could not find the embedding model path for load. " + + "Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + load_task = asyncio.create_task( + model.load_embedding_model(embedding_model_path, **data.model_dump()) + ) + await run_with_request_disconnect( + request, load_task, "Embedding model load request cancelled by user." + ) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + response = ModelLoadResponse( + model_type="embedding_model", module=1, modules=1, status="finished" + ) + + return response + + +@router.post( + "/v1/model/embedding/unload", + dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], +) +async def unload_embedding_model(): + """Unloads the current embedding model.""" + + await model.unload_embedding_model() + + # Encode tokens endpoint @router.post( "/v1/token/encode", diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 30730b8..c107dde 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel): skip_queue: Optional[bool] = False +class EmbeddingModelLoadRequest(BaseModel): + name: str + device: Optional[str] = None + + class ModelLoadResponse(BaseModel): """Represents a model load response.""" diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 0cfb26a..fc61337 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list -async def get_current_model_list(is_draft: bool = False): - """Gets the current model in list format and with path only.""" +async def get_current_model_list(model_type: str = "model"): + """ + Gets the current model in list format and with path only. + + Unified for fetching both models and embedding models. + """ + current_models = [] + model_path = None # Make sure the model container exists - if model.container: - model_path = model.container.get_model_path(is_draft) - if model_path: - current_models.append(ModelCard(id=model_path.name)) + if model_type == "model" or model_type == "draft": + if model.container: + model_path = model.container.get_model_path(model_type == "draft") + elif model_type == "embedding": + if model.embeddings_container: + model_path = model.embeddings_container.model_dir + + if model_path: + current_models.append(ModelCard(id=model_path.name)) return ModelList(data=current_models) diff --git a/main.py b/main.py index 56873c4..bae2f98 100644 --- a/main.py +++ b/main.py @@ -90,14 +90,17 @@ async def entrypoint_async(): # If an initial embedding model name is specified, create a separate container # and load the model embedding_config = config.embeddings_config() - embedding_model_name = embedding_config.get("embeddings_model_name") + embedding_model_name = embedding_config.get("embedding_model_name") if embedding_model_name: embedding_model_path = pathlib.Path( - unwrap(embedding_config.get("embeddings_model_dir"), "models") + unwrap(embedding_config.get("embedding_model_dir"), "models") ) embedding_model_path = embedding_model_path / embedding_model_name - await model.load_embeddings_model(embedding_model_path, **embedding_config) + try: + await model.load_embedding_model(embedding_model_path, **embedding_config) + except ImportError as ex: + logger.error(ex.msg) await start_api(host, port) From dc3dcc9c0ddf721ee67a54b2395df271f0393d2a Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 15:32:26 -0400 Subject: [PATCH 62/89] Embeddings: Update config, args, and parameter names Use embeddings_device as the parameter for device to remove ambiguity. Signed-off-by: kingbri --- backends/infinity/model.py | 2 +- common/args.py | 20 ++++++++++++++++++++ common/config.py | 5 +++++ config_sample.yml | 23 ++++++++++++++++------- endpoints/core/types/model.py | 2 +- 5 files changed, 43 insertions(+), 9 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 4c9bb69..35a4df4 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -35,7 +35,7 @@ class InfinityContainer: self.model_is_loading = True # Use cpu by default - device = unwrap(kwargs.get("device"), "cpu") + device = unwrap(kwargs.get("embeddings_device"), "cpu") engine_args = EngineArgs( model_name_or_path=str(self.model_dir), diff --git a/common/args.py b/common/args.py index e57de78..0548eaf 100644 --- a/common/args.py +++ b/common/args.py @@ -23,6 +23,7 @@ def init_argparser(): ) add_network_args(parser) add_model_args(parser) + add_embeddings_args(parser) add_logging_args(parser) add_developer_args(parser) add_sampling_args(parser) @@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser): sampling_group.add_argument( "--override-preset", type=str, help="Select a sampler override preset" ) + + +def add_embeddings_args(parser: argparse.ArgumentParser): + """Adds arguments specific to embeddings""" + + embeddings_group = parser.add_argument_group("embeddings") + embeddings_group.add_argument( + "--embedding-model-dir", + type=str, + help="Overrides the directory to look for models", + ) + embeddings_group.add_argument( + "--embedding-model-name", type=str, help="An initial model to load" + ) + embeddings_group.add_argument( + "--embeddings-device", + type=str, + help="Device to use for embeddings. Options: (cpu, auto, cuda)", + ) diff --git a/common/config.py b/common/config.py index 5546240..9b2f654 100644 --- a/common/config.py +++ b/common/config.py @@ -59,6 +59,11 @@ def from_args(args: dict): cur_developer_config = developer_config() GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override} + embeddings_override = args.get("embeddings") + if embeddings_override: + cur_embeddings_config = embeddings_config() + GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override} + def sampling_config(): """Returns the sampling parameter config from the global config""" diff --git a/config_sample.yml b/config_sample.yml index 71a58d2..09ae000 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -72,13 +72,6 @@ developer: # Otherwise, the priority will be set to high #realtime_process_priority: False -embeddings: - embedding_model_dir: models - - embedding_model_name: - - embeddings_device: cpu - # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: @@ -208,3 +201,19 @@ model: #loras: #- name: lora1 # scaling: 1.0 + +# Options for embedding models and loading. +# NOTE: Embeddings requires the "extras" feature to be installed +# Install it via "pip install .[extras]" +embeddings: + # Overrides directory to look for embedding models (default: models) + embedding_model_dir: models + + # An initial embedding model to load on the infinity backend (default: None) + embedding_model_name: + + # Device to load embedding models on (default: cpu) + # Possible values: cpu, auto, cuda + # NOTE: It's recommended to load embedding models on the CPU. + # If you'd like to load on an AMD gpu, set this value to "cuda" as well. + embeddings_device: cpu diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index c107dde..8b3d83e 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -139,7 +139,7 @@ class ModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel): name: str - device: Optional[str] = None + embeddings_device: Optional[str] = None class ModelLoadResponse(BaseModel): From 46304ce875a6814c68b3cd45f617bc45196e49af Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 18:42:25 -0400 Subject: [PATCH 63/89] Model: Properly pass in max_batch_size from config The override wasn't being passed in before. Also, the default is now none since Exl2 can automatically calculate the max batch size. Signed-off-by: kingbri --- backends/exllamav2/model.py | 5 ++++- config_sample.yml | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3df16b0..c7c032a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -70,7 +70,7 @@ class ExllamaV2Container: cache_size: int = None cache_mode: str = "FP16" draft_cache_mode: str = "FP16" - max_batch_size: int = 20 + max_batch_size: Optional[int] = None generation_config: Optional[GenerationConfig] = None hf_config: Optional[HuggingFaceConfig] = None @@ -217,6 +217,9 @@ class ExllamaV2Container: # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) + # Set max batch size to the config override + self.max_batch_size = unwrap(kwargs.get("max_batch_size")) + # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention if ( diff --git a/config_sample.yml b/config_sample.yml index c92f673..e57d947 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -146,11 +146,11 @@ model: # NOTE: Effects vary depending on the model. An ideal value is between 512 and 4096 #chunk_size: 2048 - # Set the maximum amount of prompts to process at one time (batch) - # This will be automatically adjusted depending on the cache size. + # Set the maximum amount of prompts to process at one time (default: None/Automatic) + # This will be automatically calculated if left blank. # A max batch size of 1 processes prompts one at a time. # NOTE: Only available for Nvidia ampere (30 series) and above GPUs - #max_batch_size: 20 + #max_batch_size: # Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None) # If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name From f111052e3905539aa6dc42dfe835e5aeff3c80a1 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 20:46:37 -0400 Subject: [PATCH 64/89] Dependencies: Use hosted pip index instead of Github Installing directly from github causes pip's HTTP cache to not recognize that the correct version of a package is already installed. This causes a redownload. When using the Start.bat script, it updates dependencies automatically to keep users on the latest versions of a package for security reasons. A simple pip cache website helps alleviate this problem and allows pip to find the cached wheels when invoked with an upgrade argument. Signed-off-by: kingbri --- pyproject.toml | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3deace0..19be7f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,22 +62,22 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] cu118 = [ # Torch @@ -89,17 +89,17 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu118/flash-attn/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu118/flash-attn/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu118/flash-attn/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] amd = [ # Torch triton for ROCm @@ -113,9 +113,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/rocm/exllamav2/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/rocm/exllamav2/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://royallab-pip-index.netlify.app/whl/rocm/exllamav2/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options From 9390d362dd99dfe29ed7eb9b44db9998aeb20443 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 1 Aug 2024 00:19:21 -0400 Subject: [PATCH 65/89] Model: Log generation params and metrics after the prompt/response A user's prompt and response can be large in the console. Therefore, always log the smaller payloads (ex. gen params + metrics) after the large chunks. However, it's recommended to keep prompt logging off anyways since it'll result in console spam. Signed-off-by: kingbri --- backends/exllamav2/model.py | 76 ++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c7c032a..373a753 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1126,31 +1126,6 @@ class ExllamaV2Container: # This is an inverse of skip_special_tokens decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) - # Log generation options to console - # Some options are too large, so log the args instead - log_generation_params( - request_id=request_id, - max_tokens=max_tokens, - min_tokens=min_tokens, - stream=kwargs.get("stream"), - **gen_settings_log_dict, - token_healing=token_healing, - auto_scale_penalty_range=auto_scale_penalty_range, - generate_window=generate_window, - bos_token_id=self.tokenizer.bos_token_id, - eos_token_id=eos_tokens, - add_bos_token=add_bos_token, - ban_eos_token=ban_eos_token, - skip_special_tokens=not decode_special_tokens, - speculative_ngram=self.generator.speculative_ngram, - logprobs=request_logprobs, - stop_conditions=stop_conditions, - banned_tokens=banned_tokens, - banned_strings=banned_strings, - logit_bias=logit_bias, - filters=grammar_handler.filters, - ) - # Log prompt to console log_prompt(prompt, request_id, negative_prompt) @@ -1181,6 +1156,7 @@ class ExllamaV2Container: max_seq_len = self.config.max_seq_len generated_tokens = 0 full_response = "" + metrics_result = {} # Get the generation status once it's ready try: @@ -1241,16 +1217,8 @@ class ExllamaV2Container: "length" if eos_reason == "max_new_tokens" else "stop" ) - log_metrics( - result.get("time_enqueued"), - result.get("prompt_tokens"), - result.get("cached_tokens"), - result.get("time_prefill"), - result.get("new_tokens"), - result.get("time_generate"), - context_len, - max_seq_len, - ) + # Save the final result for metrics logging + metrics_result = result # Remove the token text generation = { @@ -1274,3 +1242,41 @@ class ExllamaV2Container: asyncio.ensure_future(self.create_generator()) raise ex + finally: + # Log generation options to console + # Some options are too large, so log the args instead + log_generation_params( + request_id=request_id, + max_tokens=max_tokens, + min_tokens=min_tokens, + stream=kwargs.get("stream"), + **gen_settings_log_dict, + token_healing=token_healing, + auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=eos_tokens, + add_bos_token=add_bos_token, + ban_eos_token=ban_eos_token, + skip_special_tokens=not decode_special_tokens, + speculative_ngram=self.generator.speculative_ngram, + logprobs=request_logprobs, + stop_conditions=stop_conditions, + banned_tokens=banned_tokens, + banned_strings=banned_strings, + logit_bias=logit_bias, + filters=grammar_handler.filters, + ) + + # Log the metrics if present + if metrics_result: + log_metrics( + metrics_result.get("time_enqueued"), + metrics_result.get("prompt_tokens"), + metrics_result.get("cached_tokens"), + metrics_result.get("time_prefill"), + metrics_result.get("new_tokens"), + metrics_result.get("time_generate"), + context_len, + max_seq_len, + ) From 0bcb4e4a7d17340199f08bc078319ca85b15510e Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 1 Aug 2024 00:25:54 -0400 Subject: [PATCH 66/89] Model: Attach request ID to logs If multiple logs come in at once, track which log corresponds to which request. Signed-off-by: kingbri --- backends/exllamav2/model.py | 3 ++- common/gen_logging.py | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 373a753..5233229 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1210,7 +1210,7 @@ class ExllamaV2Container: # Second yield if eos is true if result.get("eos"): - log_response(full_response) + log_response(request_id, full_response) eos_reason = result.get("eos_reason") finish_reason = ( @@ -1271,6 +1271,7 @@ class ExllamaV2Container: # Log the metrics if present if metrics_result: log_metrics( + request_id, metrics_result.get("time_enqueued"), metrics_result.get("prompt_tokens"), metrics_result.get("cached_tokens"), diff --git a/common/gen_logging.py b/common/gen_logging.py index 94c4405..9995818 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -64,14 +64,18 @@ def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]): logger.info(f"Negative Prompt: {formatted_negative_prompt}\n") -def log_response(response: str): +def log_response(request_id: str, response: str): """Logs the response to console.""" if PREFERENCES.prompt: formatted_response = "\n" + response - logger.info(f"Response: {formatted_response if response else 'Empty'}\n") + logger.info( + f"Response (ID: {request_id}): " + f"{formatted_response if response else 'Empty'}\n" + ) def log_metrics( + request_id: str, queue_time: float, prompt_tokens: int, cached_tokens: int, @@ -82,7 +86,7 @@ def log_metrics( max_seq_len: int, ): initial_response = ( - f"Metrics: {generated_tokens} tokens generated in " + f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in " f"{round(queue_time + prompt_time + generate_time, 2)} seconds" ) itemization = [] From 54aeebaec17e6791defaaa256258915a443d3716 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 1 Aug 2024 13:43:31 -0400 Subject: [PATCH 67/89] API: Fix return of current embeddings model Return a ModelCard instead of a ModelList. Signed-off-by: kingbri --- endpoints/core/router.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 5aabd48..a857a1a 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -279,10 +279,11 @@ async def list_embedding_models(request: Request) -> ModelList: "/v1/model/embedding", dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], ) -async def get_embedding_model() -> ModelList: +async def get_embedding_model() -> ModelCard: """Returns the currently loaded embedding model.""" + models = await get_current_model_list(model_type="embedding") - return get_current_model_list(model_type="embedding")[0] + return models.data[0] @router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) From 3e42211c3e646063fa76f73abb0a128ef36af980 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 1 Aug 2024 13:59:49 -0400 Subject: [PATCH 68/89] Config: Embeddings: Make embeddings_device a default when API loading When loading from the API, the fallback for embeddings_device will be the same as the config. Signed-off-by: kingbri --- common/model.py | 18 ++++++++++++++++-- config_sample.yml | 9 ++++++--- endpoints/core/types/model.py | 12 ++++++++---- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/common/model.py b/common/model.py index 80858d4..feedc9f 100644 --- a/common/model.py +++ b/common/model.py @@ -5,6 +5,7 @@ Containers exist as a common interface for backends. """ import pathlib +from enum import Enum from fastapi import HTTPException from loguru import logger from typing import Optional @@ -31,6 +32,12 @@ if not do_export_openapi: embeddings_container: Optional[InfinityContainer] = None +class ModelType(Enum): + MODEL = "model" + DRAFT = "draft" + EMBEDDING = "embedding" + + def load_progress(module, modules): """Wrapper callback for load progress.""" yield module, modules @@ -142,16 +149,23 @@ async def unload_embedding_model(): embeddings_container = None -def get_config_default(key, fallback=None, is_draft=False): +def get_config_default(key: str, fallback=None, model_type: str = "model"): """Fetches a default value from model config if allowed by the user.""" model_config = config.model_config() default_keys = unwrap(model_config.get("use_as_default"), []) + + # Add extra keys to defaults + default_keys.append("embeddings_device") + if key in default_keys: # Is this a draft model load parameter? - if is_draft: + if model_type == "draft": draft_config = config.draft_model_config() return unwrap(draft_config.get(key), fallback) + elif model_type == "embedding": + embeddings_config = config.embeddings_config() + return unwrap(embeddings_config.get(key), fallback) else: return unwrap(model_config.get(key), fallback) else: diff --git a/config_sample.yml b/config_sample.yml index f3a1c51..018ff61 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -209,11 +209,14 @@ embeddings: # Overrides directory to look for embedding models (default: models) embedding_model_dir: models - # An initial embedding model to load on the infinity backend (default: None) - embedding_model_name: - # Device to load embedding models on (default: cpu) # Possible values: cpu, auto, cuda # NOTE: It's recommended to load embedding models on the CPU. # If you'd like to load on an AMD gpu, set this value to "cuda" as well. embeddings_device: cpu + + # The below parameters only apply for initial loads + # All API based loads do NOT inherit these settings unless specified in use_as_default + + # An initial embedding model to load on the infinity backend (default: None) + embedding_model_name: diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 8b3d83e..1e2eb46 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel): # Config arguments draft_rope_scale: Optional[float] = Field( default_factory=lambda: get_config_default( - "draft_rope_scale", 1.0, is_draft=True + "draft_rope_scale", 1.0, model_type="draft" ) ) draft_rope_alpha: Optional[float] = Field( description="Automatically calculated if not present", default_factory=lambda: get_config_default( - "draft_rope_alpha", None, is_draft=True + "draft_rope_alpha", None, model_type="draft" ), examples=[1.0], ) draft_cache_mode: Optional[str] = Field( default_factory=lambda: get_config_default( - "draft_cache_mode", "FP16", is_draft=True + "draft_cache_mode", "FP16", model_type="draft" ) ) @@ -139,7 +139,11 @@ class ModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel): name: str - embeddings_device: Optional[str] = None + embeddings_device: Optional[str] = Field( + default_factory=lambda: get_config_default( + "embeddings_device", model_type="embedding" + ) + ) class ModelLoadResponse(BaseModel): From 56619810bf1e6a05f8a9be55645677132ddccd7e Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 2 Aug 2024 13:34:47 -0400 Subject: [PATCH 69/89] Dependencies: Switch sentence-transformers to infinity-emb Leftover before the transition. Signed-off-by: kingbri --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 547c363..6016864 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ extras = [ # Heavy dependencies that aren't for everyday use "outlines", - "sentence-transformers" + "infinity-emb" ] dev = [ "ruff == 0.3.2" From b124797949a0df56957a8ec32ead8f79f22ed5ca Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 2 Aug 2024 14:35:58 -0400 Subject: [PATCH 70/89] Dependencies: Re-add sentence-transformers This is actually required for infinity to load a model. Signed-off-by: kingbri --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6016864..e53d887 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,8 @@ dependencies = [ extras = [ # Heavy dependencies that aren't for everyday use "outlines", - "infinity-emb" + "infinity-emb", + "sentence-transformers", ] dev = [ "ruff == 0.3.2" From 7bf2b07d4c5b5821f2451371bdd8f3ed8a84e207 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 2 Aug 2024 15:10:27 -0400 Subject: [PATCH 71/89] Signals: Exit on async cleanup The async signal exit function should be the internal for exiting the program. In addition, prevent the handler from being called twice by adding a boolean. May become an asyncio event later on. In addition, make sure to skip_wait when running model.unload. Signed-off-by: kingbri --- common/signals.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/common/signals.py b/common/signals.py index f0b7f19..d4e144c 100644 --- a/common/signals.py +++ b/common/signals.py @@ -7,24 +7,35 @@ from types import FrameType from common import model +SHUTTING_DOWN: bool = False + + def signal_handler(*_): """Signal handler for main function. Run before uvicorn starts.""" + global SHUTTING_DOWN + + if SHUTTING_DOWN: + return + logger.warning("Shutdown signal called. Exiting gracefully.") + SHUTTING_DOWN = True # Run async unloads for model asyncio.ensure_future(signal_handler_async()) - # Exit the program - sys.exit(0) - async def signal_handler_async(*_): + """Internal signal handler. Runs all async code to shut down the program.""" + if model.container: - await model.container.unload() + await model.unload_model(skip_wait=True) if model.embeddings_container: - await model.embeddings_container.unload() + await model.unload_embedding_model() + + # Exit the program + sys.exit(0) def uvicorn_signal_handler(signal_event: signal.Signals): From e66d213aef7a6bfa37c8ce9fc97ab95caa4fcf70 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 11:35:26 -0400 Subject: [PATCH 72/89] Revert "Dependencies: Use hosted pip index instead of Github" This reverts commit f111052e3905539aa6dc42dfe835e5aeff3c80a1. This was a bad idea since the netlify server has limited bandwidth. Signed-off-by: kingbri --- pyproject.toml | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e53d887..7cacebc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,22 +64,22 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu121/exllamav2/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu121/flash-attn/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] cu118 = [ # Torch @@ -91,17 +91,17 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/cu118/exllamav2/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu118/flash-attn/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu118/flash-attn/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://royallab-pip-index.netlify.app/whl/cu118/flash-attn/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] amd = [ # Torch triton for ROCm @@ -115,9 +115,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/rocm/exllamav2/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/rocm/exllamav2/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://royallab-pip-index.netlify.app/whl/rocm/exllamav2/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options From 7ce46cc2da12fdfab81c8391027b4f6f83770bbf Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 13:03:24 -0400 Subject: [PATCH 73/89] Start: Rewrite start scripts Start scripts now don't update dependencies by default due to mishandling caches from pip. Also add dedicated update scripts and save options to a JSON file instead of a text one. Signed-off-by: kingbri --- .gitignore | 3 + start.bat | 2 + start.py | 143 ++++++++++++++++++------ update_scripts/update_deps.bat | 24 ++++ update_scripts/update_deps.sh | 19 ++++ update_scripts/update_deps_and_pull.bat | 24 ++++ update_scripts/update_deps_and_pull.sh | 19 ++++ 7 files changed, 200 insertions(+), 34 deletions(-) create mode 100644 update_scripts/update_deps.bat create mode 100644 update_scripts/update_deps.sh create mode 100644 update_scripts/update_deps_and_pull.bat create mode 100644 update_scripts/update_deps_and_pull.sh diff --git a/.gitignore b/.gitignore index ebc96c7..eca917d 100644 --- a/.gitignore +++ b/.gitignore @@ -200,5 +200,8 @@ sampler_overrides/* # Gpu lib preferences file gpu_lib.txt +# Start options file +start_options.json + # OpenAPI JSON openapi.json diff --git a/start.bat b/start.bat index c7b8330..4ae3247 100644 --- a/start.bat +++ b/start.bat @@ -19,3 +19,5 @@ if exist "%CONDA_PREFIX%" ( :: Call the python script with batch args call python start.py %* + +pause diff --git a/start.py b/start.py index ddf4d27..aba9bc3 100644 --- a/start.py +++ b/start.py @@ -1,6 +1,7 @@ """Utility to automatically upgrade and start the API""" import argparse +import json import os import pathlib import platform @@ -11,6 +12,9 @@ from shutil import copyfile from common.args import convert_args_to_dict, init_argparser +start_options = {} + + def get_user_choice(question: str, options_dict: dict): """ Gets user input in a commandline script. @@ -39,36 +43,24 @@ def get_install_features(lib_name: str = None): """Fetches the appropriate requirements file depending on the GPU""" install_features = None possible_features = ["cu121", "cu118", "amd"] - saved_lib_path = pathlib.Path("gpu_lib.txt") - if lib_name: - print("Overriding GPU lib name from args.") - else: - # Try getting the GPU lib from file - if saved_lib_path.exists(): - with open(saved_lib_path.resolve(), "r") as f: - lib_name = f.readline().strip() - else: - # Ask the user for the GPU lib - gpu_lib_choices = { - "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, - "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, - "C": {"pretty": "AMD", "internal": "amd"}, - } - user_input = get_user_choice( - "Select your GPU. If you don't know, select Cuda 12.x (A)", - gpu_lib_choices, - ) + if not lib_name: + # Ask the user for the GPU lib + gpu_lib_choices = { + "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, + "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, + "C": {"pretty": "AMD", "internal": "amd"}, + } + user_input = get_user_choice( + "Select your GPU. If you don't know, select Cuda 12.x (A)", + gpu_lib_choices, + ) - lib_name = gpu_lib_choices.get(user_input, {}).get("internal") + lib_name = gpu_lib_choices.get(user_input, {}).get("internal") - # Write to a file for subsequent runs - with open(saved_lib_path.resolve(), "w") as f: - f.write(lib_name) - print( - "Saving your choice to gpu_lib.txt. " - "Delete this file and restart if you want to change your selection." - ) + # Write to start options + start_options["gpu_lib"] = lib_name + print("Saving your choice to start options.") # Assume default if the file is invalid if lib_name and lib_name in possible_features: @@ -104,10 +96,16 @@ def add_start_args(parser: argparse.ArgumentParser): """Add start script args to the provided parser""" start_group = parser.add_argument_group("start") start_group.add_argument( - "-iu", - "--ignore-upgrade", + "-ur", + "--update-repository", action="store_true", - help="Ignore requirements upgrade", + help="Update local git repository to latest", + ) + start_group.add_argument( + "-ud", + "--update-deps", + action="store_true", + help="Update all pip dependencies", ) start_group.add_argument( "-nw", @@ -122,6 +120,25 @@ def add_start_args(parser: argparse.ArgumentParser): ) +def migrate_gpu_lib(): + gpu_lib_path = pathlib.Path("gpu_lib.txt") + + if not gpu_lib_path.exists(): + return + + print("Migrating gpu_lib.txt to the new start_options.json") + with open("gpu_lib.txt", "r") as gpu_lib_file: + start_options["gpu_lib"] = gpu_lib_file.readline().strip() + start_options["first_run_done"] = True + + # Remove the old file + gpu_lib_path.unlink() + + print( + "Successfully migrated gpu lib options to start_options. " + "The old file has been deleted." + ) + if __name__ == "__main__": subprocess.run(["pip", "-V"]) @@ -129,6 +146,34 @@ if __name__ == "__main__": parser = init_argparser() add_start_args(parser) args = parser.parse_args() + script_ext = "bat" if platform.system() == "Windows" else "sh" + + start_options_path = pathlib.Path("start_options.json") + if start_options_path.exists(): + with open(start_options_path) as start_options_file: + start_options = json.load(start_options_file) + + if start_options.get("first_run_done"): + first_run = False + else: + print( + "It looks like you're running TabbyAPI for the first time. " + "Getting things ready..." + ) + + # Migrate from old setting storage + migrate_gpu_lib() + + # Set variables that rely on start options + first_run = not start_options.get("first_run_done") + + if args.gpu_lib: + print("Overriding GPU lib name from args.") + gpu_lib = args.gpu_lib + elif "gpu_lib" in start_options: + gpu_lib = start_options.get("gpu_lib") + else: + gpu_lib = None # Create a config if it doesn't exist # This is not necessary to run TabbyAPI, but is new user proof @@ -144,10 +189,13 @@ if __name__ == "__main__": f"Created one at {str(config_path.resolve())}" ) - if args.ignore_upgrade: - print("Ignoring pip dependency upgrade due to user request.") - else: - install_features = None if args.nowheel else get_install_features(args.gpu_lib) + if args.update_repository: + print("Pulling latest changes from Github.") + pull_command = "git pull" + subprocess.run(pull_command.split(" ")) + + if first_run or args.update_deps: + install_features = None if args.nowheel else get_install_features(gpu_lib) features = f"[{install_features}]" if install_features else "" # pip install .[features] @@ -155,8 +203,35 @@ if __name__ == "__main__": print(f"Running install command: {install_command}") subprocess.run(install_command.split(" ")) + if args.update_deps: + print( + f"Dependencies updated. Please run TabbyAPI with `start.{script_ext}`. " + "Exiting." + ) + sys.exit(0) + else: + print( + f"Dependencies installed. Update them with `update_deps.{script_ext}` " + "inside the `update_scripts` folder." + ) + + if first_run: + start_options["first_run_done"] = True + + # Save start options + with open("start_options.json", "w") as start_file: + start_file.write(json.dumps(start_options)) + + print( + "Successfully wrote your start script options to `start_options.json`. \n" + "If something goes wrong, editing or deleting the file will reinstall TabbyAPI " + "as a first-time user." + ) + # Import entrypoint after installing all requirements from main import entrypoint converted_args = convert_args_to_dict(args, parser) + + print("Starting TabbyAPI...") entrypoint(converted_args) diff --git a/update_scripts/update_deps.bat b/update_scripts/update_deps.bat new file mode 100644 index 0000000..e03c827 --- /dev/null +++ b/update_scripts/update_deps.bat @@ -0,0 +1,24 @@ +@echo off + +:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades +:: This is intended for users who want to start the API and have everything upgraded and installed + +:: cd to the parent directory +cd "%~dp0.." + +:: Don't create a venv if a conda environment is active +if exist "%CONDA_PREFIX%" ( + echo It looks like you're in a conda environment. Skipping venv check. +) else ( + if not exist "venv\" ( + echo Venv doesn't exist! Please run start.bat instead. + exit 0 + ) + + call .\venv\Scripts\activate.bat +) + +:: Call the python script with batch args +call python start.py --update-deps %* + +pause diff --git a/update_scripts/update_deps.sh b/update_scripts/update_deps.sh new file mode 100644 index 0000000..becfa49 --- /dev/null +++ b/update_scripts/update_deps.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +cd "$(dirname "$0")/.." || exit + +if [ -n "$CONDA_PREFIX" ]; then + echo "It looks like you're in a conda environment. Skipping venv check." +else + if [ ! -d "venv" ]; then + echo "Venv doesn't exist! Please run start.sh instead." + exit 0 + fi + + echo "Activating venv" + + # shellcheck source=/dev/null + source venv/bin/activate +fi + +python3 start.py --update-deps "$@" diff --git a/update_scripts/update_deps_and_pull.bat b/update_scripts/update_deps_and_pull.bat new file mode 100644 index 0000000..c22866b --- /dev/null +++ b/update_scripts/update_deps_and_pull.bat @@ -0,0 +1,24 @@ +@echo off + +:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades +:: This is intended for users who want to start the API and have everything upgraded and installed + +:: cd to the parent directory +cd "%~dp0.." + +:: Don't create a venv if a conda environment is active +if exist "%CONDA_PREFIX%" ( + echo It looks like you're in a conda environment. Skipping venv check. +) else ( + if not exist "venv\" ( + echo Venv doesn't exist! Please run start.bat instead. + exit 0 + ) + + call .\venv\Scripts\activate.bat +) + +:: Call the python script with batch args +call python start.py --update-deps --update-repository %* + +pause diff --git a/update_scripts/update_deps_and_pull.sh b/update_scripts/update_deps_and_pull.sh new file mode 100644 index 0000000..4582cc5 --- /dev/null +++ b/update_scripts/update_deps_and_pull.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +cd "$(dirname "$0")/.." || exit + +if [ -n "$CONDA_PREFIX" ]; then + echo "It looks like you're in a conda environment. Skipping venv check." +else + if [ ! -d "venv" ]; then + echo "Venv doesn't exist! Please run start.sh instead." + exit 0 + fi + + echo "Activating venv" + + # shellcheck source=/dev/null + source venv/bin/activate +fi + +python3 start.py --update-deps --update-repository "$@" From 65e758e134fc018112a99af64a0649779c878011 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 15:07:01 -0400 Subject: [PATCH 74/89] Tree: Format Signed-off-by: kingbri --- start.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/start.py b/start.py index aba9bc3..f38fa41 100644 --- a/start.py +++ b/start.py @@ -139,6 +139,7 @@ def migrate_gpu_lib(): "The old file has been deleted." ) + if __name__ == "__main__": subprocess.run(["pip", "-V"]) @@ -224,8 +225,8 @@ if __name__ == "__main__": print( "Successfully wrote your start script options to `start_options.json`. \n" - "If something goes wrong, editing or deleting the file will reinstall TabbyAPI " - "as a first-time user." + "If something goes wrong, editing or deleting the file " + "will reinstall TabbyAPI as a first-time user." ) # Import entrypoint after installing all requirements From b795bfc7b2224fb9da17ac47fff43b5d1588e6ca Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 15:14:40 -0400 Subject: [PATCH 75/89] Start: Split some prints up Newlines can be helpful at times. Signed-off-by: kingbri --- start.py | 1 + 1 file changed, 1 insertion(+) diff --git a/start.py b/start.py index f38fa41..54c70b5 100644 --- a/start.py +++ b/start.py @@ -203,6 +203,7 @@ if __name__ == "__main__": install_command = f"pip install -U .{features}" print(f"Running install command: {install_command}") subprocess.run(install_command.split(" ")) + print("\n") if args.update_deps: print( From 8703b23f89d6eb7f3fc78bebeef83322261c5f0c Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 15:19:31 -0400 Subject: [PATCH 76/89] Start: Make linux scripts executable Signed-off-by: kingbri --- update_scripts/update_deps.sh | 0 update_scripts/update_deps_and_pull.sh | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 update_scripts/update_deps.sh mode change 100644 => 100755 update_scripts/update_deps_and_pull.sh diff --git a/update_scripts/update_deps.sh b/update_scripts/update_deps.sh old mode 100644 new mode 100755 diff --git a/update_scripts/update_deps_and_pull.sh b/update_scripts/update_deps_and_pull.sh old mode 100644 new mode 100755 From 2a33ebbf2909d6d51fee0782b01acaee37bf2ad3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 16:05:34 -0400 Subject: [PATCH 77/89] Model: Bypass lock checks when shutting down Previously, when a SIGINT was emitted and a model load is running, the API didn't shut down until the load finished due to waitng for the lock. However, when shutting down, the lock doesn't matter since the process is being killed anyway. Signed-off-by: kingbri --- backends/exllamav2/model.py | 19 ++++++++++++------- common/model.py | 4 ++-- common/signals.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5233229..3ded71b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -734,11 +734,15 @@ class ExllamaV2Container: Free all VRAM resources used by this model """ - try: - await self.load_lock.acquire() + # Shutdown immediately unloads and bypasses all locks + do_shutdown = kwargs.get("shutdown") - # Wait for other jobs to finish - await self.wait_for_jobs(kwargs.get("skip_wait")) + try: + if not do_shutdown: + await self.load_lock.acquire() + + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) # Delete references held in the grammar module clear_grammar_func_cache() @@ -778,10 +782,11 @@ class ExllamaV2Container: logger.info("Loras unloaded." if loras_only else "Model unloaded.") finally: - self.load_lock.release() + if not do_shutdown: + self.load_lock.release() - async with self.load_condition: - self.load_condition.notify_all() + async with self.load_condition: + self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" diff --git a/common/model.py b/common/model.py index feedc9f..0bfbab2 100644 --- a/common/model.py +++ b/common/model.py @@ -43,11 +43,11 @@ def load_progress(module, modules): yield module, modules -async def unload_model(skip_wait: bool = False): +async def unload_model(skip_wait: bool = False, shutdown: bool = False): """Unloads a model""" global container - await container.unload(skip_wait=skip_wait) + await container.unload(skip_wait=skip_wait, shutdown=shutdown) container = None diff --git a/common/signals.py b/common/signals.py index d4e144c..97f595b 100644 --- a/common/signals.py +++ b/common/signals.py @@ -29,7 +29,7 @@ async def signal_handler_async(*_): """Internal signal handler. Runs all async code to shut down the program.""" if model.container: - await model.unload_model(skip_wait=True) + await model.unload_model(skip_wait=True, shutdown=True) if model.embeddings_container: await model.unload_embedding_model() From 5fb9cdc2b1590aa459414b60cb48df21e6575c64 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 17:43:14 -0400 Subject: [PATCH 78/89] Dependencies: Add Python 3.12 specific dependencies Install a prebuilt fastparquet wheel for Windows and add setuptools since torch may require it for some reason. Signed-off-by: kingbri --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7cacebc..e591a15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,10 @@ dependencies = [ # TEMP: Remove once 2.x is fixed in upstream "numpy < 2.0.0", + + # For python 3.12 + "fastparquet @ https://github.com/theroyallab/fastparquet/releases/download/v2024.5.0/fastparquet-0.1.dev837-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "setuptools ; python_version == '3.12'" ] [project.urls] From 4868fc6b1071a233745da8029d16f5047929deed Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 20:56:40 -0400 Subject: [PATCH 79/89] Update README Signed-off-by: kingbri --- README.md | 50 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 3b1eea1..824826d 100644 --- a/README.md +++ b/README.md @@ -32,20 +32,41 @@ A FastAPI based application that allows for generating text using an LLM (large language model) using the [Exllamav2 backend](https://github.com/turboderp/exllamav2) +TabbyAPI is also the official API backend server for ExllamaV2. + ## Disclaimer -This project is marked rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed. +This project is marked as rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed. -TabbyAPI is a hobby project solely for a small amount of users. It is not meant to run on production servers. For that, please look at other backends that support those workloads. +TabbyAPI is a hobby project made for a small amount of users. It is not meant to run on production servers. For that, please look at other solutions that support those workloads. ## Getting Started > [!IMPORTANT] > -> This README is not for getting started. Please read the Wiki. +> This README does not have instructions for setting up. Please read the Wiki. Read the [Wiki](https://github.com/theroyallab/tabbyAPI/wiki/1.-Getting-Started) for more information. It contains user-facing documentation for installation, configuration, sampling, API usage, and so much more. +## Features + +- OpenAI compatible API +- Loading/unloading models +- HuggingFace model downloading +- Embedding model support +- JSON schema + Regex + EBNF support +- AI Horde support +- Speculative decoding via draft models +- Multi-lora with independent scaling (ex. a weight of 0.9) +- Inbuilt proxy to override client request parameters/samplers +- Flexible Jinja2 template engine for chat completions that conforms to HuggingFace +- Concurrent inference with asyncio +- Utilizes modern python paradigms +- Continuous batching engine using paged attention +- Fast classifer-free guidance + +And much more. If something is missing here, PR it in! + ## Supported Model Types TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, loading, etc. Therefore, the following types of models are supported: @@ -58,18 +79,6 @@ TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, load In addition, TabbyAPI supports parallel batching using paged attention for Nvidia Ampere GPUs and higher. -#### Alternative Loaders/Backends - -If you want to use a different model type or quantization method than the ones listed above, here are some alternative backends with their own APIs: - -- GGUF + GGML - [KoboldCPP](https://github.com/lostruins/KoboldCPP) - -- Production ready + Many other quants + batching - [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine) - -- Production ready + batching - [VLLM](https://github.com/vllm-project/vllm) - -- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui) - ## Contributing Use the template when creating issues or pull requests, otherwise the developers may not look at your post. @@ -84,6 +93,17 @@ If you have a Pull Request - Describe the pull request in detail, what, and why you are changing something +## Acknowldgements + +TabbyAPI would not exist without the work of other contributors and FOSS projects: + +- [ExllamaV2](https://github.com/turboderp/exllamav2) +- [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine) +- [infinity-emb](https://github.com/michaelfeil/infinity) +- [FastAPI](https://github.com/fastapi/fastapi) +- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui) +- [SillyTavern](https://github.com/SillyTavern/SillyTavern) + ## Developers and Permissions Creators/Developers: From 87b6a31fad7fb74baa46d4bed85738a8b79f61fe Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 20:59:28 -0400 Subject: [PATCH 80/89] Update .gitignore Signed-off-by: kingbri --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index eca917d..2761a6b 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,6 @@ start_options.json # OpenAPI JSON openapi.json + +# Infinity-emb cache +.infinity_cache/ From 1aa934664c7f930265be77f62d0a8649e40aac54 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 21:44:45 -0400 Subject: [PATCH 81/89] Issues: Update issue templates Use forms instead of markdown templates. Signed-off-by: kingbri --- .github/ISSUE_TEMPLATE/bug_report.md | 35 -------- .github/ISSUE_TEMPLATE/bug_report.yaml | 97 +++++++++++++++++++++ .github/ISSUE_TEMPLATE/feature_request.md | 26 ------ .github/ISSUE_TEMPLATE/feature_request.yaml | 69 +++++++++++++++ 4 files changed, 166 insertions(+), 61 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/bug_report.yaml delete mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.yaml diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 2bbf60e..0000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,35 +0,0 @@ ---- -name: Bug report -about: Report code related issues -title: "[BUG]" -labels: bug -assignees: '' - ---- - -**Disclaimer:** Github Issues are **only** for code related bugs. If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj) - -**Describe the bug** -A clear and concise description of what the bug is. - -**To Reproduce** -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Logs** -If applicable, add logs and tracebacks to help explain your problem. - -**System info** (Bugs without this information will go lower on our priority list!) - - OS: [ex. Windows] - - Python version: [ex. 3.11] - - CUDA/ROCm version: [ex. 12.x] - - Python version: [ex. 3.11] - -**Additional context** -Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000..c520e24 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,97 @@ +name: Bug report +description: Report code related issues +title: "[BUG]" +labels: bug +body: + +- type: markdown + attributes: + value: | + ### Disclaimer: + Github Issues are **only** for code related bugs. + If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj) + +- type: dropdown + attributes: + label: OS + options: + - Windows + - Linux + validations: + required: true + +- type: dropdown + attributes: + label: GPU Library + description: Ex. CUDA, ROCm + options: + - CUDA 12.x + - CUDA 11.8 + - AMD ROCm + validations: + required: true + +- type: dropdown + attributes: + label: Python version + options: + - '3.12' + - '3.11' + - '3.10' + validations: + required: true + +- type: textarea + attributes: + label: Describe the bug + description: A clear and concise description of what the bug is. + validations: + required: true + +- type: textarea + attributes: + label: Reproduction steps + description: Walk us through how the bug occurred and how to make it happen. + validations: + required: true + +- type: textarea + attributes: + label: Expected behavior + description: What was expected to happen? + validations: + required: true + +- type: textarea + attributes: + label: Logs + description: If applicable, add logs and tracebacks to help explain your problem. + validations: + required: false + +- type: textarea + attributes: + label: Additional context + description: Add any other context about the problem here. + validations: + required: false + +- type: checkboxes + attributes: + label: Acknowledgements + description: Before submitting this issue, please make sure you have completed the following checklist. + options: + - label: I have looked for similar issues before submitting this one. + required: true + - label: I have read the disclaimer, and this issue is related to a code bug. If I have a question, I will use the Discord server. + required: true + - label: I understand that the developers have lives and my issue will be answered when possible. + required: true + - label: I understand the developers of this program are human, and I will ask my questions politely. + required: true + +- type: markdown + attributes: + value: | + ## Thanks! + Well-formatted issues improve TabbyAPI and make the development process smoother. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index e771a8d..0000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -name: Feature request -about: Suggest a new idea -title: "[REQUEST]" -labels: '' -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Why should this feature be added?** -An explanation of why the feature should be added. Please be as specific as possible to help us understand the reasoning. - -**Examples** -Examples of the feature in action and its significance compared to not having the feature. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yaml b/.github/ISSUE_TEMPLATE/feature_request.yaml new file mode 100644 index 0000000..9cf3d8b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yaml @@ -0,0 +1,69 @@ +name: Feature request +description: Suggest a new idea +title: "[REQUEST]" +body: + +- type: textarea + attributes: + label: Problem + description: Is the feature request related to a problem? If so, please describe. + placeholder: A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + validations: + required: false + +- type: textarea + attributes: + label: Solution + description: Describe the solution you'd like. + placeholder: A clear and concise description of what you want to happen. + validations: + required: true + +- type: textarea + attributes: + label: Alternatives + description: What alternative options did you consider? + validations: + required: false + +- type: textarea + attributes: + label: Explanation + description: Why should this feature be added? + validations: + required: true + +- type: textarea + attributes: + label: Examples + description: | + Examples of the feature in action and its significance. + + Not required, but will make your request easier to understand. + validations: + required: false + +- type: textarea + attributes: + label: Additional context + description: Anything else to add? + validations: + required: false + +- type: checkboxes + attributes: + label: Acknowledgements + description: Before submitting this issue, please make sure you have completed the following checklist. + options: + - label: I have looked for similar requests before submitting this one. + required: true + - label: I understand that the developers have lives and my issue will be answered when possible. + required: true + - label: I understand the developers of this program are human, and I will make my requests politely. + required: true + +- type: markdown + attributes: + value: | + ## Thanks! + Well-formatted issues improve TabbyAPI and make the development process smoother. From b6d2676f1ce7f0fdf8c744ac0e337cb9c53c0190 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 21:56:53 -0400 Subject: [PATCH 82/89] Start: Give the user a hint when a module can't be imported If an ImportError or ModuleNotFoundError is raised, tell the user to run the update scripts. Signed-off-by: kingbri --- backends/exllamav2/utils.py | 13 ++++++++----- common/templating.py | 2 +- start.py | 18 ++++++++++++++---- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index a5e8779..eaaf0e0 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -1,7 +1,8 @@ +import platform +import torch from packaging import version from importlib.metadata import PackageNotFoundError, version as package_version from loguru import logger -import torch def check_exllama_version(): @@ -13,8 +14,9 @@ def check_exllama_version(): unsupported_message = ( f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " f"or greater. Your current version is {current_version}.\n" - "Please upgrade your environment by running a start script " - "(start.bat or start.sh)\n\n" + "Please update your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'})\n\n" "Or you can manually run a requirements update " "using the following command:\n\n" "For CUDA 12.1:\n" @@ -71,8 +73,9 @@ def supports_paged_attn(): "Switching to compatibility mode. \n" "This disables parallel batching " "and features that rely on it (ex. CFG). \n" - "Please upgrade your environment by running a start script " - "(start.bat or start.sh)\n\n" + "Please upgrade your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'})\n\n" "Or you can manually run a requirements update " "using the following command:\n\n" "For CUDA 12.1:\n" diff --git a/common/templating.py b/common/templating.py index f742386..7a59946 100644 --- a/common/templating.py +++ b/common/templating.py @@ -51,7 +51,7 @@ class PromptTemplate: raise ImportError( "Parsing these chat completion messages requires jinja2 3.0.0 " f"or greater. Current version: {package_version('jinja2')}\n" - "Please upgrade jinja by running the following command: " + "Please update jinja by running the following command: " "pip install --upgrade jinja2" ) diff --git a/start.py b/start.py index 54c70b5..8356d3d 100644 --- a/start.py +++ b/start.py @@ -8,6 +8,7 @@ import platform import subprocess import sys from shutil import copyfile +import traceback from common.args import convert_args_to_dict, init_argparser @@ -231,9 +232,18 @@ if __name__ == "__main__": ) # Import entrypoint after installing all requirements - from main import entrypoint + try: + from main import entrypoint - converted_args = convert_args_to_dict(args, parser) + converted_args = convert_args_to_dict(args, parser) - print("Starting TabbyAPI...") - entrypoint(converted_args) + print("Starting TabbyAPI...") + entrypoint(converted_args) + except (ModuleNotFoundError, ImportError): + traceback.print_exc() + print( + "\n" + "This error was raised because a package was not found.\n" + "Update your dependencies by running update_scripts/" + f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'}\n\n" + ) From 6a0cfd731bcfe5170e00ffaa65909ec2ae9fbcf0 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 22:00:15 -0400 Subject: [PATCH 83/89] Main: Only import psutil when the experimental function is run Experimental options shouldn't be imported at the top level until the testing period is over. Signed-off-by: kingbri --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index bae2f98..5ed20f3 100644 --- a/main.py +++ b/main.py @@ -9,8 +9,6 @@ import signal from loguru import logger from typing import Optional -import psutil - from common import config, gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser from common.auth import load_auth_keys @@ -162,6 +160,8 @@ def entrypoint(arguments: Optional[dict] = None): # Set the process priority if unwrap(developer_config.get("realtime_process_priority"), False): + import psutil + current_process = psutil.Process(os.getpid()) if platform.system() == "Windows": current_process.nice(psutil.REALTIME_PRIORITY_CLASS) From 8ff2586d45e9bd4652bfabaff96a59a78627e6a8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 4 Aug 2024 10:13:19 -0400 Subject: [PATCH 84/89] Start: Fix pip update, method calls, and logging platform.system() was not called in some places, breaking the ternary on Windows. Pip's --upgrade flag does not actually update dependencies to their latest versions. That's what the --upgrade-strategy eager flag is for. Tell the user where their start preferences are coming from. Signed-off-by: kingbri --- backends/exllamav2/utils.py | 4 ++-- start.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index eaaf0e0..cf73be4 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -16,7 +16,7 @@ def check_exllama_version(): f"or greater. Your current version is {current_version}.\n" "Please update your environment by running an update script " "(update_scripts/" - f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'})\n\n" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" "Or you can manually run a requirements update " "using the following command:\n\n" "For CUDA 12.1:\n" @@ -75,7 +75,7 @@ def supports_paged_attn(): "and features that rely on it (ex. CFG). \n" "Please upgrade your environment by running an update script " "(update_scripts/" - f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'})\n\n" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" "Or you can manually run a requirements update " "using the following command:\n\n" "For CUDA 12.1:\n" diff --git a/start.py b/start.py index 8356d3d..e42f307 100644 --- a/start.py +++ b/start.py @@ -71,7 +71,7 @@ def get_install_features(lib_name: str = None): print( f"WARN: GPU library {lib_name} not found. " "Skipping GPU-specific dependencies.\n" - "WARN: Please delete gpu_lib.txt and restart " + "WARN: Please remove the `gpu_lib` key from start_options.json and restart " "if you want to change your selection." ) return @@ -154,6 +154,7 @@ if __name__ == "__main__": if start_options_path.exists(): with open(start_options_path) as start_options_file: start_options = json.load(start_options_file) + print("Loaded your saved preferences from `start_options.json`") if start_options.get("first_run_done"): first_run = False @@ -201,10 +202,12 @@ if __name__ == "__main__": features = f"[{install_features}]" if install_features else "" # pip install .[features] - install_command = f"pip install -U .{features}" + # Make sure to use eager upgrade strategy + # to push packages to their latest versions + install_command = f"pip install -U --upgrade-strategy eager .{features}" print(f"Running install command: {install_command}") subprocess.run(install_command.split(" ")) - print("\n") + print() if args.update_deps: print( @@ -245,5 +248,5 @@ if __name__ == "__main__": "\n" "This error was raised because a package was not found.\n" "Update your dependencies by running update_scripts/" - f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'}\n\n" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'}\n\n" ) From ab6c3a53b94f54bc05e5f5aab93775aaa1864df8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 4 Aug 2024 10:50:14 -0400 Subject: [PATCH 85/89] Start: Remove eager upgrade strategy This will upgrade second-level pinned dependencies to their latest versions which is not ideal. Signed-off-by: kingbri --- start.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/start.py b/start.py index e42f307..7eda6e1 100644 --- a/start.py +++ b/start.py @@ -202,9 +202,7 @@ if __name__ == "__main__": features = f"[{install_features}]" if install_features else "" # pip install .[features] - # Make sure to use eager upgrade strategy - # to push packages to their latest versions - install_command = f"pip install -U --upgrade-strategy eager .{features}" + install_command = f"pip install -U .{features}" print(f"Running install command: {install_command}") subprocess.run(install_command.split(" ")) print() From 34281c2e149d2d7a7d4dd611b789f839ea93833e Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 4 Aug 2024 11:14:38 -0400 Subject: [PATCH 86/89] Start: Add --force-reinstall argument Forces a reinstall of dependencies in the event that one is corrupted or broken. Signed-off-by: kingbri --- start.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/start.py b/start.py index 7eda6e1..490570e 100644 --- a/start.py +++ b/start.py @@ -108,6 +108,12 @@ def add_start_args(parser: argparse.ArgumentParser): action="store_true", help="Update all pip dependencies", ) + start_group.add_argument( + "-fr", + "--force-reinstall", + action="store_true", + help="Forces a reinstall of dependencies. Only works with --update-deps", + ) start_group.add_argument( "-nw", "--nowheel", @@ -198,13 +204,19 @@ if __name__ == "__main__": subprocess.run(pull_command.split(" ")) if first_run or args.update_deps: + install_command = ["pip", "install", "-U"] + + # Force a reinstall of the updated dependency if needed + if args.force_reinstall: + install_command.append("--force-reinstall") + install_features = None if args.nowheel else get_install_features(gpu_lib) - features = f"[{install_features}]" if install_features else "" + features = f".[{install_features}]" if install_features else "." + install_command.append(features) # pip install .[features] - install_command = f"pip install -U .{features}" - print(f"Running install command: {install_command}") - subprocess.run(install_command.split(" ")) + print(f"Running install command: {' '.join(install_command)}") + subprocess.run(install_command) print() if args.update_deps: From 63650d2c3c46e44f6690e0689ebf3aba331313be Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 5 Aug 2024 11:08:58 -0400 Subject: [PATCH 87/89] Model: Disable banned strings if grammar is used ExllamaV2 filters don't allow for rewinding which is what banned strings uses. Therefore, constrained generation via LMFE or outlines is not compatible for now. Signed-off-by: kingbri --- backends/exllamav2/model.py | 51 ++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3ded71b..98b5636 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1018,8 +1018,37 @@ class ExllamaV2Container: kwargs.get("repetition_decay"), fallback_decay, 0 ) - stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + # Initialize grammar handler + grammar_handler = ExLlamaV2Grammar() + + # Add JSON schema filter if it exists + json_schema = unwrap(kwargs.get("json_schema")) + if json_schema: + grammar_handler.add_json_schema_filter( + json_schema, self.model, self.tokenizer + ) + + # Add regex filter if it exists + regex_pattern = unwrap(kwargs.get("regex_pattern")) + if regex_pattern: + grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) + + # Add EBNF filter if it exists + grammar_string = unwrap(kwargs.get("grammar_string")) + if grammar_string: + grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) + + # Set banned strings banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) + if banned_strings and len(grammar_handler.filters) > 0: + logger.warning( + "Disabling banned_strings because " + "they cannot be used with grammar filters." + ) + + banned_strings = [] + + stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) logit_bias = kwargs.get("logit_bias") @@ -1067,26 +1096,6 @@ class ExllamaV2Container: "in the model's vocab. Skipping." ) - # Initialize grammar handler - grammar_handler = ExLlamaV2Grammar() - - # Add JSON schema filter if it exists - json_schema = unwrap(kwargs.get("json_schema")) - if json_schema: - grammar_handler.add_json_schema_filter( - json_schema, self.model, self.tokenizer - ) - - # Add regex filter if it exists - regex_pattern = unwrap(kwargs.get("regex_pattern")) - if regex_pattern: - grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) - - # Add EBNF filter if it exists - grammar_string = unwrap(kwargs.get("grammar_string")) - if grammar_string: - grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) - # Fetch EOS tokens from generation_config if they exist eos_tokens = ( self.generation_config.eos_tokens() From 685e3836e9e3c9afa683714eedbb03002ddbedd1 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 8 Aug 2024 16:32:29 -0400 Subject: [PATCH 88/89] Args: Add api-servers to parser Also run OpenAPI export after args/config are parsed. Signed-off-by: kingbri --- common/args.py | 6 ++++++ main.py | 19 ++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/common/args.py b/common/args.py index 0548eaf..a0f19c2 100644 --- a/common/args.py +++ b/common/args.py @@ -72,6 +72,12 @@ def add_network_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Decide whether to send error tracebacks over the API", ) + network_group.add_argument( + "--api-servers", + type=str, + nargs="+", + help="API servers to enable. Options: (OAI, Kobold)", + ) def add_model_args(parser: argparse.ArgumentParser): diff --git a/main.py b/main.py index 5ed20f3..b0c5108 100644 --- a/main.py +++ b/main.py @@ -110,15 +110,6 @@ def entrypoint(arguments: Optional[dict] = None): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - if do_export_openapi: - openapi_json = export_openapi() - - with open("openapi.json", "w") as f: - f.write(json.dumps(openapi_json)) - logger.info("Successfully wrote OpenAPI spec to openapi.json") - - return - # Load from YAML config config.from_file(pathlib.Path("config.yml")) @@ -128,6 +119,16 @@ def entrypoint(arguments: Optional[dict] = None): arguments = convert_args_to_dict(parser.parse_args(), parser) config.from_args(arguments) + + if do_export_openapi: + openapi_json = export_openapi() + + with open("openapi.json", "w") as f: + f.write(json.dumps(openapi_json)) + logger.info("Successfully wrote OpenAPI spec to openapi.json") + + return + developer_config = config.developer_config() # Check exllamav2 version and give a descriptive error if it's too old From 9cc0e700981684f4137ce51ac50a6d8c5db01a7e Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 8 Aug 2024 16:40:50 -0400 Subject: [PATCH 89/89] Actions: Build kobold docs subpage Signed-off-by: kingbri --- .github/workflows/pages.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index dd59121..a7b3327 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -47,12 +47,17 @@ jobs: run: | npm install @redocly/cli -g - name: Export OpenAPI docs - run: EXPORT_OPENAPI=1 python main.py + run: | + EXPORT_OPENAPI=1 python main.py + mv openapi.json openapi-oai.json + EXPORT_OPENAPI=1 python main.py --api-servers kobold + mv openapi.json openapi-kobold.json - name: Build and store Redocly site run: | - redocly build-docs openapi.json mkdir static - mv redoc-static.html static/index.html + mkdir static/kobold + redocly build-docs openapi-oai.json -o static/index.html + redocly build-docs openapi-kobold.json -o static/kobold/index.html - name: Setup Pages uses: actions/configure-pages@v5 - name: Upload artifact