mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add a watcher to constantly check for stop signal from the UI. This will force a stop within 2 seconds instead of having to wait on a long hung process.
This commit is contained in:
@@ -5,7 +5,9 @@ import asyncio
|
||||
import concurrent.futures
|
||||
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
|
||||
from typing import Literal, Optional
|
||||
|
||||
import threading
|
||||
import time
|
||||
import signal
|
||||
|
||||
AITK_Status = Literal["running", "stopped", "error", "completed"]
|
||||
|
||||
@@ -30,6 +32,49 @@ class UITrainer(SDTrainer):
|
||||
self._async_tasks = []
|
||||
# Initialize the status
|
||||
self._run_async_operation(self._update_status("running", "Starting"))
|
||||
self._stop_watcher_started = False
|
||||
self.start_stop_watcher(interval_sec=2.0)
|
||||
|
||||
def start_stop_watcher(self, interval_sec: float = 5.0):
|
||||
"""
|
||||
Start a daemon thread that periodically checks should_stop()
|
||||
and terminates the process immediately when triggered.
|
||||
"""
|
||||
if getattr(self, "_stop_watcher_started", False):
|
||||
return
|
||||
self._stop_watcher_started = True
|
||||
t = threading.Thread(
|
||||
target=self._stop_watcher_thread, args=(interval_sec,), daemon=True
|
||||
)
|
||||
t.start()
|
||||
|
||||
def _stop_watcher_thread(self, interval_sec: float):
|
||||
while True:
|
||||
try:
|
||||
if self.should_stop():
|
||||
# Mark and update status (non-blocking; uses existing infra)
|
||||
self.is_stopping = True
|
||||
self._run_async_operation(
|
||||
self._update_status("stopped", "Job stopped (remote)")
|
||||
)
|
||||
# Best-effort flush pending async ops
|
||||
try:
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
except RuntimeError:
|
||||
pass
|
||||
# Try to stop DB thread pool quickly
|
||||
try:
|
||||
self.thread_pool.shutdown(wait=False, cancel_futures=True)
|
||||
except TypeError:
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
print("")
|
||||
print("****************************************************")
|
||||
print(" Stop signal received; terminating process. ")
|
||||
print("****************************************************")
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
time.sleep(interval_sec)
|
||||
except Exception:
|
||||
time.sleep(interval_sec)
|
||||
|
||||
def _run_async_operation(self, coro):
|
||||
"""Helper method to run an async coroutine and track the task."""
|
||||
|
||||
Reference in New Issue
Block a user