mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
API: Split functions into their own files
Previously, generation function were bundled with the request function causing the overall code structure and API to look ugly and unreadable. Split these up and cleanup a lot of the methods that were previously overlooked in the API itself. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
199
endpoints/OAI/utils/chat_completion.py
Normal file
199
endpoints/OAI/utils/chat_completion.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Chat completion utilities for OAI server."""
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.templating import get_prompt_from_template
|
||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
ChatCompletionLogprob,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionRespChoice,
|
||||
ChatCompletionStreamChunk,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason="Generated", message=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices=[choice],
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _create_stream_chunk(
|
||||
const_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
|
||||
logprob_response = None
|
||||
|
||||
if finish_reason:
|
||||
message = {}
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), {})
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs.items()
|
||||
]
|
||||
|
||||
generated_token = next(iter(token_probs))
|
||||
token_prob_response = ChatCompletionLogprob(
|
||||
token=generated_token,
|
||||
logprob=token_probs[generated_token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
||||
|
||||
# The finish reason can be None
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason=finish_reason, delta=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
try:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True),
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
return get_prompt_from_template(
|
||||
data.messages,
|
||||
model.container.prompt_template,
|
||||
data.add_generation_prompt,
|
||||
special_tokens_dict,
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
"Could not find a Conversation from prompt template "
|
||||
f"'{model.container.prompt_template.name}'. "
|
||||
"Check your spelling?",
|
||||
) from exc
|
||||
except TemplateError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"TemplateError: {str(exc)}",
|
||||
) from exc
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Chat completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
|
||||
|
||||
yield finish_response.model_dump_json()
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Chat completion aborted. Please check the server console."
|
||||
)
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
try:
|
||||
generation = await run_in_threadpool(
|
||||
model.container.generate,
|
||||
prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Chat completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
Reference in New Issue
Block a user