mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
At any point for any request cancellation, the semaphore will be decremented. This is an issue since an arbitrary request can desync the semaphore, causing multiple tasks to be processed at once and break generation. Remove this from the networking handlers and therefore, remove the release_semaphore function itself. Signed-off-by: kingbri <bdashore3@proton.me>
96 lines
2.4 KiB
Python
96 lines
2.4 KiB
Python
"""Common utility functions"""
|
|
|
|
import asyncio
|
|
import socket
|
|
import traceback
|
|
from fastapi import Request
|
|
from loguru import logger
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
|
|
|
|
class TabbyRequestErrorMessage(BaseModel):
|
|
"""Common request error type."""
|
|
|
|
message: str
|
|
trace: Optional[str] = None
|
|
|
|
|
|
class TabbyRequestError(BaseModel):
|
|
"""Common request error type."""
|
|
|
|
error: TabbyRequestErrorMessage
|
|
|
|
|
|
def get_generator_error(message: str, exc_info: bool = True):
|
|
"""Get a generator error."""
|
|
|
|
generator_error = handle_request_error(message, exc_info)
|
|
|
|
return generator_error.model_dump_json()
|
|
|
|
|
|
def handle_request_error(message: str, exc_info: bool = True):
|
|
"""Log a request error to the console."""
|
|
|
|
error_message = TabbyRequestErrorMessage(
|
|
message=message, trace=traceback.format_exc()
|
|
)
|
|
|
|
request_error = TabbyRequestError(error=error_message)
|
|
|
|
# Log the error and provided message to the console
|
|
if error_message.trace and exc_info:
|
|
logger.error(error_message.trace)
|
|
|
|
logger.error(f"Sent to request: {message}")
|
|
|
|
return request_error
|
|
|
|
|
|
def handle_request_disconnect(message: str):
|
|
"""Wrapper for handling for request disconnection."""
|
|
|
|
logger.error(message)
|
|
|
|
|
|
async def request_disconnect_loop(request: Request):
|
|
"""Polls for a starlette request disconnect."""
|
|
|
|
while not await request.is_disconnected():
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
|
async def run_with_request_disconnect(
|
|
request: Request, call_task: asyncio.Task, disconnect_message: str
|
|
):
|
|
"""Utility function to cancel if a request is disconnected."""
|
|
|
|
_, unfinished = await asyncio.wait(
|
|
[
|
|
call_task,
|
|
asyncio.create_task(request_disconnect_loop(request)),
|
|
],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
for task in unfinished:
|
|
task.cancel()
|
|
|
|
try:
|
|
return call_task.result()
|
|
except (asyncio.CancelledError, asyncio.InvalidStateError):
|
|
handle_request_disconnect(disconnect_message)
|
|
|
|
|
|
def is_port_in_use(port: int) -> bool:
|
|
"""
|
|
Checks if a port is in use
|
|
|
|
From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python
|
|
"""
|
|
|
|
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
test_socket.settimeout(1)
|
|
with test_socket:
|
|
return test_socket.connect_ex(("localhost", port)) == 0
|