mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 02:01:24 +00:00
Signal: Fix async signal handling
Run unload async functions before exiting the program. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import pathlib
|
import pathlib
|
||||||
import torch
|
import torch
|
||||||
|
from loguru import logger
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
@@ -50,6 +51,8 @@ class InfinityContainer:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
logger.info("Embedding model unloaded.")
|
||||||
|
|
||||||
async def generate(self, sentence_input: List[str]):
|
async def generate(self, sentence_input: List[str]):
|
||||||
result_embeddings, usage = await self.engine.embed(sentence_input)
|
result_embeddings, usage = await self.engine.embed(sentence_input)
|
||||||
|
|
||||||
|
|||||||
@@ -13,17 +13,20 @@ def signal_handler(*_):
|
|||||||
logger.warning("Shutdown signal called. Exiting gracefully.")
|
logger.warning("Shutdown signal called. Exiting gracefully.")
|
||||||
|
|
||||||
# Run async unloads for model
|
# Run async unloads for model
|
||||||
loop = asyncio.get_running_loop()
|
asyncio.ensure_future(signal_handler_async())
|
||||||
if model.container:
|
|
||||||
loop.create_task(model.container.unload())
|
|
||||||
|
|
||||||
if model.embeddings_container:
|
|
||||||
loop.create_task(model.embeddings_container.unload())
|
|
||||||
|
|
||||||
# Exit the program
|
# Exit the program
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
async def signal_handler_async(*_):
|
||||||
|
if model.container:
|
||||||
|
await model.container.unload()
|
||||||
|
|
||||||
|
if model.embeddings_container:
|
||||||
|
await model.embeddings_container.unload()
|
||||||
|
|
||||||
|
|
||||||
def uvicorn_signal_handler(signal_event: signal.Signals):
|
def uvicorn_signal_handler(signal_event: signal.Signals):
|
||||||
"""Overrides uvicorn's signal handler."""
|
"""Overrides uvicorn's signal handler."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user