From 9157be3e34ab88cb46b0b0a8e935e96b4691d1fb Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:29:48 -0400 Subject: [PATCH] API: Append task index to generations with n > 1 Since jobs are tracked via request IDs now, each generation task should be uniquely identified in the event of cancellation. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- endpoints/OAI/utils/chat_completion.py | 25 ++++++++++++++++++------- endpoints/OAI/utils/completion.py | 23 ++++++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index dcc0dea..d2fab92 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -29,7 +29,7 @@ from endpoints.OAI.types.chat_completion import ( ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats -from endpoints.OAI.utils.completion import _stream_collector +from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector from endpoints.OAI.utils.tools import ToolCallProcessor @@ -326,14 +326,17 @@ async def stream_generate_chat_completion( try: logger.info(f"Received chat completion streaming request {request.state.id}") - for n in range(0, data.n): + for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) + request_id = _parse_gen_request_id( + data.n, request.state.id, idx + ) gen_task = asyncio.create_task( _stream_collector( - n, + idx, gen_queue, - request.state.id, + request_id, prompt, task_gen_params, abort_event, @@ -418,11 +421,15 @@ async def generate_chat_completion( gen_tasks: List[asyncio.Task] = [] try: - for _ in range(0, data.n): + for idx in range(0, data.n): + request_id = _parse_gen_request_id( + data.n, request.state.id, idx + ) + gen_tasks.append( asyncio.create_task( model.container.generate( - request.state.id, + request_id, prompt, data, mm_embeddings=embeddings, @@ -484,10 +491,14 @@ async def generate_tool_calls( data, current_generations ) + request_id = _parse_gen_request_id( + data.n, request.state.id, idx + ) + gen_tasks.append( asyncio.create_task( model.container.generate( - request.state.id, + request_id, pre_tool_prompt, tool_data, embeddings=mm_embeddings, diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 8be249c..01b6276 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -31,6 +31,13 @@ from endpoints.OAI.types.completion import ( from endpoints.OAI.types.common import UsageStats +def _parse_gen_request_id(n: int, request_id: str, task_idx: int): + if n > 1: + return f"{request_id}-{task_idx}" + else: + return request_id + + def _create_response( request_id: str, generations: Union[dict, List[dict]], model_name: str = "" ): @@ -193,14 +200,17 @@ async def stream_generate_completion( try: logger.info(f"Received streaming completion request {request.state.id}") - for n in range(0, data.n): + for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) + request_id = _parse_gen_request_id( + data.n, request.state.id, idx + ) gen_task = asyncio.create_task( _stream_collector( - n, + idx, gen_queue, - request.state.id, + request_id, data.prompt, task_gen_params, abort_event, @@ -255,13 +265,16 @@ async def generate_completion( try: logger.info(f"Recieved completion request {request.state.id}") - for _ in range(0, data.n): + for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) + request_id = _parse_gen_request_id( + data.n, request.state.id, idx + ) gen_tasks.append( asyncio.create_task( model.container.generate( - request.state.id, + request_id, data.prompt, task_gen_params, )