mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
OAI: Add "n" for non-streaming generations
This adds the ability to add multiple choices to a generation. This is only available for non-streaming gens for now, it requires some more work to port over to streaming. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -4,7 +4,7 @@ import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
@@ -23,34 +23,39 @@ from endpoints.OAI.types.completion import (
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
"""Create a completion response from the provided text."""
|
||||
|
||||
logprob_response = None
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
offset = unwrap(generation.get("offset"), [])
|
||||
choices = []
|
||||
for index, generation in enumerate(generations):
|
||||
logprob_response = None
|
||||
|
||||
logprob_response = CompletionLogProbs(
|
||||
text_offset=offset if isinstance(offset, list) else [offset],
|
||||
token_logprobs=token_probs.values(),
|
||||
tokens=token_probs.keys(),
|
||||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
offset = unwrap(generation.get("offset"), [])
|
||||
|
||||
logprob_response = CompletionLogProbs(
|
||||
text_offset=offset if isinstance(offset, list) else [offset],
|
||||
token_logprobs=token_probs.values(),
|
||||
tokens=token_probs.keys(),
|
||||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
)
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
text=unwrap(generation.get("text"), ""),
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
text=unwrap(generation.get("text"), ""),
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
choices.append(choice)
|
||||
|
||||
response = CompletionResponse(
|
||||
choices=[choice],
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@@ -84,7 +89,7 @@ async def stream_generate_completion(
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
response = _create_response([generation], model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Break if the generation is finished
|
||||
@@ -105,9 +110,18 @@ async def stream_generate_completion(
|
||||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
"""Non-streaming generate for completions"""
|
||||
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
|
||||
try:
|
||||
generation = await model.container.generate(data.prompt, **data.to_gen_params())
|
||||
response = _create_response(generation, model_path.name)
|
||||
for _ in range(0, data.n):
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(data.prompt, **data.to_gen_params())
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
|
||||
Reference in New Issue
Block a user