diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index d6bd159..09303b0 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from functools import partial from loguru import logger from sse_starlette import EventSourceResponse +from sys import maxsize from common import config, model, gen_logging, sampling from common.auth import check_admin_key, check_api_key @@ -193,7 +194,7 @@ async def load_model(request: Request, data: ModelLoadRequest): else: load_callback = partial(generate_with_semaphore, load_callback) - return EventSourceResponse(load_callback()) + return EventSourceResponse(load_callback(), ping=maxsize) # Unload model endpoint @@ -412,7 +413,10 @@ async def completion_request(request: Request, data: CompletionRequest): stream_generate_completion, request, data, model_path ) - return EventSourceResponse(generate_with_semaphore(generator_callback)) + return EventSourceResponse( + generate_with_semaphore(generator_callback), + ping=maxsize, + ) else: response = await call_with_semaphore( partial(generate_completion, data, model_path) @@ -451,7 +455,10 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) stream_generate_chat_completion, prompt, request, data, model_path ) - return EventSourceResponse(generate_with_semaphore(generator_callback)) + return EventSourceResponse( + generate_with_semaphore(generator_callback), + ping=maxsize, + ) else: response = await call_with_semaphore( partial(generate_chat_completion, prompt, request, data, model_path)