Tree: Switch to async generators

Async generation helps remove many roadblocks to managing tasks
using threads. It should allow for abortables and modern-day paradigms.

NOTE: Exllamav2 itself is not an asynchronous library. It's just
been added into tabby's async nature to allow for a fast and concurrent
API server. It's still being debated to run stream_ex in a separate
thread or manually manage it using asyncio.sleep(0)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-14 10:27:39 -04:00
committed by Brian Dashore
parent 33e2df50b7
commit 7fded4f183
10 changed files with 84 additions and 88 deletions

View File

@@ -1,18 +1,21 @@
"""Chat completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib
from typing import Optional
from uuid import uuid4
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi import HTTPException
from jinja2 import TemplateError
from loguru import logger
from common import model
from common.generators import release_semaphore
from common.templating import get_prompt_from_template
from common.utils import get_generator_error, handle_request_error, unwrap
from common.utils import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,
@@ -150,20 +153,14 @@ def format_prompt_with_template(data: ChatCompletionRequest):
async def stream_generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
):
"""Generator for the generation process."""
try:
const_id = f"chatcmpl-{uuid4().hex}"
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Chat completion generation cancelled by user.")
return
async for generation in new_generation:
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
@@ -172,6 +169,10 @@ async def stream_generate_chat_completion(
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
yield finish_response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Chat completion aborted. Please check the server console."
@@ -179,11 +180,10 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
):
try:
generation = await run_in_threadpool(
model.container.generate,
generation = await model.container.generate(
prompt,
**data.to_gen_params(),
)

View File

@@ -1,14 +1,17 @@
"""Completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from fastapi import HTTPException
from typing import Optional
from common import model
from common.generators import release_semaphore
from common.utils import get_generator_error, handle_request_error, unwrap
from common.utils import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,
@@ -57,28 +60,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
return response
async def stream_generate_completion(
request: Request, data: CompletionRequest, model_path: pathlib.Path
):
async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Streaming generation for completions."""
try:
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return
async for generation in new_generation:
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Yield a finish response on successful generation
yield "[DONE]"
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect("Completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
@@ -89,9 +88,7 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
"""Non-streaming generate for completions"""
try:
generation = await run_in_threadpool(
model.container.generate, data.prompt, **data.to_gen_params()
)
generation = await model.container.generate(data.prompt, **data.to_gen_params())
response = _create_response(generation, model_path.name)
return response

View File

@@ -9,6 +9,6 @@ def get_lora_list(lora_path: pathlib.Path):
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card) # pylint: disable=no-member
lora_list.data.append(lora_card)
return lora_list

View File

@@ -1,12 +1,9 @@
import pathlib
from asyncio import CancelledError
from fastapi import Request
from loguru import logger
from typing import Optional
from common import model
from common.generators import release_semaphore
from common.utils import get_generator_error
from common.utils import get_generator_error, handle_request_disconnect
from endpoints.OAI.types.model import (
ModelCard,
@@ -35,7 +32,6 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
async def stream_model_load(
request: Request,
data: ModelLoadRequest,
model_path: pathlib.Path,
draft_model_path: str,
@@ -50,14 +46,6 @@ async def stream_model_load(
load_status = model.load_model_gen(model_path, **load_data)
try:
async for module, modules, model_type in load_status:
if await request.is_disconnected():
release_semaphore()
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
return
if module != 0:
response = ModelLoadResponse(
model_type=model_type,
@@ -78,7 +66,9 @@ async def stream_model_load(
yield response.model_dump_json()
except CancelledError:
logger.error(
# Get out if the request gets disconnected
handle_request_disconnect(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)