mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
API: Add ability to use request IDs
Identify which request is being processed to help users disambiguate which logs correspond to which request. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -5,7 +5,6 @@ import pathlib
|
||||
from asyncio import CancelledError
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
@@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _stream_collector
|
||||
|
||||
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
def _create_response(
|
||||
request_id: str, generations: List[dict], model_name: Optional[str]
|
||||
):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
@@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
choices.append(choice)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
@@ -90,7 +93,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
|
||||
|
||||
def _create_stream_chunk(
|
||||
const_id: str,
|
||||
request_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
is_usage_chunk: bool = False,
|
||||
@@ -150,7 +153,7 @@ def _create_stream_chunk(
|
||||
choices.append(choice)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id,
|
||||
id=f"chatcmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=usage_stats,
|
||||
@@ -235,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
await gen_queue.put(generation)
|
||||
|
||||
if "finish_reason" in generation:
|
||||
break
|
||||
except Exception as e:
|
||||
await gen_queue.put(e)
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
abort_event = asyncio.Event()
|
||||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
logger.info(f"Recieved chat completion streaming request {request.state.id}")
|
||||
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
for n in range(0, data.n):
|
||||
@@ -277,7 +259,14 @@ async def stream_generate_chat_completion(
|
||||
task_gen_params = gen_params
|
||||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
|
||||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
prompt,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
**task_gen_params,
|
||||
)
|
||||
)
|
||||
|
||||
gen_tasks.append(gen_task)
|
||||
@@ -286,7 +275,9 @@ async def stream_generate_chat_completion(
|
||||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Chat completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
|
||||
@@ -294,7 +285,9 @@ async def stream_generate_chat_completion(
|
||||
if isinstance(generation, Exception):
|
||||
raise generation
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
response = _create_stream_chunk(
|
||||
request.state.id, generation, model_path.name
|
||||
)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Check if all tasks are completed
|
||||
@@ -302,10 +295,17 @@ async def stream_generate_chat_completion(
|
||||
# Send a usage chunk
|
||||
if data.stream_options and data.stream_options.include_usage:
|
||||
usage_chunk = _create_stream_chunk(
|
||||
const_id, generation, model_path.name, is_usage_chunk=True
|
||||
request.state.id,
|
||||
generation,
|
||||
model_path.name,
|
||||
is_usage_chunk=True,
|
||||
)
|
||||
yield usage_chunk.model_dump_json()
|
||||
|
||||
logger.info(
|
||||
f"Finished chat completion streaming request {request.state.id}"
|
||||
)
|
||||
|
||||
yield "[DONE]"
|
||||
break
|
||||
except CancelledError:
|
||||
@@ -320,7 +320,7 @@ async def stream_generate_chat_completion(
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
gen_params = data.to_gen_params()
|
||||
@@ -335,16 +335,23 @@ async def generate_chat_completion(
|
||||
task_gen_params = gen_params
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt, request.state.id, **task_gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished chat completion request {request.state.id}")
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Chat completion aborted. Maybe the model was unloaded? "
|
||||
f"Chat completion {request.state.id} aborted. "
|
||||
"Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
|
||||
Reference in New Issue
Block a user