"""Common utility functions""" import asyncio import json import socket import time import traceback from fastapi import Depends, HTTPException, Request from loguru import logger from common.logger import xlogger from pydantic import BaseModel from typing import Optional from uuid import uuid4 from common.tabby_config import config class TabbyRequestErrorMessage(BaseModel): """Common request error type.""" message: str trace: Optional[str] = None class TabbyRequestError(BaseModel): """Common request error type.""" error: TabbyRequestErrorMessage def get_generator_error(message: str, exc_info: bool = True): """Get a generator error.""" generator_error = handle_request_error(message, exc_info) return generator_error.model_dump_json() def handle_request_error(message: str, exc_info: bool = True): """Log a request error to the console.""" trace = traceback.format_exc() send_trace = config.network.send_tracebacks error_message = TabbyRequestErrorMessage(message=message, trace=trace if send_trace else None) request_error = TabbyRequestError(error=error_message) # Log the error and provided message to the console if trace and exc_info: xlogger.error("Error", {"trace": trace, "message": message}, details=trace) logger.error(f"Sent to request: {message}") return request_error def handle_request_disconnect(message: str): """Wrapper for handling for request disconnection.""" xlogger.error(message) class DisconnectHandler: def __init__( self, request: Request, description: str, ): self.request = request self.abort_event = asyncio.Event() self.last_poll = time.time() - 10 self.disconnected = False self.cleanup_tasks = {} self.description = description async def poll(self): """ Poll the request status a maximum of 20 times per second. Once request is disconnected runs scheduled cleanup tasks and raises asyncio.CancelledError. Caller is responsible for forwarding the error back to the endpoint function. The endpoint fn should call poll() at least once before returning a non-canceled response """ now = time.time() if now < self.last_poll + 0.05: return self.last_poll = now # Check if request has disconnected if await self.request.is_disconnected(): # Set abort signal if self.abort_event is not None: self.abort_event.set() # Trigger any cleanup tasks await self.cleanup() # Log and raise if not self.disconnected: xlogger.error(f"Request disconnected: {self.description}") self.disconnected = True raise asyncio.CancelledError(f"Request disconnected: {self.description}") async def add_cleanup_task(self, key, func, args): # Intentionally strict assert key not in self.cleanup_tasks self.cleanup_tasks[key] = (func, args) async def finish(self, key): # Intentionally strict del self.cleanup_tasks[key] # Safe to call redundantly, each cleanup task must be called exactly once async def cleanup(self): for func, args in self.cleanup_tasks.values(): await func(*args) self.cleanup_tasks = {} async def request_disconnect_loop(request: Request): """Polls for a starlette request disconnect.""" while not await request.is_disconnected(): await asyncio.sleep(0.5) async def run_with_request_disconnect( request: Request, call_task: asyncio.Task, disconnect_message: str ): """Utility function to cancel if a request is disconnected.""" _, unfinished = await asyncio.wait( [ call_task, asyncio.create_task(request_disconnect_loop(request)), ], return_when=asyncio.FIRST_COMPLETED, ) for task in unfinished: task.cancel() try: return call_task.result() except (asyncio.CancelledError, asyncio.InvalidStateError) as ex: handle_request_disconnect(disconnect_message) raise HTTPException(422, disconnect_message) from ex def is_port_in_use(port: int) -> bool: """ Checks if a port is in use From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python """ test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) test_socket.settimeout(1) with test_socket: return test_socket.connect_ex(("localhost", port)) == 0 async def add_request_id(request: Request): """FastAPI depends to add a UUID to a request's state.""" request.state.id = uuid4().hex return request async def log_request(request: Request): """FastAPI depends to log a request to the user.""" log_message = [f"Information for {request.method} request {request.state.id}:"] log_message.append(f"URL: {request.url}") log_message.append(f"Headers: {dict(request.headers)}") if request.method != "GET": body_bytes = await request.body() if body_bytes: body = json.loads(body_bytes.decode("utf-8")) log_message.append(f"Body: {dict(body)}") xlogger.info("Request", dict(request), details="\n".join(log_message)) def get_global_depends(): """Returns global dependencies for a FastAPI app.""" depends = [Depends(add_request_id)] if config.logging.log_requests: depends.append(Depends(log_request)) return depends