Tree: Switch to async generators

Async generation helps remove many roadblocks to managing tasks
using threads. It should allow for abortables and modern-day paradigms.

NOTE: Exllamav2 itself is not an asynchronous library. It's just
been added into tabby's async nature to allow for a fast and concurrent
API server. It's still being debated to run stream_ex in a separate
thread or manually manage it using asyncio.sleep(0)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-14 10:27:39 -04:00
committed by Brian Dashore
parent 33e2df50b7
commit 7fded4f183
10 changed files with 84 additions and 88 deletions

View File

@@ -1,7 +1,6 @@
import pathlib
import uvicorn
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
@@ -10,7 +9,7 @@ from sys import maxsize
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key
from common.generators import (
from common.concurrency import (
call_with_semaphore,
generate_with_semaphore,
)
@@ -181,9 +180,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
load_callback = partial(
stream_model_load, request, data, model_path, draft_model_path
)
load_callback = partial(stream_model_load, data, model_path, draft_model_path)
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
@@ -333,9 +330,7 @@ async def load_lora(data: LoraLoadRequest):
"A parent lora directory does not exist. Check your config.yml?",
)
load_callback = partial(
run_in_threadpool, model.load_loras, lora_dir, **data.model_dump()
)
load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
@@ -409,9 +404,7 @@ async def completion_request(request: Request, data: CompletionRequest):
)
if data.stream and not disable_request_streaming:
generator_callback = partial(
stream_generate_completion, request, data, model_path
)
generator_callback = partial(stream_generate_completion, data, model_path)
return EventSourceResponse(
generate_with_semaphore(generator_callback),
@@ -452,7 +445,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
if data.stream and not disable_request_streaming:
generator_callback = partial(
stream_generate_chat_completion, prompt, request, data, model_path
stream_generate_chat_completion, prompt, data, model_path
)
return EventSourceResponse(
@@ -461,13 +454,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
)
else:
response = await call_with_semaphore(
partial(generate_chat_completion, prompt, request, data, model_path)
partial(generate_chat_completion, prompt, data, model_path)
)
return response
def start_api(host: str, port: int):
async def start_api(host: str, port: int):
"""Isolated function to start the API server"""
# TODO: Move OAI API to a separate folder
@@ -475,9 +468,12 @@ def start_api(host: str, port: int):
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
uvicorn.run(
config = uvicorn.Config(
app,
host=host,
port=port,
log_config=UVICORN_LOG_CONFIG,
)
server = uvicorn.Server(config)
await server.serve()