OAI: Add "n" for non-streaming generations

This adds the ability to add multiple choices to a generation. This
is only available for non-streaming gens for now, it requires some
more work to port over to streaming.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-05-26 22:22:51 -04:00
committed by Brian Dashore
parent 8d31a5aed1
commit b944f8d756
3 changed files with 89 additions and 64 deletions

View File

@@ -4,7 +4,7 @@ import asyncio
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException, Request
from typing import Optional
from typing import List, Optional
from common import model
from common.networking import (
@@ -23,34 +23,39 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats
def _create_response(generation: dict, model_name: Optional[str]):
def _create_response(generations: List[dict], model_name: Optional[str]):
"""Create a completion response from the provided text."""
logprob_response = None
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
choices = []
for index, generation in enumerate(generations):
logprob_response = None
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
choice = CompletionRespChoice(
index=index,
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
choice = CompletionRespChoice(
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
choices.append(choice)
response = CompletionResponse(
choices=[choice],
choices=choices,
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
@@ -84,7 +89,7 @@ async def stream_generate_completion(
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
response = _create_response(generation, model_path.name)
response = _create_response([generation], model_path.name)
yield response.model_dump_json()
# Break if the generation is finished
@@ -105,9 +110,18 @@ async def stream_generate_completion(
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Non-streaming generate for completions"""
gen_tasks: List[asyncio.Task] = []
try:
generation = await model.container.generate(data.prompt, **data.to_gen_params())
response = _create_response(generation, model_path.name)
for _ in range(0, data.n):
gen_tasks.append(
asyncio.create_task(
model.container.generate(data.prompt, **data.to_gen_params())
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
return response
except Exception as exc: