diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py index f08e758..0d1c166 100644 --- a/endpoints/Kobold/utils/generation.py +++ b/endpoints/Kobold/utils/generation.py @@ -61,10 +61,7 @@ async def _stream_collector(data: GenerateRequest, request: Request): async for generation in generator: if disconnect_task.done(): - abort_event.set() - handle_request_disconnect( - f"Kobold generation {data.genkey} cancelled by user." - ) + raise CancelledError() text = generation.get("text") @@ -78,7 +75,7 @@ async def _stream_collector(data: GenerateRequest, request: Request): break except CancelledError: # If the request disconnects, break out - if not disconnect_task.done(): + if not abort_event.is_set(): abort_event.set() handle_request_disconnect( f"Kobold generation {data.genkey} cancelled by user." diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 4a6c210..b559bb2 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -348,10 +348,7 @@ async def stream_generate_chat_completion( # Consumer loop while True: if disconnect_task.done(): - abort_event.set() - handle_request_disconnect( - f"Chat completion generation {request.state.id} cancelled by user." - ) + raise CancelledError() generation = await gen_queue.get() @@ -401,7 +398,7 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - if not disconnect_task.done(): + if not abort_event.is_set(): abort_event.set() handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index ca51c9c..f66d381 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -226,10 +226,7 @@ async def stream_generate_completion( # Consumer loop while True: if disconnect_task.done(): - abort_event.set() - handle_request_disconnect( - f"Completion generation {request.state.id} cancelled by user." - ) + raise CancelledError() generation = await gen_queue.get() @@ -248,7 +245,7 @@ async def stream_generate_completion( except CancelledError: # Get out if the request gets disconnected - if not disconnect_task.done(): + if not abort_event.is_set(): abort_event.set() handle_request_disconnect( f"Completion generation {request.state.id} cancelled by user."