API: Fix semaphore handling and chat completion errors

Chat completions previously always yielded a final packet to say that
a generation finished. However, this caused errors that a yield was
executed after GeneratorExit. This is correctly stated because python's
garbage collector can't clean up the generator after exiting due to the
finally block executing.

In addition, SSE endpoints close off the connection, so the finish packet
can only be yielded when the response has completed, so ignore yield on
exception.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-04 15:51:25 -05:00
parent 30fc5b3d29
commit 5e54911cc8

16
main.py
View File

@@ -201,8 +201,6 @@ async def generate_completion(request: Request, data: CompletionRequest):
model_path.name)
yield get_sse_packet(response.json(ensure_ascii=False))
except GeneratorExit:
print("Completion response aborted")
except Exception as e:
yield get_generator_error(e)
@@ -233,6 +231,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
const_id = f"chatcmpl-{uuid4().hex}"
async def generator():
try:
raise ValueError("Error!")
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
for (part, _, _) in new_generation:
if await request.is_disconnected():
@@ -244,20 +243,17 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
model_path.name
)
yield get_sse_packet(response.json(ensure_ascii=False))
except GeneratorExit:
print("Chat completion response aborted")
except Exception as e:
yield get_generator_error(e)
finally:
# Always finish no matter what
# FIXME: An error currently fires here since the generator is closed, move this somewhere else
yield get_sse_packet(response.json(ensure_ascii=False))
# Yield a finish response on successful generation
finish_response = create_chat_completion_stream_chunk(
const_id,
finish_reason = "stop"
)
yield get_sse_packet(finish_response.json(ensure_ascii=False))
except Exception as e:
yield get_generator_error(e)
return StreamingResponse(
generate_with_semaphore(generator),