Files
tabbyAPI/common/utils.py
kingbri c82697fef2 API: Fix issues with concurrent requests and queueing
This is the first in many future commits that will overhaul the API
to be more robust and concurrent. The model is admin-first where the
admin can do anything in-case something goes awry.

Previously, calls to long running synchronous background tasks would
block the entire API, making it ignore any terminal signals until
generation is completed.

To fix this, levrage FastAPI's run_in_threadpool to offload the long
running tasks to another thread. However, signals to abort the process
still kept the background thread running and made the terminal hang.

This was due to an issue with Uvicorn not propegating the SIGINT signal
across threads in its event loop. To fix this in a catch-all way, run
the API processes in a separate thread so the main thread can still
kill the process if needed.

In addition, make request error logging more robust and refer to the
console for full error logs rather than creating a long message on the
client-side.

Finally, add state checks to see if a model is fully loaded before
generating a completion.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-03-04 23:21:40 -05:00

76 lines
1.6 KiB
Python

"""Common utility functions"""
import traceback
from pydantic import BaseModel
from typing import Optional
from common.logger import init_logger
logger = init_logger(__name__)
def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
class TabbyRequestErrorMessage(BaseModel):
"""Common request error type."""
message: str
trace: Optional[str] = None
class TabbyRequestError(BaseModel):
"""Common request error type."""
error: TabbyRequestErrorMessage
def get_generator_error(message: str):
"""Get a generator error."""
generator_error = handle_request_error(message)
return get_sse_packet(generator_error.model_dump_json())
def handle_request_error(message: str):
"""Log a request error to the console."""
error_message = TabbyRequestErrorMessage(
message=message, trace=traceback.format_exc()
)
request_error = TabbyRequestError(error=error_message)
# Log the error and provided message to the console
logger.error(error_message.trace)
logger.error(message)
return request_error
def get_sse_packet(json_data: str):
"""Get an SSE packet."""
return f"data: {json_data}\n\n"
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None:
return default
return wrapped
def coalesce(*args):
"""Coalesce function for multiple unwraps."""
return next((arg for arg in args if arg is not None), None)
def prune_dict(input_dict):
"""Trim out instances of None from a dictionary"""
return {k: v for k, v in input_dict.items() if v is not None}