Files
tabbyAPI/common/networking.py
turboderp 79d581e1f5 OAI endpoints: More rework
- remove disconnect_task
- move disconnect logic to a per-request handler that wraps cleanup operation and directly polls the request state with throttling
- exclusively signal disconnect with CancelledError
- rework completions endpoint to follow same approach as chat completions, share some code
- refactor OAI endpoints a bit
- correct behavior for batched completion requests
- make sure logprobs work for completion and streaming completion requests
- more tests
2026-04-02 01:26:44 +02:00

198 lines
5.5 KiB
Python

"""Common utility functions"""
import asyncio
import json
import socket
import time
import traceback
from fastapi import Depends, HTTPException, Request
from loguru import logger
from common.logger import xlogger
from pydantic import BaseModel
from typing import Optional
from uuid import uuid4
from common.tabby_config import config
class TabbyRequestErrorMessage(BaseModel):
"""Common request error type."""
message: str
trace: Optional[str] = None
class TabbyRequestError(BaseModel):
"""Common request error type."""
error: TabbyRequestErrorMessage
def get_generator_error(message: str, exc_info: bool = True):
"""Get a generator error."""
generator_error = handle_request_error(message, exc_info)
return generator_error.model_dump_json()
def handle_request_error(message: str, exc_info: bool = True):
"""Log a request error to the console."""
trace = traceback.format_exc()
send_trace = config.network.send_tracebacks
error_message = TabbyRequestErrorMessage(message=message, trace=trace if send_trace else None)
request_error = TabbyRequestError(error=error_message)
# Log the error and provided message to the console
if trace and exc_info:
xlogger.error("Error", {"trace": trace, "message": message}, details=trace)
logger.error(f"Sent to request: {message}")
return request_error
def handle_request_disconnect(message: str):
"""Wrapper for handling for request disconnection."""
xlogger.error(message)
class DisconnectHandler:
def __init__(
self,
request: Request,
description: str,
):
self.request = request
self.abort_event = asyncio.Event()
self.last_poll = time.time() - 10
self.disconnected = False
self.cleanup_tasks = {}
self.description = description
async def poll(self):
"""
Poll the request status a maximum of 20 times per second. Once request is disconnected
runs scheduled cleanup tasks and raises asyncio.CancelledError. Caller is responsible for
forwarding the error back to the endpoint function. The endpoint fn should call poll() at
least once before returning a non-canceled response
"""
now = time.time()
if now < self.last_poll + 0.05:
return
self.last_poll = now
# Check if request has disconnected
if await self.request.is_disconnected():
# Set abort signal
if self.abort_event is not None:
self.abort_event.set()
# Trigger any cleanup tasks
await self.cleanup()
# Log and raise
if not self.disconnected:
xlogger.error(f"Request disconnected: {self.description}")
self.disconnected = True
raise asyncio.CancelledError(f"Request disconnected: {self.description}")
async def add_cleanup_task(self, key, func, args):
# Intentionally strict
assert key not in self.cleanup_tasks
self.cleanup_tasks[key] = (func, args)
async def finish(self, key):
# Intentionally strict
del self.cleanup_tasks[key]
# Safe to call redundantly, each cleanup task must be called exactly once
async def cleanup(self):
for func, args in self.cleanup_tasks.values():
await func(*args)
self.cleanup_tasks = {}
async def request_disconnect_loop(request: Request):
"""Polls for a starlette request disconnect."""
while not await request.is_disconnected():
await asyncio.sleep(0.5)
async def run_with_request_disconnect(
request: Request, call_task: asyncio.Task, disconnect_message: str
):
"""Utility function to cancel if a request is disconnected."""
_, unfinished = await asyncio.wait(
[
call_task,
asyncio.create_task(request_disconnect_loop(request)),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in unfinished:
task.cancel()
try:
return call_task.result()
except (asyncio.CancelledError, asyncio.InvalidStateError) as ex:
handle_request_disconnect(disconnect_message)
raise HTTPException(422, disconnect_message) from ex
def is_port_in_use(port: int) -> bool:
"""
Checks if a port is in use
From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python
"""
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
test_socket.settimeout(1)
with test_socket:
return test_socket.connect_ex(("localhost", port)) == 0
async def add_request_id(request: Request):
"""FastAPI depends to add a UUID to a request's state."""
request.state.id = uuid4().hex
return request
async def log_request(request: Request):
"""FastAPI depends to log a request to the user."""
log_message = [f"Information for {request.method} request {request.state.id}:"]
log_message.append(f"URL: {request.url}")
log_message.append(f"Headers: {dict(request.headers)}")
if request.method != "GET":
body_bytes = await request.body()
if body_bytes:
body = json.loads(body_bytes.decode("utf-8"))
log_message.append(f"Body: {dict(body)}")
xlogger.info("Request", dict(request), details="\n".join(log_message))
def get_global_depends():
"""Returns global dependencies for a FastAPI app."""
depends = [Depends(add_request_id)]
if config.logging.log_requests:
depends.append(Depends(log_request))
return depends