Tree: Switch to async generators

Async generation helps remove many roadblocks to managing tasks
using threads. It should allow for abortables and modern-day paradigms.

NOTE: Exllamav2 itself is not an asynchronous library. It's just
been added into tabby's async nature to allow for a fast and concurrent
API server. It's still being debated to run stream_ex in a separate
thread or manually manage it using asyncio.sleep(0)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-14 10:27:39 -04:00
committed by Brian Dashore
parent 33e2df50b7
commit 7fded4f183
10 changed files with 84 additions and 88 deletions

View File

@@ -1,18 +1,21 @@
"""Chat completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib
from typing import Optional
from uuid import uuid4
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi import HTTPException
from jinja2 import TemplateError
from loguru import logger
from common import model
from common.generators import release_semaphore
from common.templating import get_prompt_from_template
from common.utils import get_generator_error, handle_request_error, unwrap
from common.utils import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,
@@ -150,20 +153,14 @@ def format_prompt_with_template(data: ChatCompletionRequest):
async def stream_generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
):
"""Generator for the generation process."""
try:
const_id = f"chatcmpl-{uuid4().hex}"
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Chat completion generation cancelled by user.")
return
async for generation in new_generation:
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
@@ -172,6 +169,10 @@ async def stream_generate_chat_completion(
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
yield finish_response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Chat completion aborted. Please check the server console."
@@ -179,11 +180,10 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
):
try:
generation = await run_in_threadpool(
model.container.generate,
generation = await model.container.generate(
prompt,
**data.to_gen_params(),
)