From 3038f668e8bcfa4bc1579b1fc3f9b432f905729b Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 22:50:01 -0400 Subject: [PATCH] Kobold: Add extra routes for horde compatability Needed to connect to horde. Also do some reordering to clean the router file up. Signed-off-by: kingbri --- endpoints/Kobold/router.py | 50 +++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index ff0ec8c..334bae2 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -6,12 +6,15 @@ from common import model from common.auth import check_api_key from common.model import check_model_container from common.utils import unwrap +from endpoints.core.utils.model import get_current_model from endpoints.Kobold.types.generation import ( AbortRequest, + AbortResponse, CheckGenerateRequest, GenerateRequest, GenerateResponse, ) +from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse from endpoints.Kobold.utils.generation import ( abort_generation, @@ -19,7 +22,6 @@ from endpoints.Kobold.utils.generation import ( get_generation, stream_generation, ) -from endpoints.core.utils.model import get_current_model api_name = "KoboldAI" @@ -65,7 +67,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe "/abort", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def abort_generate(data: AbortRequest): +async def abort_generate(data: AbortRequest) -> AbortResponse: response = await abort_generation(data.genkey) return response @@ -88,7 +90,7 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: @kai_router.get( "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] ) -async def current_model(): +async def current_model() -> CurrentModelResponse: """Fetches the current model and who owns it.""" current_model_card = get_current_model() @@ -99,12 +101,31 @@ async def current_model(): "/tokencount", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def get_tokencount(data: TokenCountRequest): +async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: raw_tokens = model.container.encode_tokens(data.prompt) tokens = unwrap(raw_tokens, []) return TokenCountResponse(value=len(tokens), ids=tokens) +@kai_router.get( + "/config/max_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@kai_router.get( + "/config/max_context_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@extra_kai_router.get( + "/true_max_context_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def get_max_length() -> MaxLengthResponse: + """Fetches the max length of the model.""" + + max_length = model.container.get_model_parameters().get("max_seq_len") + return {"value": max_length} + + @kai_router.get("/info/version") async def get_version(): """Impersonate KAI United.""" @@ -117,3 +138,24 @@ async def get_extra_version(): """Impersonate Koboldcpp.""" return {"result": "KoboldCpp", "version": "1.61"} + + +@kai_router.get("/config/soft_prompts_list") +async def get_available_softprompts(): + """Used for KAI compliance.""" + + return {"values": []} + + +@kai_router.get("/config/soft_prompt") +async def get_current_softprompt(): + """Used for KAI compliance.""" + + return {"value": ""} + + +@kai_router.put("/config/soft_prompt") +async def set_current_softprompt(): + """Used for KAI compliance.""" + + return {}