mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-25 16:59:09 +00:00
OAI: Fix request cancellation behavior
Depending on the day of the week, Starlette can work with a CancelledError or using await request.is_disconnected(). Run the same behavior for both cases and allow cancellation. Streaming requests now set an event to cancel the batched job and break out of the generation loop. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
@@ -192,14 +193,24 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
abort_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
if await request.is_disconnected():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
@@ -210,6 +221,7 @@ async def stream_generate_chat_completion(
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import Optional
|
||||
|
||||
from common import model
|
||||
@@ -60,14 +61,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
|
||||
return response
|
||||
|
||||
|
||||
async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
async def stream_generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Streaming generation for completions."""
|
||||
|
||||
abort_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
data.prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
if await request.is_disconnected():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
@@ -78,6 +89,7 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
||||
Reference in New Issue
Block a user