diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 61726e2..6f3521c 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -96,11 +96,13 @@ def _create_stream_chunk( ): """Create a chat completion stream chunk from the provided text.""" + index = generation.get("index") logprob_response = None if "finish_reason" in generation: choice = ChatCompletionStreamChoice( - finish_reason=generation.get("finish_reason") + index=index, + finish_reason=generation.get("finish_reason"), ) else: message = ChatCompletionMessage( @@ -125,6 +127,7 @@ def _create_stream_chunk( logprob_response = ChatCompletionLogprobs(content=[token_prob_response]) choice = ChatCompletionStreamChoice( + index=index, delta=message, logprobs=logprob_response, ) @@ -199,34 +202,62 @@ def format_prompt_with_template(data: ChatCompletionRequest): raise HTTPException(400, error_message) from exc +async def _stream_collector( + task_idx: int, + gen_queue: asyncio.Queue, + prompt: str, + abort_event: asyncio.Event, + **kwargs, +): + """Collects a stream and places results in a common queue""" + + new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) + async for generation in new_generation: + generation["index"] = task_idx + + await gen_queue.put(generation) + + if "finish_reason" in generation: + break + + async def stream_generate_chat_completion( prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): """Generator for the generation process.""" + const_id = f"chatcmpl-{uuid4().hex}" abort_event = asyncio.Event() + gen_queue = asyncio.Queue() + gen_tasks: List[asyncio.Task] = [] + disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: - const_id = f"chatcmpl-{uuid4().hex}" + gen_params = data.to_gen_params() - new_generation = model.container.generate_gen( - prompt, abort_event, **data.to_gen_params() - ) - # Create a background task to avoid blocking the loop - disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + for n in range(0, data.n): + if n > 0: + task_gen_params = deepcopy(gen_params) + else: + task_gen_params = gen_params - async for generation in new_generation: - # Sometimes this fires, and sometimes a CancelledError will fire - # Keep both implementations in to avoid the headache + gen_task = asyncio.create_task( + _stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params) + ) + + gen_tasks.append(gen_task) + + # Consumer loop + while True: if disconnect_task.done(): abort_event.set() handle_request_disconnect("Completion generation cancelled by user.") + generation = await gen_queue.get() response = _create_stream_chunk(const_id, generation, model_path.name) - yield response.model_dump_json() - # Break if the generation is finished - if "finish_reason" in generation: + # Check if all tasks are completed + if all(task.done() for task in gen_tasks) and gen_queue.empty(): break except CancelledError: # Get out if the request gets disconnected @@ -247,7 +278,6 @@ async def generate_chat_completion( try: for n in range(0, data.n): - # Deepcopy gen params above the first index # to ensure nested structures aren't shared if n > 0: @@ -256,9 +286,7 @@ async def generate_chat_completion( task_gen_params = gen_params gen_tasks.append( - asyncio.create_task( - model.container.generate(prompt, **task_gen_params) - ) + asyncio.create_task(model.container.generate(prompt, **task_gen_params)) ) generations = await asyncio.gather(*gen_tasks) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 5034fbc..87256c9 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -5,7 +5,7 @@ import pathlib from asyncio import CancelledError from copy import deepcopy from fastapi import HTTPException, Request -from typing import List, Optional +from typing import List, Union from common import model from common.networking import ( @@ -24,13 +24,14 @@ from endpoints.OAI.types.completion import ( from endpoints.OAI.types.common import UsageStats -def _create_response(generations: List[dict], model_name: Optional[str]): - """Create a completion response from the provided text.""" +def _create_response(generations: Union[dict, List[dict]], model_name: str = ""): + """Create a completion response from the provided choices.""" - prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) - completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) + # Convert the single choice object into a list + if not isinstance(generations, list): + generations = [generations] - choices = [] + choices: List[CompletionRespChoice] = [] for index, generation in enumerate(generations): logprob_response = None @@ -46,8 +47,9 @@ def _create_response(generations: List[dict], model_name: Optional[str]): top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], ) + # The index can be located in the generation itself choice = CompletionRespChoice( - index=index, + index=unwrap(generation.get("index"), index), finish_reason=generation.get("finish_reason"), text=unwrap(generation.get("text"), ""), logprobs=logprob_response, @@ -55,9 +57,12 @@ def _create_response(generations: List[dict], model_name: Optional[str]): choices.append(choice) + prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) + completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) + response = CompletionResponse( choices=choices, - model=unwrap(model_name, ""), + model=model_name, usage=UsageStats( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, @@ -68,33 +73,64 @@ def _create_response(generations: List[dict], model_name: Optional[str]): return response +async def _stream_collector( + task_idx: int, + gen_queue: asyncio.Queue, + prompt: str, + abort_event: asyncio.Event, + **kwargs, +): + """Collects a stream and places results in a common queue""" + + new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) + async for generation in new_generation: + generation["index"] = task_idx + + await gen_queue.put(generation) + + if "finish_reason" in generation: + break + + async def stream_generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): """Streaming generation for completions.""" abort_event = asyncio.Event() + gen_queue = asyncio.Queue() + gen_tasks: List[asyncio.Task] = [] + disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: - new_generation = model.container.generate_gen( - data.prompt, abort_event, **data.to_gen_params() - ) + gen_params = data.to_gen_params() - # Create a background task to avoid blocking the loop - disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + for n in range(0, data.n): + if n > 0: + task_gen_params = deepcopy(gen_params) + else: + task_gen_params = gen_params - async for generation in new_generation: - # Sometimes this fires, and sometimes a CancelledError will fire - # Keep both implementations in to avoid the headache + gen_task = asyncio.create_task( + _stream_collector( + n, gen_queue, data.prompt, abort_event, **task_gen_params + ) + ) + + gen_tasks.append(gen_task) + + # Consumer loop + while True: if disconnect_task.done(): abort_event.set() handle_request_disconnect("Completion generation cancelled by user.") - response = _create_response([generation], model_path.name) + generation = await gen_queue.get() + response = _create_response(generation, model_path.name) yield response.model_dump_json() - # Break if the generation is finished - if "finish_reason" in generation: + # Check if all tasks are completed + if all(task.done() for task in gen_tasks) and gen_queue.empty(): yield "[DONE]" break except CancelledError: @@ -116,7 +152,6 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path) try: for n in range(0, data.n): - # Deepcopy gen params above the first index # to ensure nested structures aren't shared if n > 0: