mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
API: Back to async
According to FastAPI docs, if you're using a generic function, running it in async will make it more performant (which makes sense since running def functions for routes will automatically run the caller through a threadpool). Tested and everything works fine. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -76,7 +76,9 @@ def load_auth_keys(disable_from_config: bool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
|
async def check_api_key(
|
||||||
|
x_api_key: str = Header(None), authorization: str = Header(None)
|
||||||
|
):
|
||||||
"""Check if the API key is valid."""
|
"""Check if the API key is valid."""
|
||||||
|
|
||||||
# Allow request if auth is disabled
|
# Allow request if auth is disabled
|
||||||
@@ -102,7 +104,9 @@ def check_api_key(x_api_key: str = Header(None), authorization: str = Header(Non
|
|||||||
raise HTTPException(401, "Please provide an API key")
|
raise HTTPException(401, "Please provide an API key")
|
||||||
|
|
||||||
|
|
||||||
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
|
async def check_admin_key(
|
||||||
|
x_admin_key: str = Header(None), authorization: str = Header(None)
|
||||||
|
):
|
||||||
"""Check if the admin key is valid."""
|
"""Check if the admin key is valid."""
|
||||||
|
|
||||||
# Allow request if auth is disabled
|
# Allow request if auth is disabled
|
||||||
|
|||||||
36
main.py
36
main.py
@@ -92,7 +92,7 @@ app = FastAPI(
|
|||||||
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
||||||
|
|
||||||
|
|
||||||
def _check_model_container():
|
async def _check_model_container():
|
||||||
if MODEL_CONTAINER is None or not (
|
if MODEL_CONTAINER is None or not (
|
||||||
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
|
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
|
||||||
):
|
):
|
||||||
@@ -116,7 +116,7 @@ app.add_middleware(
|
|||||||
# Model list endpoint
|
# Model list endpoint
|
||||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||||
def list_models():
|
async def list_models():
|
||||||
"""Lists all models in the model directory."""
|
"""Lists all models in the model directory."""
|
||||||
model_config = get_model_config()
|
model_config = get_model_config()
|
||||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||||
@@ -140,7 +140,7 @@ def list_models():
|
|||||||
"/v1/internal/model/info",
|
"/v1/internal/model/info",
|
||||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def get_current_model():
|
async def get_current_model():
|
||||||
"""Returns the currently loaded model."""
|
"""Returns the currently loaded model."""
|
||||||
model_name = MODEL_CONTAINER.get_model_path().name
|
model_name = MODEL_CONTAINER.get_model_path().name
|
||||||
prompt_template = MODEL_CONTAINER.prompt_template
|
prompt_template = MODEL_CONTAINER.prompt_template
|
||||||
@@ -173,7 +173,7 @@ def get_current_model():
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||||
def list_draft_models():
|
async def list_draft_models():
|
||||||
"""Lists all draft models in the model directory."""
|
"""Lists all draft models in the model directory."""
|
||||||
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
|
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
|
||||||
draft_model_path = pathlib.Path(draft_model_dir)
|
draft_model_path = pathlib.Path(draft_model_dir)
|
||||||
@@ -225,7 +225,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
|
|
||||||
# Unload the existing model
|
# Unload the existing model
|
||||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||||
unload_model()
|
await unload_model()
|
||||||
|
|
||||||
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||||
|
|
||||||
@@ -235,7 +235,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
try:
|
try:
|
||||||
for module, modules in load_status:
|
for module, modules in load_status:
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
unload_model()
|
await unload_model()
|
||||||
break
|
break
|
||||||
|
|
||||||
if module == 0:
|
if module == 0:
|
||||||
@@ -293,7 +293,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
"/v1/model/unload",
|
"/v1/model/unload",
|
||||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def unload_model():
|
async def unload_model():
|
||||||
"""Unloads the currently loaded model."""
|
"""Unloads the currently loaded model."""
|
||||||
global MODEL_CONTAINER
|
global MODEL_CONTAINER
|
||||||
|
|
||||||
@@ -303,7 +303,7 @@ def unload_model():
|
|||||||
|
|
||||||
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||||
def get_templates():
|
async def get_templates():
|
||||||
templates = get_all_templates()
|
templates = get_all_templates()
|
||||||
template_strings = list(map(lambda template: template.stem, templates))
|
template_strings = list(map(lambda template: template.stem, templates))
|
||||||
return TemplateList(data=template_strings)
|
return TemplateList(data=template_strings)
|
||||||
@@ -313,7 +313,7 @@ def get_templates():
|
|||||||
"/v1/template/switch",
|
"/v1/template/switch",
|
||||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def switch_template(data: TemplateSwitchRequest):
|
async def switch_template(data: TemplateSwitchRequest):
|
||||||
"""Switch the currently loaded template"""
|
"""Switch the currently loaded template"""
|
||||||
if not data.name:
|
if not data.name:
|
||||||
raise HTTPException(400, "New template name not found.")
|
raise HTTPException(400, "New template name not found.")
|
||||||
@@ -329,7 +329,7 @@ def switch_template(data: TemplateSwitchRequest):
|
|||||||
"/v1/template/unload",
|
"/v1/template/unload",
|
||||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def unload_template():
|
async def unload_template():
|
||||||
"""Unloads the currently selected template"""
|
"""Unloads the currently selected template"""
|
||||||
|
|
||||||
MODEL_CONTAINER.prompt_template = None
|
MODEL_CONTAINER.prompt_template = None
|
||||||
@@ -338,7 +338,7 @@ def unload_template():
|
|||||||
# Sampler override endpoints
|
# Sampler override endpoints
|
||||||
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||||
def list_sampler_overrides():
|
async def list_sampler_overrides():
|
||||||
"""API wrapper to list all currently applied sampler overrides"""
|
"""API wrapper to list all currently applied sampler overrides"""
|
||||||
|
|
||||||
return get_sampler_overrides()
|
return get_sampler_overrides()
|
||||||
@@ -348,7 +348,7 @@ def list_sampler_overrides():
|
|||||||
"/v1/sampling/override/switch",
|
"/v1/sampling/override/switch",
|
||||||
dependencies=[Depends(check_admin_key)],
|
dependencies=[Depends(check_admin_key)],
|
||||||
)
|
)
|
||||||
def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||||
"""Switch the currently loaded override preset"""
|
"""Switch the currently loaded override preset"""
|
||||||
|
|
||||||
if data.preset:
|
if data.preset:
|
||||||
@@ -370,7 +370,7 @@ def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||||||
"/v1/sampling/override/unload",
|
"/v1/sampling/override/unload",
|
||||||
dependencies=[Depends(check_admin_key)],
|
dependencies=[Depends(check_admin_key)],
|
||||||
)
|
)
|
||||||
def unload_sampler_override():
|
async def unload_sampler_override():
|
||||||
"""Unloads the currently selected override preset"""
|
"""Unloads the currently selected override preset"""
|
||||||
|
|
||||||
set_overrides_from_dict({})
|
set_overrides_from_dict({})
|
||||||
@@ -379,7 +379,7 @@ def unload_sampler_override():
|
|||||||
# Lora list endpoint
|
# Lora list endpoint
|
||||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||||
def get_all_loras():
|
async def get_all_loras():
|
||||||
"""Lists all LoRAs in the lora directory."""
|
"""Lists all LoRAs in the lora directory."""
|
||||||
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||||
loras = get_lora_list(lora_path.resolve())
|
loras = get_lora_list(lora_path.resolve())
|
||||||
@@ -392,7 +392,7 @@ def get_all_loras():
|
|||||||
"/v1/lora",
|
"/v1/lora",
|
||||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def get_active_loras():
|
async def get_active_loras():
|
||||||
"""Returns the currently loaded loras."""
|
"""Returns the currently loaded loras."""
|
||||||
active_loras = LoraList(
|
active_loras = LoraList(
|
||||||
data=list(
|
data=list(
|
||||||
@@ -455,7 +455,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||||||
"/v1/lora/unload",
|
"/v1/lora/unload",
|
||||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def unload_loras():
|
async def unload_loras():
|
||||||
"""Unloads the currently loaded loras."""
|
"""Unloads the currently loaded loras."""
|
||||||
MODEL_CONTAINER.unload(loras_only=True)
|
MODEL_CONTAINER.unload(loras_only=True)
|
||||||
|
|
||||||
@@ -465,7 +465,7 @@ def unload_loras():
|
|||||||
"/v1/token/encode",
|
"/v1/token/encode",
|
||||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def encode_tokens(data: TokenEncodeRequest):
|
async def encode_tokens(data: TokenEncodeRequest):
|
||||||
"""Encodes a string into tokens."""
|
"""Encodes a string into tokens."""
|
||||||
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
|
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
|
||||||
tokens = unwrap(raw_tokens, [])
|
tokens = unwrap(raw_tokens, [])
|
||||||
@@ -479,7 +479,7 @@ def encode_tokens(data: TokenEncodeRequest):
|
|||||||
"/v1/token/decode",
|
"/v1/token/decode",
|
||||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def decode_tokens(data: TokenDecodeRequest):
|
async def decode_tokens(data: TokenDecodeRequest):
|
||||||
"""Decodes tokens into a string."""
|
"""Decodes tokens into a string."""
|
||||||
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
|
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
|
||||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||||
|
|||||||
Reference in New Issue
Block a user