Files
ComfyUI/app/assets/seeder.py
Luke Mino-Altherr 6443cf016e refactor: make scanner helper functions public
Rename _sync_root_safely, _prune_orphans_safely, _collect_paths_for_roots,
_build_asset_specs, and _insert_asset_specs to remove underscore prefix
since they are used by seeder.py as part of the public API.

Amp-Thread-ID: https://ampcode.com/threads/T-019c3037-df32-7138-99d8-b4b824d896b3
Co-authored-by: Amp <amp@ampcode.com>
2026-02-11 17:41:38 -08:00

415 lines
13 KiB
Python

"""Background asset seeder with thread management and cancellation support."""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Callable
from app.assets.scanner import (
RootType,
build_asset_specs,
collect_paths_for_roots,
get_all_known_prefixes,
get_prefixes_for_root,
insert_asset_specs,
prune_orphans_safely,
sync_root_safely,
)
from app.database.db import dependencies_available
if TYPE_CHECKING:
pass
class State(Enum):
"""Seeder state machine states."""
IDLE = "IDLE"
RUNNING = "RUNNING"
CANCELLING = "CANCELLING"
@dataclass
class Progress:
"""Progress information for a scan operation."""
scanned: int = 0
total: int = 0
created: int = 0
skipped: int = 0
@dataclass
class ScanStatus:
"""Current status of the asset seeder."""
state: State
progress: Progress | None
errors: list[str] = field(default_factory=list)
ProgressCallback = Callable[[Progress], None]
class AssetSeeder:
"""Singleton class managing background asset scanning.
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
"""
_instance: "AssetSeeder | None" = None
_instance_lock = threading.Lock()
def __new__(cls) -> "AssetSeeder":
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._roots: tuple[RootType, ...] = ()
self._progress_callback: ProgressCallback | None = None
def start(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
) -> bool:
"""Start a background scan for the given roots.
Args:
roots: Tuple of root types to scan (models, input, output)
progress_callback: Optional callback called with progress updates
prune_first: If True, prune orphaned assets before scanning
Returns:
True if scan was started, False if already running
"""
with self._lock:
if self._state != State.IDLE:
return False
self._state = State.RUNNING
self._progress = Progress()
self._errors = []
self._roots = roots
self._prune_first = prune_first
self._progress_callback = progress_callback
self._cancel_event.clear()
self._thread = threading.Thread(
target=self._run_scan,
name="AssetSeeder",
daemon=True,
)
self._thread.start()
return True
def cancel(self) -> bool:
"""Request cancellation of the current scan.
Returns:
True if cancellation was requested, False if not running
"""
with self._lock:
if self._state != State.RUNNING:
return False
self._state = State.CANCELLING
self._cancel_event.set()
return True
def wait(self, timeout: float | None = None) -> bool:
"""Wait for the current scan to complete.
Args:
timeout: Maximum seconds to wait, or None for no timeout
Returns:
True if scan completed, False if timeout expired or no scan running
"""
with self._lock:
thread = self._thread
if thread is None:
return True
thread.join(timeout=timeout)
return not thread.is_alive()
def get_status(self) -> ScanStatus:
"""Get the current status and progress of the seeder."""
with self._lock:
return ScanStatus(
state=self._state,
progress=Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if self._progress
else None,
errors=list(self._errors),
)
def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown: cancel any running scan and wait for thread.
Args:
timeout: Maximum seconds to wait for thread to exit
"""
self.cancel()
self.wait(timeout=timeout)
with self._lock:
self._thread = None
def prune_orphans(self) -> int:
"""Prune orphaned assets that are outside all known root prefixes.
This operation is decoupled from scanning to prevent partial scans
from accidentally deleting assets belonging to other roots.
Should be called explicitly when cleanup is desired, typically after
a full scan of all roots or during maintenance.
Returns:
Number of orphaned assets pruned, or 0 if dependencies unavailable
or a scan is currently running
"""
with self._lock:
if self._state != State.IDLE:
logging.warning("Cannot prune orphans while scan is running")
return 0
if not dependencies_available():
logging.warning(
"Database dependencies not available, skipping orphan pruning"
)
return 0
all_prefixes = get_all_known_prefixes()
pruned = prune_orphans_safely(all_prefixes)
if pruned > 0:
logging.info("Pruned %d orphaned assets", pruned)
return pruned
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
return self._cancel_event.is_set()
def _emit_event(self, event_type: str, data: dict) -> None:
"""Emit a WebSocket event if server is available."""
try:
from server import PromptServer
if hasattr(PromptServer, "instance") and PromptServer.instance:
PromptServer.instance.send_sync(event_type, data)
except Exception:
pass
def _update_progress(
self,
scanned: int | None = None,
total: int | None = None,
created: int | None = None,
skipped: int | None = None,
) -> None:
"""Update progress counters (thread-safe)."""
with self._lock:
if self._progress is None:
return
if scanned is not None:
self._progress.scanned = scanned
if total is not None:
self._progress.total = total
if created is not None:
self._progress.created = created
if skipped is not None:
self._progress.skipped = skipped
if self._progress_callback:
try:
self._progress_callback(
Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
)
except Exception:
pass
def _add_error(self, message: str) -> None:
"""Add an error message (thread-safe)."""
with self._lock:
self._errors.append(message)
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
"""Log the directories that will be scanned."""
import folder_paths
for root in roots:
if root == "models":
logging.info(
"Asset scan [models] directory: %s",
os.path.abspath(folder_paths.models_dir),
)
else:
prefixes = get_prefixes_for_root(root)
if prefixes:
logging.info("Asset scan [%s] directories: %s", root, prefixes)
def _run_scan(self) -> None:
"""Main scan loop running in background thread."""
t_start = time.perf_counter()
roots = self._roots
cancelled = False
total_created = 0
skipped_existing = 0
total_paths = 0
try:
if not dependencies_available():
self._add_error("Database dependencies not available")
self._emit_event(
"assets.seed.error",
{"message": "Database dependencies not available"},
)
return
if self._prune_first:
all_prefixes = get_all_known_prefixes()
pruned = prune_orphans_safely(all_prefixes)
if pruned > 0:
logging.info("Pruned %d orphaned assets before scan", pruned)
if self._is_cancelled():
logging.info("Asset scan cancelled after pruning phase")
cancelled = True
return
self._log_scan_config(roots)
existing_paths: set[str] = set()
for r in roots:
if self._is_cancelled():
logging.info("Asset scan cancelled during sync phase")
cancelled = True
return
existing_paths.update(sync_root_safely(r))
if self._is_cancelled():
logging.info("Asset scan cancelled after sync phase")
cancelled = True
return
paths = collect_paths_for_roots(roots)
total_paths = len(paths)
self._update_progress(total=total_paths)
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "total": total_paths},
)
specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths)
self._update_progress(skipped=skipped_existing)
if self._is_cancelled():
logging.info("Asset scan cancelled after building specs")
cancelled = True
return
batch_size = 500
last_progress_time = time.perf_counter()
progress_interval = 1.0
for i in range(0, len(specs), batch_size):
if self._is_cancelled():
logging.info(
"Asset scan cancelled after %d/%d files (created=%d)",
i,
len(specs),
total_created,
)
cancelled = True
return
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
logging.exception("Batch insert failed at offset %d", i)
scanned = i + len(batch)
self._update_progress(scanned=scanned, created=total_created)
now = time.perf_counter()
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
)
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
elapsed = time.perf_counter() - t_start
logging.info(
"Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, total=%d)",
roots,
elapsed,
total_created,
skipped_existing,
len(paths),
)
self._emit_event(
"assets.seed.completed",
{
"scanned": len(specs),
"total": total_paths,
"created": total_created,
"skipped": skipped_existing,
"elapsed": round(elapsed, 3),
},
)
except Exception as e:
self._add_error(f"Scan failed: {e}")
logging.exception("Asset scan failed")
self._emit_event("assets.seed.error", {"message": str(e)})
finally:
if cancelled:
self._emit_event(
"assets.seed.cancelled",
{
"scanned": self._progress.scanned if self._progress else 0,
"total": total_paths,
"created": total_created,
},
)
with self._lock:
self._state = State.IDLE
asset_seeder = AssetSeeder()