mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
API: Split functions into their own files
Previously, generation function were bundled with the request function causing the overall code structure and API to look ugly and unreadable. Split these up and cleanup a lot of the methods that were previously overlooked in the API itself. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,17 +1,15 @@
|
||||
""" Utility functions for the OpenAI server. """
|
||||
"""Completion utilities for OAI server."""
|
||||
import pathlib
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
ChatCompletionLogprob,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRespChoice,
|
||||
ChatCompletionStreamChunk,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
||||
from endpoints.OAI.types.completion import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionRespChoice,
|
||||
CompletionLogProbs,
|
||||
@@ -19,7 +17,7 @@ from endpoints.OAI.types.completion import (
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def create_completion_response(generation: dict, model_name: Optional[str]):
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a completion response from the provided text."""
|
||||
|
||||
logprob_response = None
|
||||
@@ -58,97 +56,49 @@ def create_completion_response(generation: dict, model_name: Optional[str]):
|
||||
return response
|
||||
|
||||
|
||||
def create_chat_completion_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason="Generated", message=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices=[choice],
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_chat_completion_stream_chunk(
|
||||
const_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
async def stream_generate_completion(
|
||||
request: Request, data: CompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
"""Streaming generation for completions."""
|
||||
|
||||
logprob_response = None
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
if finish_reason:
|
||||
message = {}
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
yield "[DONE]"
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), {})
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs.items()
|
||||
]
|
||||
|
||||
generated_token = next(iter(token_probs))
|
||||
token_prob_response = ChatCompletionLogprob(
|
||||
token=generated_token,
|
||||
logprob=token_probs[generated_token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
"""Non-streaming generate for completions"""
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
||||
try:
|
||||
generation = await run_in_threadpool(
|
||||
model.container.generate, data.prompt, **data.to_gen_params()
|
||||
)
|
||||
|
||||
# The finish reason can be None
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason=finish_reason, delta=message, logprobs=logprob_response
|
||||
)
|
||||
response = _create_response(generation, model_path.name)
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||
)
|
||||
|
||||
return chunk
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
|
||||
Reference in New Issue
Block a user