API: Split functions into their own files

Previously, generation function were bundled with the request function
causing the overall code structure and API to look ugly and unreadable.

Split these up and cleanup a lot of the methods that were previously
overlooked in the API itself.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-12 01:31:59 -04:00
committed by Brian Dashore
parent 104a6121cb
commit 6f03be9523
5 changed files with 376 additions and 294 deletions

View File

@@ -1,17 +1,15 @@
""" Utility functions for the OpenAI server. """
"""Completion utilities for OAI server."""
import pathlib
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from typing import Optional
from common.utils import unwrap
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,
ChatCompletionMessage,
ChatCompletionRespChoice,
ChatCompletionStreamChunk,
ChatCompletionResponse,
ChatCompletionStreamChoice,
)
from common import model
from common.generators import release_semaphore
from common.utils import get_generator_error, handle_request_error, unwrap
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,
CompletionRespChoice,
CompletionLogProbs,
@@ -19,7 +17,7 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats
def create_completion_response(generation: dict, model_name: Optional[str]):
def _create_response(generation: dict, model_name: Optional[str]):
"""Create a completion response from the provided text."""
logprob_response = None
@@ -58,97 +56,49 @@ def create_completion_response(generation: dict, model_name: Optional[str]):
return response
def create_chat_completion_response(generation: dict, model_name: Optional[str]):
"""Create a chat completion response from the provided text."""
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
collected_token_probs = []
for index, token in enumerate(token_probs.keys()):
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs[index].items()
]
collected_token_probs.append(
ChatCompletionLogprob(
token=token,
logprob=token_probs[token],
top_logprobs=top_logprobs,
)
)
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
choice = ChatCompletionRespChoice(
finish_reason="Generated", message=message, logprobs=logprob_response
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
response = ChatCompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def create_chat_completion_stream_chunk(
const_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None,
async def stream_generate_completion(
request: Request, data: CompletionRequest, model_path: pathlib.Path
):
"""Create a chat completion stream chunk from the provided text."""
"""Streaming generation for completions."""
logprob_response = None
try:
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return
if finish_reason:
message = {}
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Yield a finish response on successful generation
yield "[DONE]"
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
)
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs.items()
]
generated_token = next(iter(token_probs))
token_prob_response = ChatCompletionLogprob(
token=generated_token,
logprob=token_probs[generated_token],
top_logprobs=top_logprobs,
)
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Non-streaming generate for completions"""
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
try:
generation = await run_in_threadpool(
model.container.generate, data.prompt, **data.to_gen_params()
)
# The finish reason can be None
choice = ChatCompletionStreamChoice(
finish_reason=finish_reason, delta=message, logprobs=logprob_response
)
response = _create_response(generation, model_path.name)
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
)
return chunk
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc