diff --git a/common/utils.py b/common/utils.py index dd79e9a..bae567b 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,7 +1,9 @@ """Common utility functions""" +import asyncio import socket import traceback +from fastapi import Request from loguru import logger from pydantic import BaseModel from typing import Optional @@ -60,6 +62,34 @@ def handle_request_disconnect(message: str): logger.error(message) +async def request_disconnect_loop(request: Request): + """Polls for a starlette request disconnect.""" + + while not await request.is_disconnected(): + await asyncio.sleep(0.5) + + +async def run_with_request_disconnect( + request: Request, call_task: asyncio.Task, disconnect_message: str +): + """Utility function to cancel if a request is disconnected.""" + + _, unfinished = await asyncio.wait( + [ + call_task, + asyncio.create_task(request_disconnect_loop(request)), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in unfinished: + task.cancel() + + try: + return call_task.result() + except (asyncio.CancelledError, asyncio.InvalidStateError): + handle_request_disconnect(disconnect_message) + + def unwrap(wrapped, default=None): """Unwrap function for Optionals.""" if wrapped is None: diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index 28ee9d1..0f2e0e9 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -1,3 +1,4 @@ +import asyncio import pathlib import signal import uvicorn @@ -25,6 +26,7 @@ from common.templating import ( from common.utils import ( coalesce, handle_request_error, + run_with_request_disconnect, unwrap, ) from endpoints.OAI.types.auth import AuthPermissionResponse @@ -452,10 +454,15 @@ async def completion_request(request: Request, data: CompletionRequest): ping=maxsize, ) else: - response = await call_with_semaphore( - partial(generate_completion, data, model_path) + generate_task = asyncio.create_task( + call_with_semaphore(partial(generate_completion, data, model_path)) ) + response = await run_with_request_disconnect( + request, + generate_task, + disconnect_message="Completion generation cancelled by user.", + ) return response @@ -494,10 +501,17 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) ping=maxsize, ) else: - response = await call_with_semaphore( - partial(generate_chat_completion, prompt, data, model_path) + generate_task = asyncio.create_task( + call_with_semaphore( + partial(generate_chat_completion, prompt, data, model_path) + ) ) + response = await run_with_request_disconnect( + request, + generate_task, + disconnect_message="Chat completion generation cancelled by user.", + ) return response