feat(assets): add background asset seeder for non-blocking startup

- Add AssetSeeder singleton class with thread management and cancellation
- Support IDLE/RUNNING/CANCELLING state machine with thread-safe access
- Emit WebSocket events for scan progress (started, progress, completed, cancelled, error)
- Update main.py to use non-blocking asset_seeder.start() at startup
- Add shutdown() call in finally block for graceful cleanup
- Update POST /api/assets/seed to return 202 Accepted, support ?wait=true
- Add GET /api/assets/seed/status and POST /api/assets/seed/cancel endpoints
- Update test helper to use ?wait=true for synchronous behavior
- Add 17 unit tests covering state transitions, cancellation, and thread safety
- Log scan configuration (models directory, input/output paths) at scan start

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b45-e6e8-740a-b38b-b11daea8d094
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-02-04 17:02:57 -08:00
parent 469576ed87
commit 3a096a08ae
6 changed files with 996 additions and 18 deletions

View File

@@ -15,7 +15,7 @@ from app.assets.api.schemas_in import (
UploadError,
)
from app.assets.api.upload import parse_multipart_upload
from app.assets.scanner import seed_assets as scanner_seed_assets
from app.assets.seeder import asset_seeder
from app.assets.services import (
DependencyMissingError,
HashMismatchError,
@@ -620,21 +620,78 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/seed")
async def seed_assets(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output)."""
"""Trigger asset seeding for specified roots (models, input, output).
Query params:
wait: If "true", block until scan completes (synchronous behavior for tests)
Returns:
202 Accepted if scan started
409 Conflict if scan already running
200 OK with final stats if wait=true
"""
try:
payload = await request.json()
roots = payload.get("roots", ["models", "input", "output"])
except Exception:
roots = ["models", "input", "output"]
valid_roots = [r for r in roots if r in ("models", "input", "output")]
valid_roots = tuple(r for r in roots if r in ("models", "input", "output"))
if not valid_roots:
return _build_error_response(400, "INVALID_BODY", "No valid roots specified")
try:
scanner_seed_assets(tuple(valid_roots))
except Exception:
logging.exception("scanner_seed_assets failed for roots=%s", valid_roots)
return _build_error_response(500, "INTERNAL", "Seed operation failed")
wait_param = request.query.get("wait", "").lower()
should_wait = wait_param in ("true", "1", "yes")
return web.json_response({"seeded": valid_roots}, status=200)
started = asset_seeder.start(roots=valid_roots)
if not started:
return web.json_response({"status": "already_running"}, status=409)
if should_wait:
asset_seeder.wait()
status = asset_seeder.get_status()
return web.json_response(
{
"status": "completed",
"progress": {
"scanned": status.progress.scanned if status.progress else 0,
"total": status.progress.total if status.progress else 0,
"created": status.progress.created if status.progress else 0,
"skipped": status.progress.skipped if status.progress else 0,
},
"errors": status.errors,
},
status=200,
)
return web.json_response({"status": "started"}, status=202)
@ROUTES.get("/api/assets/seed/status")
async def get_seed_status(request: web.Request) -> web.Response:
"""Get current scan status and progress."""
status = asset_seeder.get_status()
return web.json_response(
{
"state": status.state.value,
"progress": {
"scanned": status.progress.scanned,
"total": status.progress.total,
"created": status.progress.created,
"skipped": status.progress.skipped,
}
if status.progress
else None,
"errors": status.errors,
},
status=200,
)
@ROUTES.post("/api/assets/seed/cancel")
async def cancel_seed(request: web.Request) -> web.Response:
"""Request cancellation of in-progress scan."""
cancelled = asset_seeder.cancel()
if cancelled:
return web.json_response({"status": "cancelling"}, status=200)
return web.json_response({"status": "idle"}, status=200)

376
app/assets/seeder.py Normal file
View File

@@ -0,0 +1,376 @@
"""Background asset seeder with thread management and cancellation support."""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Callable
from app.assets.scanner import (
RootType,
_build_asset_specs,
_collect_paths_for_roots,
_insert_asset_specs,
_prune_orphans_safely,
_sync_root_safely,
get_prefixes_for_root,
)
from app.database.db import dependencies_available
if TYPE_CHECKING:
from server import PromptServer
class State(Enum):
"""Seeder state machine states."""
IDLE = "IDLE"
RUNNING = "RUNNING"
CANCELLING = "CANCELLING"
@dataclass
class Progress:
"""Progress information for a scan operation."""
scanned: int = 0
total: int = 0
created: int = 0
skipped: int = 0
@dataclass
class ScanStatus:
"""Current status of the asset seeder."""
state: State
progress: Progress | None
errors: list[str] = field(default_factory=list)
ProgressCallback = Callable[[Progress], None]
class AssetSeeder:
"""Singleton class managing background asset scanning.
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
"""
_instance: "AssetSeeder | None" = None
_instance_lock = threading.Lock()
def __new__(cls) -> "AssetSeeder":
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._roots: tuple[RootType, ...] = ()
self._progress_callback: ProgressCallback | None = None
def start(
self,
roots: tuple[RootType, ...] = ("models",),
progress_callback: ProgressCallback | None = None,
) -> bool:
"""Start a background scan for the given roots.
Args:
roots: Tuple of root types to scan (models, input, output)
progress_callback: Optional callback called with progress updates
Returns:
True if scan was started, False if already running
"""
with self._lock:
if self._state != State.IDLE:
return False
self._state = State.RUNNING
self._progress = Progress()
self._errors = []
self._roots = roots
self._progress_callback = progress_callback
self._cancel_event.clear()
self._thread = threading.Thread(
target=self._run_scan,
name="AssetSeeder",
daemon=True,
)
self._thread.start()
return True
def cancel(self) -> bool:
"""Request cancellation of the current scan.
Returns:
True if cancellation was requested, False if not running
"""
with self._lock:
if self._state != State.RUNNING:
return False
self._state = State.CANCELLING
self._cancel_event.set()
return True
def wait(self, timeout: float | None = None) -> bool:
"""Wait for the current scan to complete.
Args:
timeout: Maximum seconds to wait, or None for no timeout
Returns:
True if scan completed, False if timeout expired or no scan running
"""
with self._lock:
thread = self._thread
if thread is None:
return True
thread.join(timeout=timeout)
return not thread.is_alive()
def get_status(self) -> ScanStatus:
"""Get the current status and progress of the seeder."""
with self._lock:
return ScanStatus(
state=self._state,
progress=Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if self._progress
else None,
errors=list(self._errors),
)
def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown: cancel any running scan and wait for thread.
Args:
timeout: Maximum seconds to wait for thread to exit
"""
self.cancel()
self.wait(timeout=timeout)
with self._lock:
self._thread = None
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
return self._cancel_event.is_set()
def _emit_event(self, event_type: str, data: dict) -> None:
"""Emit a WebSocket event if server is available."""
try:
from server import PromptServer
if hasattr(PromptServer, "instance") and PromptServer.instance:
PromptServer.instance.send_sync(event_type, data)
except Exception:
pass
def _update_progress(
self,
scanned: int | None = None,
total: int | None = None,
created: int | None = None,
skipped: int | None = None,
) -> None:
"""Update progress counters (thread-safe)."""
with self._lock:
if self._progress is None:
return
if scanned is not None:
self._progress.scanned = scanned
if total is not None:
self._progress.total = total
if created is not None:
self._progress.created = created
if skipped is not None:
self._progress.skipped = skipped
if self._progress_callback:
try:
self._progress_callback(
Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
)
except Exception:
pass
def _add_error(self, message: str) -> None:
"""Add an error message (thread-safe)."""
with self._lock:
self._errors.append(message)
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
"""Log the directories that will be scanned."""
import folder_paths
for root in roots:
if root == "models":
logging.info(
"Asset scan [models] directory: %s",
os.path.abspath(folder_paths.models_dir),
)
else:
prefixes = get_prefixes_for_root(root)
if prefixes:
logging.info("Asset scan [%s] directories: %s", root, prefixes)
def _run_scan(self) -> None:
"""Main scan loop running in background thread."""
t_start = time.perf_counter()
roots = self._roots
cancelled = False
total_created = 0
skipped_existing = 0
total_paths = 0
try:
if not dependencies_available():
self._add_error("Database dependencies not available")
self._emit_event(
"assets.seed.error",
{"message": "Database dependencies not available"},
)
return
self._log_scan_config(roots)
existing_paths: set[str] = set()
for r in roots:
if self._is_cancelled():
logging.info("Asset scan cancelled during sync phase")
cancelled = True
return
existing_paths.update(_sync_root_safely(r))
if self._is_cancelled():
logging.info("Asset scan cancelled after sync phase")
cancelled = True
return
all_prefixes = [
os.path.abspath(p) for r in roots for p in get_prefixes_for_root(r)
]
orphans_pruned = _prune_orphans_safely(all_prefixes)
if self._is_cancelled():
logging.info("Asset scan cancelled after orphan pruning")
cancelled = True
return
paths = _collect_paths_for_roots(roots)
total_paths = len(paths)
self._update_progress(total=total_paths)
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "total": total_paths},
)
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths)
self._update_progress(skipped=skipped_existing)
if self._is_cancelled():
logging.info("Asset scan cancelled after building specs")
cancelled = True
return
batch_size = 500
last_progress_time = time.perf_counter()
progress_interval = 1.0
for i in range(0, len(specs), batch_size):
if self._is_cancelled():
logging.info(
"Asset scan cancelled after %d/%d files (created=%d)",
i,
len(specs),
total_created,
)
cancelled = True
return
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = _insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
logging.exception("Batch insert failed at offset %d", i)
scanned = i + len(batch)
self._update_progress(scanned=scanned, created=total_created)
now = time.perf_counter()
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{"scanned": scanned, "total": len(specs), "created": total_created},
)
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
elapsed = time.perf_counter() - t_start
logging.info(
"Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, orphans_pruned=%d, total=%d)",
roots,
elapsed,
total_created,
skipped_existing,
orphans_pruned,
len(paths),
)
self._emit_event(
"assets.seed.completed",
{
"scanned": len(specs),
"total": total_paths,
"created": total_created,
"skipped": skipped_existing,
"elapsed": round(elapsed, 3),
},
)
except Exception as e:
self._add_error(f"Scan failed: {e}")
logging.exception("Asset scan failed")
self._emit_event("assets.seed.error", {"message": str(e)})
finally:
if cancelled:
self._emit_event(
"assets.seed.cancelled",
{
"scanned": self._progress.scanned if self._progress else 0,
"total": total_paths,
"created": total_created,
},
)
with self._lock:
self._state = State.IDLE
asset_seeder = AssetSeeder()