Signal: Fix async signal handling

Run unload async functions before exiting the program.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-30 11:11:05 -04:00
parent fbf1455db1
commit 01c7702859
2 changed files with 12 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
import gc import gc
import pathlib import pathlib
import torch import torch
from loguru import logger
from typing import List, Optional from typing import List, Optional
from common.utils import unwrap from common.utils import unwrap
@@ -50,6 +51,8 @@ class InfinityContainer:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Embedding model unloaded.")
async def generate(self, sentence_input: List[str]): async def generate(self, sentence_input: List[str]):
result_embeddings, usage = await self.engine.embed(sentence_input) result_embeddings, usage = await self.engine.embed(sentence_input)

View File

@@ -13,17 +13,20 @@ def signal_handler(*_):
logger.warning("Shutdown signal called. Exiting gracefully.") logger.warning("Shutdown signal called. Exiting gracefully.")
# Run async unloads for model # Run async unloads for model
loop = asyncio.get_running_loop() asyncio.ensure_future(signal_handler_async())
if model.container:
loop.create_task(model.container.unload())
if model.embeddings_container:
loop.create_task(model.embeddings_container.unload())
# Exit the program # Exit the program
sys.exit(0) sys.exit(0)
async def signal_handler_async(*_):
if model.container:
await model.container.unload()
if model.embeddings_container:
await model.embeddings_container.unload()
def uvicorn_signal_handler(signal_event: signal.Signals): def uvicorn_signal_handler(signal_event: signal.Signals):
"""Overrides uvicorn's signal handler.""" """Overrides uvicorn's signal handler."""