mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
OAI: Add "n" for non-streaming generations
This adds the ability to add multiple choices to a generation. This is only available for non-streaming gens for now, it requires some more work to port over to streaming. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -31,47 +31,52 @@ from endpoints.OAI.types.chat_completion import (
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
logprob_response = None
|
||||
choices = []
|
||||
for index, generation in enumerate(generations):
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
logprob_response = None
|
||||
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
message=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
choice = ChatCompletionRespChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
message=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
choices.append(choice)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices=[choice],
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@@ -236,12 +241,18 @@ async def stream_generate_chat_completion(
|
||||
async def generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
|
||||
try:
|
||||
generation = await model.container.generate(
|
||||
prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
response = _create_response(generation, model_path.name)
|
||||
for _ in range(0, data.n):
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(prompt, **data.to_gen_params())
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
|
||||
Reference in New Issue
Block a user