From 193c1b2dfab7a8ef1732e1ede9fdbade42c1af11 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 31 Aug 2025 16:58:01 -0600 Subject: [PATCH] Add a watcher to constantly check for stop signal from the UI. This will force a stop within 2 seconds instead of having to wait on a long hung process. --- extensions_built_in/sd_trainer/UITrainer.py | 47 ++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index df940ced..c83a7c7a 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -5,7 +5,9 @@ import asyncio import concurrent.futures from extensions_built_in.sd_trainer.SDTrainer import SDTrainer from typing import Literal, Optional - +import threading +import time +import signal AITK_Status = Literal["running", "stopped", "error", "completed"] @@ -30,6 +32,49 @@ class UITrainer(SDTrainer): self._async_tasks = [] # Initialize the status self._run_async_operation(self._update_status("running", "Starting")) + self._stop_watcher_started = False + self.start_stop_watcher(interval_sec=2.0) + + def start_stop_watcher(self, interval_sec: float = 5.0): + """ + Start a daemon thread that periodically checks should_stop() + and terminates the process immediately when triggered. + """ + if getattr(self, "_stop_watcher_started", False): + return + self._stop_watcher_started = True + t = threading.Thread( + target=self._stop_watcher_thread, args=(interval_sec,), daemon=True + ) + t.start() + + def _stop_watcher_thread(self, interval_sec: float): + while True: + try: + if self.should_stop(): + # Mark and update status (non-blocking; uses existing infra) + self.is_stopping = True + self._run_async_operation( + self._update_status("stopped", "Job stopped (remote)") + ) + # Best-effort flush pending async ops + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + # Try to stop DB thread pool quickly + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + print("") + print("****************************************************") + print(" Stop signal received; terminating process. ") + print("****************************************************") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(interval_sec) + except Exception: + time.sleep(interval_sec) def _run_async_operation(self, coro): """Helper method to run an async coroutine and track the task."""