diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 1e05d04..c169b9c 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -2,6 +2,7 @@ import gc import pathlib +import threading import time import torch @@ -623,14 +624,18 @@ class ExllamaV2Container: return kwargs - async def generate_gen(self, prompt: str, **kwargs): + async def generate_gen( + self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs + ): """Basic async wrapper for completion generator""" - sync_generator = self.generate_gen_sync(prompt, **kwargs) + sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs) async for value in iterate_in_threadpool(sync_generator): yield value - def generate_gen_sync(self, prompt: str, **kwargs): + def generate_gen_sync( + self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs + ): """ Create generator function for prompt completion. @@ -893,6 +898,7 @@ class ExllamaV2Container: return_probabilities=request_logprobs > 0, return_top_tokens=request_logprobs, return_logits=request_logprobs > 0, + abort_event=abort_event, ) else: self.generator.begin_stream_ex( @@ -903,6 +909,7 @@ class ExllamaV2Container: return_probabilities=request_logprobs > 0, return_top_tokens=request_logprobs, return_logits=request_logprobs > 0, + abort_event=abort_event, ) # Reset offsets for subsequent passes if the context is truncated diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index e33ec7f..562e736 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -2,6 +2,7 @@ from asyncio import CancelledError import pathlib +import threading from typing import Optional from uuid import uuid4 @@ -161,8 +162,11 @@ async def stream_generate_chat_completion( """Generator for the generation process.""" try: const_id = f"chatcmpl-{uuid4().hex}" + abort_event = threading.Event() - new_generation = model.container.generate_gen(prompt, **data.to_gen_params()) + new_generation = model.container.generate_gen( + prompt, abort_event, **data.to_gen_params() + ) async for generation in new_generation: response = _create_stream_chunk(const_id, generation, model_path.name) @@ -174,6 +178,7 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected + abort_event.set() handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index a98d1b7..02b7852 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -2,6 +2,7 @@ import pathlib from asyncio import CancelledError +import threading from fastapi import HTTPException from typing import Optional @@ -64,8 +65,10 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli """Streaming generation for completions.""" try: + abort_event = threading.Event() + new_generation = model.container.generate_gen( - data.prompt, **data.to_gen_params() + data.prompt, abort_event, **data.to_gen_params() ) async for generation in new_generation: response = _create_response(generation, model_path.name) @@ -78,6 +81,7 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli except CancelledError: # Get out if the request gets disconnected + abort_event.set() handle_request_disconnect("Completion generation cancelled by user.") except Exception: yield get_generator_error(