Signal: Fix signal handlers for uvicorn

Add the ability to override uvicorn's signal handler in addition
to using main's signal handler for any SIGINTs before the API server
starts.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-16 23:18:27 -04:00
committed by Brian Dashore
parent 95e44c20d6
commit 14d8ec2007
3 changed files with 36 additions and 6 deletions

23
common/signals.py Normal file
View File

@@ -0,0 +1,23 @@
import signal
import sys
from loguru import logger
from types import FrameType
def signal_handler(*_):
"""Signal handler for main function. Run before uvicorn starts."""
logger.warning("Shutdown signal called. Exiting gracefully.")
sys.exit(0)
def uvicorn_signal_handler(signal_event: signal.Signals):
"""Overrides uvicorn's signal handler."""
default_signal_handler = signal.getsignal(signal_event)
def wrapped_handler(signum: int, frame: FrameType = None):
logger.warning("Shutdown signal called. Exiting gracefully.")
default_signal_handler(signum, frame)
signal.signal(signal_event, wrapped_handler)

View File

@@ -1,5 +1,7 @@
import pathlib import pathlib
import signal
import uvicorn import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, Request from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from functools import partial from functools import partial
@@ -14,6 +16,7 @@ from common.concurrency import (
generate_with_semaphore, generate_with_semaphore,
) )
from common.logger import UVICORN_LOG_CONFIG from common.logger import UVICORN_LOG_CONFIG
from common.signals import uvicorn_signal_handler
from common.templating import ( from common.templating import (
get_all_templates, get_all_templates,
get_template_from_file, get_template_from_file,
@@ -55,6 +58,14 @@ from endpoints.OAI.utils.completion import (
from endpoints.OAI.utils.model import get_model_list, stream_model_load from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list from endpoints.OAI.utils.lora import get_lora_list
@asynccontextmanager
async def lifespan(_: FastAPI):
uvicorn_signal_handler(signal.SIGINT)
uvicorn_signal_handler(signal.SIGTERM)
yield
app = FastAPI( app = FastAPI(
title="TabbyAPI", title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast", summary="An OAI compatible exllamav2 API that's both lightweight and fast",
@@ -62,6 +73,7 @@ app = FastAPI(
"This docs page is not meant to send requests! Please use a service " "This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI." "like Postman or a frontend UI."
), ),
lifespan=lifespan,
) )
# ALlow CORS requests # ALlow CORS requests

View File

@@ -4,7 +4,6 @@ import asyncio
import os import os
import pathlib import pathlib
import signal import signal
import sys
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
@@ -13,15 +12,11 @@ from common import config, gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser from common.args import convert_args_to_dict, init_argparser
from common.auth import load_auth_keys from common.auth import load_auth_keys
from common.logger import setup_logger from common.logger import setup_logger
from common.signals import signal_handler
from common.utils import is_port_in_use, unwrap from common.utils import is_port_in_use, unwrap
from endpoints.OAI.app import start_api from endpoints.OAI.app import start_api
def signal_handler(*_):
logger.warning("Shutdown signal called. Exiting gracefully.")
sys.exit(0)
async def entrypoint(args: Optional[dict] = None): async def entrypoint(args: Optional[dict] = None):
"""Entry function for program startup""" """Entry function for program startup"""