"""Background asset seeder with thread management and cancellation support.""" import logging import os import threading import time from dataclasses import dataclass, field from enum import Enum from typing import Callable from app.assets.scanner import ( RootType, build_asset_specs, collect_paths_for_roots, get_all_known_prefixes, get_prefixes_for_root, insert_asset_specs, mark_missing_outside_prefixes_safely, sync_root_safely, ) from app.database.db import dependencies_available class State(Enum): """Seeder state machine states.""" IDLE = "IDLE" RUNNING = "RUNNING" CANCELLING = "CANCELLING" @dataclass class Progress: """Progress information for a scan operation.""" scanned: int = 0 total: int = 0 created: int = 0 skipped: int = 0 @dataclass class ScanStatus: """Current status of the asset seeder.""" state: State progress: Progress | None errors: list[str] = field(default_factory=list) ProgressCallback = Callable[[Progress], None] class AssetSeeder: """Singleton class managing background asset scanning. Thread-safe singleton that spawns ephemeral daemon threads for scanning. Each scan creates a new thread that exits when complete. """ _instance: "AssetSeeder | None" = None _instance_lock = threading.Lock() def __new__(cls) -> "AssetSeeder": with cls._instance_lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self) -> None: if self._initialized: return self._initialized = True self._lock = threading.Lock() self._state = State.IDLE self._progress: Progress | None = None self._errors: list[str] = [] self._thread: threading.Thread | None = None self._cancel_event = threading.Event() self._roots: tuple[RootType, ...] = () self._compute_hashes: bool = False self._progress_callback: ProgressCallback | None = None def start( self, roots: tuple[RootType, ...] = ("models", "input", "output"), progress_callback: ProgressCallback | None = None, prune_first: bool = False, compute_hashes: bool = False, ) -> bool: """Start a background scan for the given roots. Args: roots: Tuple of root types to scan (models, input, output) progress_callback: Optional callback called with progress updates prune_first: If True, prune orphaned assets before scanning compute_hashes: If True, compute blake3 hashes for each file (slow for large files) Returns: True if scan was started, False if already running """ with self._lock: if self._state != State.IDLE: return False self._state = State.RUNNING self._progress = Progress() self._errors = [] self._roots = roots self._prune_first = prune_first self._compute_hashes = compute_hashes self._progress_callback = progress_callback self._cancel_event.clear() self._thread = threading.Thread( target=self._run_scan, name="AssetSeeder", daemon=True, ) self._thread.start() return True def cancel(self) -> bool: """Request cancellation of the current scan. Returns: True if cancellation was requested, False if not running """ with self._lock: if self._state != State.RUNNING: return False self._state = State.CANCELLING self._cancel_event.set() return True def wait(self, timeout: float | None = None) -> bool: """Wait for the current scan to complete. Args: timeout: Maximum seconds to wait, or None for no timeout Returns: True if scan completed, False if timeout expired or no scan running """ with self._lock: thread = self._thread if thread is None: return True thread.join(timeout=timeout) return not thread.is_alive() def get_status(self) -> ScanStatus: """Get the current status and progress of the seeder.""" with self._lock: return ScanStatus( state=self._state, progress=Progress( scanned=self._progress.scanned, total=self._progress.total, created=self._progress.created, skipped=self._progress.skipped, ) if self._progress else None, errors=list(self._errors), ) def shutdown(self, timeout: float = 5.0) -> None: """Gracefully shutdown: cancel any running scan and wait for thread. Args: timeout: Maximum seconds to wait for thread to exit """ self.cancel() self.wait(timeout=timeout) with self._lock: self._thread = None def mark_missing_outside_prefixes(self) -> int: """Mark cache states as missing when outside all known root prefixes. This is a non-destructive soft-delete operation. Assets and their metadata are preserved, but cache states are flagged as missing. They can be restored if the file reappears in a future scan. This operation is decoupled from scanning to prevent partial scans from accidentally marking assets belonging to other roots. Should be called explicitly when cleanup is desired, typically after a full scan of all roots or during maintenance. Returns: Number of cache states marked as missing, or 0 if dependencies unavailable or a scan is currently running """ with self._lock: if self._state != State.IDLE: logging.warning( "Cannot mark missing assets while scan is running" ) return 0 self._state = State.RUNNING try: if not dependencies_available(): logging.warning( "Database dependencies not available, skipping mark missing" ) return 0 all_prefixes = get_all_known_prefixes() marked = mark_missing_outside_prefixes_safely(all_prefixes) if marked > 0: logging.info("Marked %d cache states as missing", marked) return marked finally: with self._lock: self._state = State.IDLE def _is_cancelled(self) -> bool: """Check if cancellation has been requested.""" return self._cancel_event.is_set() def _emit_event(self, event_type: str, data: dict) -> None: """Emit a WebSocket event if server is available.""" try: from server import PromptServer if hasattr(PromptServer, "instance") and PromptServer.instance: PromptServer.instance.send_sync(event_type, data) except Exception: pass def _update_progress( self, scanned: int | None = None, total: int | None = None, created: int | None = None, skipped: int | None = None, ) -> None: """Update progress counters (thread-safe).""" callback: ProgressCallback | None = None progress: Progress | None = None with self._lock: if self._progress is None: return if scanned is not None: self._progress.scanned = scanned if total is not None: self._progress.total = total if created is not None: self._progress.created = created if skipped is not None: self._progress.skipped = skipped if self._progress_callback: callback = self._progress_callback progress = Progress( scanned=self._progress.scanned, total=self._progress.total, created=self._progress.created, skipped=self._progress.skipped, ) if callback and progress: try: callback(progress) except Exception: pass def _add_error(self, message: str) -> None: """Add an error message (thread-safe).""" with self._lock: self._errors.append(message) def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: """Log the directories that will be scanned.""" import folder_paths for root in roots: if root == "models": logging.info( "Asset scan [models] directory: %s", os.path.abspath(folder_paths.models_dir), ) else: prefixes = get_prefixes_for_root(root) if prefixes: logging.info("Asset scan [%s] directories: %s", root, prefixes) def _run_scan(self) -> None: """Main scan loop running in background thread.""" t_start = time.perf_counter() roots = self._roots cancelled = False total_created = 0 skipped_existing = 0 total_paths = 0 try: if not dependencies_available(): self._add_error("Database dependencies not available") self._emit_event( "assets.seed.error", {"message": "Database dependencies not available"}, ) return if self._prune_first: all_prefixes = get_all_known_prefixes() marked = mark_missing_outside_prefixes_safely(all_prefixes) if marked > 0: logging.info("Marked %d cache states as missing before scan", marked) if self._is_cancelled(): logging.info("Asset scan cancelled after pruning phase") cancelled = True return self._log_scan_config(roots) existing_paths: set[str] = set() for r in roots: if self._is_cancelled(): logging.info("Asset scan cancelled during sync phase") cancelled = True return existing_paths.update(sync_root_safely(r)) if self._is_cancelled(): logging.info("Asset scan cancelled after sync phase") cancelled = True return paths = collect_paths_for_roots(roots) total_paths = len(paths) self._update_progress(total=total_paths) self._emit_event( "assets.seed.started", {"roots": list(roots), "total": total_paths}, ) specs, tag_pool, skipped_existing = build_asset_specs( paths, existing_paths, compute_hashes=self._compute_hashes ) self._update_progress(skipped=skipped_existing) if self._is_cancelled(): logging.info("Asset scan cancelled after building specs") cancelled = True return batch_size = 500 last_progress_time = time.perf_counter() progress_interval = 1.0 for i in range(0, len(specs), batch_size): if self._is_cancelled(): logging.info( "Asset scan cancelled after %d/%d files (created=%d)", i, len(specs), total_created, ) cancelled = True return batch = specs[i : i + batch_size] batch_tags = {t for spec in batch for t in spec["tags"]} try: created = insert_asset_specs(batch, batch_tags) total_created += created except Exception as e: self._add_error(f"Batch insert failed at offset {i}: {e}") logging.exception("Batch insert failed at offset %d", i) scanned = i + len(batch) now = time.perf_counter() self._update_progress(scanned=scanned, created=total_created) if now - last_progress_time >= progress_interval: self._emit_event( "assets.seed.progress", { "scanned": scanned, "total": len(specs), "created": total_created, }, ) last_progress_time = now self._update_progress(scanned=len(specs), created=total_created) elapsed = time.perf_counter() - t_start logging.info( "Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, total=%d)", roots, elapsed, total_created, skipped_existing, len(paths), ) self._emit_event( "assets.seed.completed", { "scanned": len(specs), "total": total_paths, "created": total_created, "skipped": skipped_existing, "elapsed": round(elapsed, 3), }, ) except Exception as e: self._add_error(f"Scan failed: {e}") logging.exception("Asset scan failed") self._emit_event("assets.seed.error", {"message": str(e)}) finally: if cancelled: self._emit_event( "assets.seed.cancelled", { "scanned": self._progress.scanned if self._progress else 0, "total": total_paths, "created": total_created, }, ) with self._lock: self._state = State.IDLE asset_seeder = AssetSeeder()