mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 02:01:24 +00:00
API + Model: Add blocks and checks for various load requests
Add a sequential lock and wait until jobs are completed before executing any loading requests that directly alter the model. However, we also need to block any new requests that come in until the load is finished, so add a condition that triggers once the lock is free. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,18 +1,12 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from functools import partial
|
||||
from loguru import logger
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
from typing import Optional
|
||||
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common.auth import check_admin_key, check_api_key, validate_key_permission
|
||||
from common.concurrency import (
|
||||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
)
|
||||
from common.downloader import hf_repo_download
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
@@ -141,7 +135,7 @@ async def list_draft_models():
|
||||
|
||||
# Load model endpoint
|
||||
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(request: Request, data: ModelLoadRequest):
|
||||
async def load_model(data: ModelLoadRequest):
|
||||
"""Loads a model into the model container."""
|
||||
|
||||
# Verify request parameters
|
||||
@@ -178,18 +172,9 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
load_callback = partial(stream_model_load, data, model_path, draft_model_path)
|
||||
|
||||
# Wrap in a semaphore if the queue isn't being skipped
|
||||
if data.skip_queue:
|
||||
logger.warning(
|
||||
"Model load request is skipping the completions queue. "
|
||||
"Unexpected results may occur."
|
||||
)
|
||||
else:
|
||||
load_callback = partial(generate_with_semaphore, load_callback)
|
||||
|
||||
return EventSourceResponse(load_callback(), ping=maxsize)
|
||||
return EventSourceResponse(
|
||||
stream_model_load(data, model_path, draft_model_path), ping=maxsize
|
||||
)
|
||||
|
||||
|
||||
# Unload model endpoint
|
||||
@@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||
)
|
||||
async def unload_model():
|
||||
"""Unloads the currently loaded model."""
|
||||
await model.unload_model()
|
||||
await model.unload_model(skip_wait=True)
|
||||
|
||||
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@@ -335,15 +320,13 @@ async def get_all_loras():
|
||||
async def get_active_loras():
|
||||
"""Returns the currently loaded loras."""
|
||||
active_loras = LoraList(
|
||||
data=list(
|
||||
map(
|
||||
lambda lora: LoraCard(
|
||||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
),
|
||||
model.container.active_loras,
|
||||
data=[
|
||||
LoraCard(
|
||||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
)
|
||||
)
|
||||
for lora in model.container.get_loras()
|
||||
]
|
||||
)
|
||||
|
||||
return active_loras
|
||||
@@ -374,18 +357,9 @@ async def load_lora(data: LoraLoadRequest):
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
|
||||
|
||||
# Wrap in a semaphore if the queue isn't being skipped
|
||||
if data.skip_queue:
|
||||
logger.warning(
|
||||
"Lora load request is skipping the completions queue. "
|
||||
"Unexpected results may occur."
|
||||
)
|
||||
else:
|
||||
load_callback = partial(call_with_semaphore, load_callback)
|
||||
|
||||
load_result = await load_callback()
|
||||
load_result = await model.load_loras(
|
||||
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
|
||||
)
|
||||
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(load_result.get("success"), []),
|
||||
@@ -401,7 +375,7 @@ async def load_lora(data: LoraLoadRequest):
|
||||
async def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
|
||||
model.unload_loras()
|
||||
await model.unload_loras()
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@@ -494,16 +468,12 @@ async def completion_request(request: Request, data: CompletionRequest):
|
||||
data.json_schema = {"type": "object"}
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
generator_callback = partial(stream_generate_completion, data, model_path)
|
||||
|
||||
return EventSourceResponse(
|
||||
generate_with_semaphore(generator_callback),
|
||||
stream_generate_completion(data, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
call_with_semaphore(partial(generate_completion, data, model_path))
|
||||
)
|
||||
generate_task = asyncio.create_task(generate_completion(data, model_path))
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
@@ -545,19 +515,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
generator_callback = partial(
|
||||
stream_generate_chat_completion, prompt, data, model_path
|
||||
)
|
||||
|
||||
return EventSourceResponse(
|
||||
generate_with_semaphore(generator_callback),
|
||||
stream_generate_chat_completion(prompt, data, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
call_with_semaphore(
|
||||
partial(generate_chat_completion, prompt, data, model_path)
|
||||
)
|
||||
generate_chat_completion(prompt, data, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
|
||||
Reference in New Issue
Block a user