From 5e54911cc81a9d55f4146c48840e452e6d98d676 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 4 Dec 2023 15:51:25 -0500 Subject: [PATCH] API: Fix semaphore handling and chat completion errors Chat completions previously always yielded a final packet to say that a generation finished. However, this caused errors that a yield was executed after GeneratorExit. This is correctly stated because python's garbage collector can't clean up the generator after exiting due to the finally block executing. In addition, SSE endpoints close off the connection, so the finish packet can only be yielded when the response has completed, so ignore yield on exception. Signed-off-by: kingbri --- main.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index d77e3a4..204cfa5 100644 --- a/main.py +++ b/main.py @@ -201,8 +201,6 @@ async def generate_completion(request: Request, data: CompletionRequest): model_path.name) yield get_sse_packet(response.json(ensure_ascii=False)) - except GeneratorExit: - print("Completion response aborted") except Exception as e: yield get_generator_error(e) @@ -233,6 +231,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest const_id = f"chatcmpl-{uuid4().hex}" async def generator(): try: + raise ValueError("Error!") new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) for (part, _, _) in new_generation: if await request.is_disconnected(): @@ -244,20 +243,17 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest model_path.name ) - yield get_sse_packet(response.json(ensure_ascii=False)) - except GeneratorExit: - print("Chat completion response aborted") - except Exception as e: - yield get_generator_error(e) - finally: - # Always finish no matter what - # FIXME: An error currently fires here since the generator is closed, move this somewhere else + yield get_sse_packet(response.json(ensure_ascii=False)) + + # Yield a finish response on successful generation finish_response = create_chat_completion_stream_chunk( const_id, finish_reason = "stop" ) yield get_sse_packet(finish_response.json(ensure_ascii=False)) + except Exception as e: + yield get_generator_error(e) return StreamingResponse( generate_with_semaphore(generator),