"""Background asset seeder with thread management and cancellation support.""" import logging import os import threading import time from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Callable from app.assets.scanner import ( RootType, build_asset_specs, collect_paths_for_roots, get_all_known_prefixes, get_prefixes_for_root, insert_asset_specs, mark_missing_outside_prefixes_safely, sync_root_safely, ) from app.database.db import dependencies_available if TYPE_CHECKING: pass class State(Enum): """Seeder state machine states.""" IDLE = "IDLE" RUNNING = "RUNNING" CANCELLING = "CANCELLING" @dataclass class Progress: """Progress information for a scan operation.""" scanned: int = 0 total: int = 0 created: int = 0 skipped: int = 0 @dataclass class ScanStatus: """Current status of the asset seeder.""" state: State progress: Progress | None errors: list[str] = field(default_factory=list) ProgressCallback = Callable[[Progress], None] class AssetSeeder: """Singleton class managing background asset scanning. Thread-safe singleton that spawns ephemeral daemon threads for scanning. Each scan creates a new thread that exits when complete. """ _instance: "AssetSeeder | None" = None _instance_lock = threading.Lock() def __new__(cls) -> "AssetSeeder": with cls._instance_lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self) -> None: if self._initialized: return self._initialized = True self._lock = threading.Lock() self._state = State.IDLE self._progress: Progress | None = None self._errors: list[str] = [] self._thread: threading.Thread | None = None self._cancel_event = threading.Event() self._roots: tuple[RootType, ...] = () self._progress_callback: ProgressCallback | None = None def start( self, roots: tuple[RootType, ...] = ("models", "input", "output"), progress_callback: ProgressCallback | None = None, prune_first: bool = False, ) -> bool: """Start a background scan for the given roots. Args: roots: Tuple of root types to scan (models, input, output) progress_callback: Optional callback called with progress updates prune_first: If True, prune orphaned assets before scanning Returns: True if scan was started, False if already running """ with self._lock: if self._state != State.IDLE: return False self._state = State.RUNNING self._progress = Progress() self._errors = [] self._roots = roots self._prune_first = prune_first self._progress_callback = progress_callback self._cancel_event.clear() self._thread = threading.Thread( target=self._run_scan, name="AssetSeeder", daemon=True, ) self._thread.start() return True def cancel(self) -> bool: """Request cancellation of the current scan. Returns: True if cancellation was requested, False if not running """ with self._lock: if self._state != State.RUNNING: return False self._state = State.CANCELLING self._cancel_event.set() return True def wait(self, timeout: float | None = None) -> bool: """Wait for the current scan to complete. Args: timeout: Maximum seconds to wait, or None for no timeout Returns: True if scan completed, False if timeout expired or no scan running """ with self._lock: thread = self._thread if thread is None: return True thread.join(timeout=timeout) return not thread.is_alive() def get_status(self) -> ScanStatus: """Get the current status and progress of the seeder.""" with self._lock: return ScanStatus( state=self._state, progress=Progress( scanned=self._progress.scanned, total=self._progress.total, created=self._progress.created, skipped=self._progress.skipped, ) if self._progress else None, errors=list(self._errors), ) def shutdown(self, timeout: float = 5.0) -> None: """Gracefully shutdown: cancel any running scan and wait for thread. Args: timeout: Maximum seconds to wait for thread to exit """ self.cancel() self.wait(timeout=timeout) with self._lock: self._thread = None def mark_missing_outside_prefixes(self) -> int: """Mark cache states as missing when outside all known root prefixes. This is a non-destructive soft-delete operation. Assets and their metadata are preserved, but cache states are flagged as missing. They can be restored if the file reappears in a future scan. This operation is decoupled from scanning to prevent partial scans from accidentally marking assets belonging to other roots. Should be called explicitly when cleanup is desired, typically after a full scan of all roots or during maintenance. Returns: Number of cache states marked as missing, or 0 if dependencies unavailable or a scan is currently running """ with self._lock: if self._state != State.IDLE: logging.warning( "Cannot mark missing assets while scan is running" ) return 0 if not dependencies_available(): logging.warning( "Database dependencies not available, skipping mark missing" ) return 0 all_prefixes = get_all_known_prefixes() marked = mark_missing_outside_prefixes_safely(all_prefixes) if marked > 0: logging.info("Marked %d cache states as missing", marked) return marked def _is_cancelled(self) -> bool: """Check if cancellation has been requested.""" return self._cancel_event.is_set() def _emit_event(self, event_type: str, data: dict) -> None: """Emit a WebSocket event if server is available.""" try: from server import PromptServer if hasattr(PromptServer, "instance") and PromptServer.instance: PromptServer.instance.send_sync(event_type, data) except Exception: pass def _update_progress( self, scanned: int | None = None, total: int | None = None, created: int | None = None, skipped: int | None = None, ) -> None: """Update progress counters (thread-safe).""" with self._lock: if self._progress is None: return if scanned is not None: self._progress.scanned = scanned if total is not None: self._progress.total = total if created is not None: self._progress.created = created if skipped is not None: self._progress.skipped = skipped if self._progress_callback: try: self._progress_callback( Progress( scanned=self._progress.scanned, total=self._progress.total, created=self._progress.created, skipped=self._progress.skipped, ) ) except Exception: pass def _add_error(self, message: str) -> None: """Add an error message (thread-safe).""" with self._lock: self._errors.append(message) def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: """Log the directories that will be scanned.""" import folder_paths for root in roots: if root == "models": logging.info( "Asset scan [models] directory: %s", os.path.abspath(folder_paths.models_dir), ) else: prefixes = get_prefixes_for_root(root) if prefixes: logging.info("Asset scan [%s] directories: %s", root, prefixes) def _run_scan(self) -> None: """Main scan loop running in background thread.""" t_start = time.perf_counter() roots = self._roots cancelled = False total_created = 0 skipped_existing = 0 total_paths = 0 try: if not dependencies_available(): self._add_error("Database dependencies not available") self._emit_event( "assets.seed.error", {"message": "Database dependencies not available"}, ) return if self._prune_first: all_prefixes = get_all_known_prefixes() marked = mark_missing_outside_prefixes_safely(all_prefixes) if marked > 0: logging.info("Marked %d cache states as missing before scan", marked) if self._is_cancelled(): logging.info("Asset scan cancelled after pruning phase") cancelled = True return self._log_scan_config(roots) existing_paths: set[str] = set() for r in roots: if self._is_cancelled(): logging.info("Asset scan cancelled during sync phase") cancelled = True return existing_paths.update(sync_root_safely(r)) if self._is_cancelled(): logging.info("Asset scan cancelled after sync phase") cancelled = True return paths = collect_paths_for_roots(roots) total_paths = len(paths) self._update_progress(total=total_paths) self._emit_event( "assets.seed.started", {"roots": list(roots), "total": total_paths}, ) specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths) self._update_progress(skipped=skipped_existing) if self._is_cancelled(): logging.info("Asset scan cancelled after building specs") cancelled = True return batch_size = 500 last_progress_time = time.perf_counter() progress_interval = 1.0 for i in range(0, len(specs), batch_size): if self._is_cancelled(): logging.info( "Asset scan cancelled after %d/%d files (created=%d)", i, len(specs), total_created, ) cancelled = True return batch = specs[i : i + batch_size] batch_tags = {t for spec in batch for t in spec["tags"]} try: created = insert_asset_specs(batch, batch_tags) total_created += created except Exception as e: self._add_error(f"Batch insert failed at offset {i}: {e}") logging.exception("Batch insert failed at offset %d", i) scanned = i + len(batch) self._update_progress(scanned=scanned, created=total_created) now = time.perf_counter() if now - last_progress_time >= progress_interval: self._emit_event( "assets.seed.progress", { "scanned": scanned, "total": len(specs), "created": total_created, }, ) last_progress_time = now self._update_progress(scanned=len(specs), created=total_created) elapsed = time.perf_counter() - t_start logging.info( "Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, total=%d)", roots, elapsed, total_created, skipped_existing, len(paths), ) self._emit_event( "assets.seed.completed", { "scanned": len(specs), "total": total_paths, "created": total_created, "skipped": skipped_existing, "elapsed": round(elapsed, 3), }, ) except Exception as e: self._add_error(f"Scan failed: {e}") logging.exception("Asset scan failed") self._emit_event("assets.seed.error", {"message": str(e)}) finally: if cancelled: self._emit_event( "assets.seed.cancelled", { "scanned": self._progress.scanned if self._progress else 0, "total": total_paths, "created": total_created, }, ) with self._lock: self._state = State.IDLE asset_seeder = AssetSeeder()