OAI: Add "n" support for streaming generations

Use a queue-based system to get choices independently and send them
in the overall streaming payload. This method allows for unordered
streaming of generations.

The system is a bit redundant, so maybe make the code more optimized
in the future.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-05-27 23:53:25 -04:00
committed by Brian Dashore
parent c8371e0f50
commit e2a8b6e8ae
2 changed files with 100 additions and 37 deletions

View File

@@ -96,11 +96,13 @@ def _create_stream_chunk(
):
"""Create a chat completion stream chunk from the provided text."""
index = generation.get("index")
logprob_response = None
if "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
finish_reason=generation.get("finish_reason")
index=index,
finish_reason=generation.get("finish_reason"),
)
else:
message = ChatCompletionMessage(
@@ -125,6 +127,7 @@ def _create_stream_chunk(
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
choice = ChatCompletionStreamChoice(
index=index,
delta=message,
logprobs=logprob_response,
)
@@ -199,34 +202,62 @@ def format_prompt_with_template(data: ChatCompletionRequest):
raise HTTPException(400, error_message) from exc
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
if "finish_reason" in generation:
break
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
"""Generator for the generation process."""
const_id = f"chatcmpl-{uuid4().hex}"
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
const_id = f"chatcmpl-{uuid4().hex}"
gen_params = data.to_gen_params()
new_generation = model.container.generate_gen(
prompt, abort_event, **data.to_gen_params()
)
# Create a background task to avoid blocking the loop
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
for n in range(0, data.n):
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
gen_task = asyncio.create_task(
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
)
gen_tasks.append(gen_task)
# Consumer loop
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
generation = await gen_queue.get()
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
# Break if the generation is finished
if "finish_reason" in generation:
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
break
except CancelledError:
# Get out if the request gets disconnected
@@ -247,7 +278,6 @@ async def generate_chat_completion(
try:
for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
if n > 0:
@@ -256,9 +286,7 @@ async def generate_chat_completion(
task_gen_params = gen_params
gen_tasks.append(
asyncio.create_task(
model.container.generate(prompt, **task_gen_params)
)
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
)
generations = await asyncio.gather(*gen_tasks)