diff --git a/common/args.py b/common/args.py index 736b360..a0745ae 100644 --- a/common/args.py +++ b/common/args.py @@ -135,3 +135,8 @@ def add_developer_args(parser: argparse.ArgumentParser): developer_group.add_argument( "--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check" ) + developer_group.add_argument( + "--disable-request-streaming", + type=str_to_bool, + help="Disables API request streaming", + ) diff --git a/config_sample.yml b/config_sample.yml index c50111c..23b7191 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -42,6 +42,10 @@ developer: # WARNING: Don't set this unless you know what you're doing! #unsafe_launch: False + # Disable all request streaming (default: False) + # A kill switch for turning off SSE in the API server + #disable_request_streaming: False + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/main.py b/main.py index c47a9fc..eaf045a 100644 --- a/main.py +++ b/main.py @@ -449,7 +449,11 @@ async def generate_completion(request: Request, data: CompletionRequest): if isinstance(data.prompt, list): data.prompt = "\n".join(data.prompt) - if data.stream: + disable_request_streaming = unwrap( + get_developer_config().get("disable_request_streaming"), False + ) + + if data.stream and not disable_request_streaming: async def generator(): """Generator for the generation process.""" @@ -531,7 +535,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest f"TemplateError: {str(exc)}", ) from exc - if data.stream: + disable_request_streaming = unwrap( + get_developer_config().get("disable_request_streaming"), False + ) + + if data.stream and not disable_request_streaming: const_id = f"chatcmpl-{uuid4().hex}" async def generator():