API + Model: Add blocks and checks for various load requests

Add a sequential lock and wait until jobs are completed before executing
any loading requests that directly alter the model. However, we also
need to block any new requests that come in until the load is finished,
so add a condition that triggers once the lock is free.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-05-25 18:24:11 -04:00
committed by Brian Dashore
parent 408c66a1f2
commit 43cd7f57e8
5 changed files with 268 additions and 249 deletions

View File

@@ -1,18 +1,12 @@
import asyncio
import pathlib
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from functools import partial
from loguru import logger
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.concurrency import (
call_with_semaphore,
generate_with_semaphore,
)
from common.downloader import hf_repo_download
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
@@ -141,7 +135,7 @@ async def list_draft_models():
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(request: Request, data: ModelLoadRequest):
async def load_model(data: ModelLoadRequest):
"""Loads a model into the model container."""
# Verify request parameters
@@ -178,18 +172,9 @@ async def load_model(request: Request, data: ModelLoadRequest):
raise HTTPException(400, error_message)
load_callback = partial(stream_model_load, data, model_path, draft_model_path)
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
logger.warning(
"Model load request is skipping the completions queue. "
"Unexpected results may occur."
)
else:
load_callback = partial(generate_with_semaphore, load_callback)
return EventSourceResponse(load_callback(), ping=maxsize)
return EventSourceResponse(
stream_model_load(data, model_path, draft_model_path), ping=maxsize
)
# Unload model endpoint
@@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
async def unload_model():
"""Unloads the currently loaded model."""
await model.unload_model()
await model.unload_model(skip_wait=True)
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@@ -335,15 +320,13 @@ async def get_all_loras():
async def get_active_loras():
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=list(
map(
lambda lora: LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
),
model.container.active_loras,
data=[
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
)
for lora in model.container.get_loras()
]
)
return active_loras
@@ -374,18 +357,9 @@ async def load_lora(data: LoraLoadRequest):
raise HTTPException(400, error_message)
load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
logger.warning(
"Lora load request is skipping the completions queue. "
"Unexpected results may occur."
)
else:
load_callback = partial(call_with_semaphore, load_callback)
load_result = await load_callback()
load_result = await model.load_loras(
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
)
return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
@@ -401,7 +375,7 @@ async def load_lora(data: LoraLoadRequest):
async def unload_loras():
"""Unloads the currently loaded loras."""
model.unload_loras()
await model.unload_loras()
# Encode tokens endpoint
@@ -494,16 +468,12 @@ async def completion_request(request: Request, data: CompletionRequest):
data.json_schema = {"type": "object"}
if data.stream and not disable_request_streaming:
generator_callback = partial(stream_generate_completion, data, model_path)
return EventSourceResponse(
generate_with_semaphore(generator_callback),
stream_generate_completion(data, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
call_with_semaphore(partial(generate_completion, data, model_path))
)
generate_task = asyncio.create_task(generate_completion(data, model_path))
response = await run_with_request_disconnect(
request,
@@ -545,19 +515,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
)
if data.stream and not disable_request_streaming:
generator_callback = partial(
stream_generate_chat_completion, prompt, data, model_path
)
return EventSourceResponse(
generate_with_semaphore(generator_callback),
stream_generate_chat_completion(prompt, data, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
call_with_semaphore(
partial(generate_chat_completion, prompt, data, model_path)
)
generate_chat_completion(prompt, data, model_path)
)
response = await run_with_request_disconnect(