From ed6c962aadb8ae74a21520631cf3027ee9057d1c Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 3 Dec 2023 22:54:34 -0500 Subject: [PATCH] API: Fix sequential requests FastAPI is kinda weird with queueing. If an await is used within an async def, requests aren't executed sequentially. Get the sequential requests back by using a semaphore to limit concurrent execution from generator functions. Also scaffold the framework to move generator functions to their own file. Signed-off-by: kingbri --- generators.py | 10 ++++++++++ main.py | 21 ++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 generators.py diff --git a/generators.py b/generators.py new file mode 100644 index 0000000..287835c --- /dev/null +++ b/generators.py @@ -0,0 +1,10 @@ +from asyncio import Semaphore +from typing import AsyncGenerator + +generate_semaphore = Semaphore(1) + +# Async generation that blocks on a semaphore +async def generate_with_semaphore(generator: AsyncGenerator): + async with generate_semaphore: + async for result in generator(): + yield result diff --git a/main.py b/main.py index 16ed0b0..d77e3a4 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from model import ModelContainer from progress.bar import IncrementalBar +from generators import generate_with_semaphore from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse @@ -200,10 +201,15 @@ async def generate_completion(request: Request, data: CompletionRequest): model_path.name) yield get_sse_packet(response.json(ensure_ascii=False)) + except GeneratorExit: + print("Completion response aborted") except Exception as e: yield get_generator_error(e) - return StreamingResponse(generator(), media_type = "text/event-stream") + return StreamingResponse( + generate_with_semaphore(generator), + media_type = "text/event-stream" + ) else: response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params()) response = create_completion_response(response_text, @@ -238,12 +244,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest model_path.name ) - yield get_sse_packet(response.json(ensure_ascii=False)) + yield get_sse_packet(response.json(ensure_ascii=False)) + except GeneratorExit: + print("Chat completion response aborted") except Exception as e: yield get_generator_error(e) finally: - # Always finish no matter what + # FIXME: An error currently fires here since the generator is closed, move this somewhere else finish_response = create_chat_completion_stream_chunk( const_id, finish_reason = "stop" @@ -251,7 +259,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest yield get_sse_packet(finish_response.json(ensure_ascii=False)) - return StreamingResponse(generator(), media_type = "text/event-stream") + return StreamingResponse( + generate_with_semaphore(generator), + media_type = "text/event-stream" + ) else: response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params()) response = create_chat_completion_response(response_text, @@ -283,7 +294,7 @@ if __name__ == "__main__": if "model_name" in model_config: model_path = pathlib.Path(model_config.get("model_dir") or "models") model_path = model_path / model_config.get("model_name") - + model_container = ModelContainer(model_path.resolve(), False, **model_config) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: