From 3a096a08aeb748e7457d637dbe46f2b1e9e6f52a Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Wed, 4 Feb 2026 17:02:57 -0800 Subject: [PATCH] feat(assets): add background asset seeder for non-blocking startup - Add AssetSeeder singleton class with thread management and cancellation - Support IDLE/RUNNING/CANCELLING state machine with thread-safe access - Emit WebSocket events for scan progress (started, progress, completed, cancelled, error) - Update main.py to use non-blocking asset_seeder.start() at startup - Add shutdown() call in finally block for graceful cleanup - Update POST /api/assets/seed to return 202 Accepted, support ?wait=true - Add GET /api/assets/seed/status and POST /api/assets/seed/cancel endpoints - Update test helper to use ?wait=true for synchronous behavior - Add 17 unit tests covering state transitions, cancellation, and thread safety - Log scan configuration (models directory, input/output paths) at scan start Amp-Thread-ID: https://ampcode.com/threads/T-019c2b45-e6e8-740a-b38b-b11daea8d094 Co-authored-by: Amp --- app/assets/api/routes.py | 75 ++++- app/assets/seeder.py | 376 +++++++++++++++++++++++++ docs/design/background-asset-seeder.md | 195 +++++++++++++ main.py | 10 +- tests-unit/assets_test/helpers.py | 11 +- tests-unit/seeder_test/test_seeder.py | 347 +++++++++++++++++++++++ 6 files changed, 996 insertions(+), 18 deletions(-) create mode 100644 app/assets/seeder.py create mode 100644 docs/design/background-asset-seeder.md create mode 100644 tests-unit/seeder_test/test_seeder.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 688b9f2db..d3039e894 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -15,7 +15,7 @@ from app.assets.api.schemas_in import ( UploadError, ) from app.assets.api.upload import parse_multipart_upload -from app.assets.scanner import seed_assets as scanner_seed_assets +from app.assets.seeder import asset_seeder from app.assets.services import ( DependencyMissingError, HashMismatchError, @@ -620,21 +620,78 @@ async def delete_asset_tags(request: web.Request) -> web.Response: @ROUTES.post("/api/assets/seed") async def seed_assets(request: web.Request) -> web.Response: - """Trigger asset seeding for specified roots (models, input, output).""" + """Trigger asset seeding for specified roots (models, input, output). + + Query params: + wait: If "true", block until scan completes (synchronous behavior for tests) + + Returns: + 202 Accepted if scan started + 409 Conflict if scan already running + 200 OK with final stats if wait=true + """ try: payload = await request.json() roots = payload.get("roots", ["models", "input", "output"]) except Exception: roots = ["models", "input", "output"] - valid_roots = [r for r in roots if r in ("models", "input", "output")] + valid_roots = tuple(r for r in roots if r in ("models", "input", "output")) if not valid_roots: return _build_error_response(400, "INVALID_BODY", "No valid roots specified") - try: - scanner_seed_assets(tuple(valid_roots)) - except Exception: - logging.exception("scanner_seed_assets failed for roots=%s", valid_roots) - return _build_error_response(500, "INTERNAL", "Seed operation failed") + wait_param = request.query.get("wait", "").lower() + should_wait = wait_param in ("true", "1", "yes") - return web.json_response({"seeded": valid_roots}, status=200) + started = asset_seeder.start(roots=valid_roots) + if not started: + return web.json_response({"status": "already_running"}, status=409) + + if should_wait: + asset_seeder.wait() + status = asset_seeder.get_status() + return web.json_response( + { + "status": "completed", + "progress": { + "scanned": status.progress.scanned if status.progress else 0, + "total": status.progress.total if status.progress else 0, + "created": status.progress.created if status.progress else 0, + "skipped": status.progress.skipped if status.progress else 0, + }, + "errors": status.errors, + }, + status=200, + ) + + return web.json_response({"status": "started"}, status=202) + + +@ROUTES.get("/api/assets/seed/status") +async def get_seed_status(request: web.Request) -> web.Response: + """Get current scan status and progress.""" + status = asset_seeder.get_status() + return web.json_response( + { + "state": status.state.value, + "progress": { + "scanned": status.progress.scanned, + "total": status.progress.total, + "created": status.progress.created, + "skipped": status.progress.skipped, + } + if status.progress + else None, + "errors": status.errors, + }, + status=200, + ) + + +@ROUTES.post("/api/assets/seed/cancel") +async def cancel_seed(request: web.Request) -> web.Response: + """Request cancellation of in-progress scan.""" + cancelled = asset_seeder.cancel() + if cancelled: + return web.json_response({"status": "cancelling"}, status=200) + return web.json_response({"status": "idle"}, status=200) diff --git a/app/assets/seeder.py b/app/assets/seeder.py new file mode 100644 index 000000000..8722e3d1a --- /dev/null +++ b/app/assets/seeder.py @@ -0,0 +1,376 @@ +"""Background asset seeder with thread management and cancellation support.""" + +import logging +import os +import threading +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Callable + +from app.assets.scanner import ( + RootType, + _build_asset_specs, + _collect_paths_for_roots, + _insert_asset_specs, + _prune_orphans_safely, + _sync_root_safely, + get_prefixes_for_root, +) +from app.database.db import dependencies_available + +if TYPE_CHECKING: + from server import PromptServer + + +class State(Enum): + """Seeder state machine states.""" + + IDLE = "IDLE" + RUNNING = "RUNNING" + CANCELLING = "CANCELLING" + + +@dataclass +class Progress: + """Progress information for a scan operation.""" + + scanned: int = 0 + total: int = 0 + created: int = 0 + skipped: int = 0 + + +@dataclass +class ScanStatus: + """Current status of the asset seeder.""" + + state: State + progress: Progress | None + errors: list[str] = field(default_factory=list) + + +ProgressCallback = Callable[[Progress], None] + + +class AssetSeeder: + """Singleton class managing background asset scanning. + + Thread-safe singleton that spawns ephemeral daemon threads for scanning. + Each scan creates a new thread that exits when complete. + """ + + _instance: "AssetSeeder | None" = None + _instance_lock = threading.Lock() + + def __new__(cls) -> "AssetSeeder": + with cls._instance_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self) -> None: + if self._initialized: + return + self._initialized = True + self._lock = threading.Lock() + self._state = State.IDLE + self._progress: Progress | None = None + self._errors: list[str] = [] + self._thread: threading.Thread | None = None + self._cancel_event = threading.Event() + self._roots: tuple[RootType, ...] = () + self._progress_callback: ProgressCallback | None = None + + def start( + self, + roots: tuple[RootType, ...] = ("models",), + progress_callback: ProgressCallback | None = None, + ) -> bool: + """Start a background scan for the given roots. + + Args: + roots: Tuple of root types to scan (models, input, output) + progress_callback: Optional callback called with progress updates + + Returns: + True if scan was started, False if already running + """ + with self._lock: + if self._state != State.IDLE: + return False + self._state = State.RUNNING + self._progress = Progress() + self._errors = [] + self._roots = roots + self._progress_callback = progress_callback + self._cancel_event.clear() + self._thread = threading.Thread( + target=self._run_scan, + name="AssetSeeder", + daemon=True, + ) + self._thread.start() + return True + + def cancel(self) -> bool: + """Request cancellation of the current scan. + + Returns: + True if cancellation was requested, False if not running + """ + with self._lock: + if self._state != State.RUNNING: + return False + self._state = State.CANCELLING + self._cancel_event.set() + return True + + def wait(self, timeout: float | None = None) -> bool: + """Wait for the current scan to complete. + + Args: + timeout: Maximum seconds to wait, or None for no timeout + + Returns: + True if scan completed, False if timeout expired or no scan running + """ + with self._lock: + thread = self._thread + if thread is None: + return True + thread.join(timeout=timeout) + return not thread.is_alive() + + def get_status(self) -> ScanStatus: + """Get the current status and progress of the seeder.""" + with self._lock: + return ScanStatus( + state=self._state, + progress=Progress( + scanned=self._progress.scanned, + total=self._progress.total, + created=self._progress.created, + skipped=self._progress.skipped, + ) + if self._progress + else None, + errors=list(self._errors), + ) + + def shutdown(self, timeout: float = 5.0) -> None: + """Gracefully shutdown: cancel any running scan and wait for thread. + + Args: + timeout: Maximum seconds to wait for thread to exit + """ + self.cancel() + self.wait(timeout=timeout) + with self._lock: + self._thread = None + + def _is_cancelled(self) -> bool: + """Check if cancellation has been requested.""" + return self._cancel_event.is_set() + + def _emit_event(self, event_type: str, data: dict) -> None: + """Emit a WebSocket event if server is available.""" + try: + from server import PromptServer + + if hasattr(PromptServer, "instance") and PromptServer.instance: + PromptServer.instance.send_sync(event_type, data) + except Exception: + pass + + def _update_progress( + self, + scanned: int | None = None, + total: int | None = None, + created: int | None = None, + skipped: int | None = None, + ) -> None: + """Update progress counters (thread-safe).""" + with self._lock: + if self._progress is None: + return + if scanned is not None: + self._progress.scanned = scanned + if total is not None: + self._progress.total = total + if created is not None: + self._progress.created = created + if skipped is not None: + self._progress.skipped = skipped + if self._progress_callback: + try: + self._progress_callback( + Progress( + scanned=self._progress.scanned, + total=self._progress.total, + created=self._progress.created, + skipped=self._progress.skipped, + ) + ) + except Exception: + pass + + def _add_error(self, message: str) -> None: + """Add an error message (thread-safe).""" + with self._lock: + self._errors.append(message) + + def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: + """Log the directories that will be scanned.""" + import folder_paths + + for root in roots: + if root == "models": + logging.info( + "Asset scan [models] directory: %s", + os.path.abspath(folder_paths.models_dir), + ) + else: + prefixes = get_prefixes_for_root(root) + if prefixes: + logging.info("Asset scan [%s] directories: %s", root, prefixes) + + def _run_scan(self) -> None: + """Main scan loop running in background thread.""" + t_start = time.perf_counter() + roots = self._roots + cancelled = False + total_created = 0 + skipped_existing = 0 + total_paths = 0 + + try: + if not dependencies_available(): + self._add_error("Database dependencies not available") + self._emit_event( + "assets.seed.error", + {"message": "Database dependencies not available"}, + ) + return + + self._log_scan_config(roots) + + existing_paths: set[str] = set() + for r in roots: + if self._is_cancelled(): + logging.info("Asset scan cancelled during sync phase") + cancelled = True + return + existing_paths.update(_sync_root_safely(r)) + + if self._is_cancelled(): + logging.info("Asset scan cancelled after sync phase") + cancelled = True + return + + all_prefixes = [ + os.path.abspath(p) for r in roots for p in get_prefixes_for_root(r) + ] + orphans_pruned = _prune_orphans_safely(all_prefixes) + + if self._is_cancelled(): + logging.info("Asset scan cancelled after orphan pruning") + cancelled = True + return + + paths = _collect_paths_for_roots(roots) + total_paths = len(paths) + self._update_progress(total=total_paths) + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "total": total_paths}, + ) + + specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths) + self._update_progress(skipped=skipped_existing) + + if self._is_cancelled(): + logging.info("Asset scan cancelled after building specs") + cancelled = True + return + + batch_size = 500 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + for i in range(0, len(specs), batch_size): + if self._is_cancelled(): + logging.info( + "Asset scan cancelled after %d/%d files (created=%d)", + i, + len(specs), + total_created, + ) + cancelled = True + return + + batch = specs[i : i + batch_size] + batch_tags = {t for spec in batch for t in spec["tags"]} + try: + created = _insert_asset_specs(batch, batch_tags) + total_created += created + except Exception as e: + self._add_error(f"Batch insert failed at offset {i}: {e}") + logging.exception("Batch insert failed at offset %d", i) + + scanned = i + len(batch) + self._update_progress(scanned=scanned, created=total_created) + + now = time.perf_counter() + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + {"scanned": scanned, "total": len(specs), "created": total_created}, + ) + last_progress_time = now + + self._update_progress(scanned=len(specs), created=total_created) + + elapsed = time.perf_counter() - t_start + logging.info( + "Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, orphans_pruned=%d, total=%d)", + roots, + elapsed, + total_created, + skipped_existing, + orphans_pruned, + len(paths), + ) + + self._emit_event( + "assets.seed.completed", + { + "scanned": len(specs), + "total": total_paths, + "created": total_created, + "skipped": skipped_existing, + "elapsed": round(elapsed, 3), + }, + ) + + except Exception as e: + self._add_error(f"Scan failed: {e}") + logging.exception("Asset scan failed") + self._emit_event("assets.seed.error", {"message": str(e)}) + finally: + if cancelled: + self._emit_event( + "assets.seed.cancelled", + { + "scanned": self._progress.scanned if self._progress else 0, + "total": total_paths, + "created": total_created, + }, + ) + with self._lock: + self._state = State.IDLE + + +asset_seeder = AssetSeeder() diff --git a/docs/design/background-asset-seeder.md b/docs/design/background-asset-seeder.md new file mode 100644 index 000000000..496560fca --- /dev/null +++ b/docs/design/background-asset-seeder.md @@ -0,0 +1,195 @@ +# Background Asset Seeder Design Document + +## Problem Statement + +The `seed_assets` function in `app/assets/scanner.py` scans filesystem directories and imports assets into the database. Currently it runs synchronously, which causes two problems: + +1. **Startup blocking**: When ComfyUI starts, `setup_database()` calls `seed_assets(["models"])` synchronously. If the models directory contains thousands of files, startup is delayed significantly before the UI becomes available. + +2. **API request blocking**: The `POST /api/assets/seed` endpoint runs synchronously, blocking the HTTP request until scanning completes. For large directories, this causes request timeouts and a poor user experience. + +## Goals + +- Move asset scanning to a background thread so startup and API requests return immediately +- Provide visibility into scan progress via API and WebSocket events +- Support graceful cancellation of in-progress scans +- Maintain backward compatibility for tests that rely on synchronous behavior +- Ensure thread safety when accessing shared state + +## Non-Goals + +- Targeted/priority scanning of specific paths (addressed separately via synchronous `scan_paths()`) +- Parallel scanning across multiple threads +- Persistent scan state across restarts + +--- + +## Architecture Overview + +### Component: AssetSeeder Singleton + +A new `AssetSeeder` class in `app/assets/seeder.py` manages background scanning with the following responsibilities: + +- Owns a single background `threading.Thread` for scanning work +- Tracks scan state and progress in a thread-safe manner +- Provides cancellation support via `threading.Event` +- Emits WebSocket events for UI progress updates + +### State Machine + +``` +IDLE ──start()──► RUNNING ──(completes)──► IDLE + │ + cancel() + │y + ▼ + CANCELLING ──(thread exits)──► IDLE +``` + +### Integration Points + +| Component | Change | +|-----------|--------| +| `main.py` | Call `asset_seeder.start()` (non-blocking) instead of `seed_assets()` | +| `main.py` | Call `asset_seeder.shutdown()` in `finally` block alongside `cleanup_temp()` | +| API routes | New endpoints for status and cancellation | +| WebSocket | Emit progress events during scanning | +| Test helper | Use `?wait=true` query param for synchronous behavior | + +--- + +## Tasks + +### Task 1: Create AssetSeeder Class + +**Description**: Implement the core `AssetSeeder` singleton class with thread management, state tracking, and cancellation support. + +**Acceptance Criteria**: +- [ ] `AssetSeeder` is a singleton accessible via module-level instance +- [ ] `State` enum with values: `IDLE`, `RUNNING`, `CANCELLING` +- [ ] `start(roots)` method spawns a daemon thread and returns immediately +- [ ] `start()` is idempotent—calling while already running is a no-op and returns `False` +- [ ] `cancel()` method signals the thread to stop gracefully +- [ ] `wait(timeout)` method blocks until thread completes or timeout expires +- [ ] `get_status()` returns current state and progress information +- [ ] All state access is protected by `threading.Lock` +- [ ] Thread creates its own database sessions (no session sharing across threads) +- [ ] Progress tuple tracks: `(scanned, total, created, skipped)` +- [ ] Errors during scanning are captured and available via `get_status()` + +### Task 2: Add Cancellation Checkpoints + +**Description**: Modify the scanning logic to check for cancellation between batches, allowing graceful early termination. + +**Acceptance Criteria**: +- [ ] `threading.Event` is checked between batch operations +- [ ] When cancellation is requested, current batch completes before stopping +- [ ] Partial progress is committed (assets already scanned remain in database) +- [ ] State transitions to `IDLE` after cancellation completes +- [ ] Cancellation is logged with partial progress statistics + +### Task 3: Update Startup Integration + +**Description**: Modify `main.py` to use non-blocking asset seeding at startup. + +**Acceptance Criteria**: +- [ ] `setup_database()` calls `asset_seeder.start(roots=["models"])` instead of `seed_assets()` +- [ ] Startup proceeds immediately without waiting for scan completion +- [ ] `asset_seeder.shutdown()` called in `finally` block alongside `cleanup_temp()` +- [ ] `--disable-assets-autoscan` flag continues to skip seeding entirely +- [ ] Startup logs indicate background scan was initiated (not completed) + +### Task 4: Create API Endpoints + +**Description**: Add REST endpoints for triggering, monitoring, and cancelling background scans. + +**Endpoints**: + +#### POST /api/assets/seed +Trigger a background scan for specified roots. + +**Acceptance Criteria**: +- [ ] Accepts `{"roots": ["models", "input", "output"]}` in request body +- [ ] Returns `202 Accepted` with `{"status": "started"}` when scan begins +- [ ] Returns `409 Conflict` with `{"status": "already_running"}` if scan in progress +- [ ] Supports `?wait=true` query param for synchronous behavior (blocks until complete) +- [ ] With `?wait=true`, returns `200 OK` with final statistics on completion + +#### GET /api/assets/seed/status +Get current scan status and progress. + +**Acceptance Criteria**: +- [ ] Returns `{"state": "IDLE|RUNNING|CANCELLING", "progress": {...}, "errors": [...]}` +- [ ] Progress object includes: `scanned`, `total`, `created`, `skipped` +- [ ] When idle, progress reflects last completed scan (or null if never run) +- [ ] Errors array contains messages from any failures during last/current scan + +#### POST /api/assets/seed/cancel +Request cancellation of in-progress scan. + +**Acceptance Criteria**: +- [ ] Returns `200 OK` with `{"status": "cancelling"}` if scan was running +- [ ] Returns `200 OK` with `{"status": "idle"}` if no scan was running +- [ ] Cancellation is graceful—does not corrupt database state + +### Task 5: Add WebSocket Progress Events + +**Description**: Emit WebSocket events during scanning so the UI can display progress. + +**Acceptance Criteria**: +- [ ] Event type: `assets.seed.started` with `{"roots": [...], "total": N}` +- [ ] Event type: `assets.seed.progress` with `{"scanned": N, "total": M, "created": C}` +- [ ] Event type: `assets.seed.completed` with final statistics +- [ ] Event type: `assets.seed.cancelled` if scan was cancelled +- [ ] Event type: `assets.seed.error` if scan failed with error message +- [ ] Progress events emitted at reasonable intervals (not every file, ~every 100 files or 1 second) + +### Task 6: Update Test Helper + +**Description**: Modify the test helper to use synchronous behavior via query parameter. + +**Acceptance Criteria**: +- [ ] `trigger_sync_seed_assets()` uses `?wait=true` query parameter +- [ ] Tests continue to pass with synchronous blocking behavior +- [ ] Remove artificial `time.sleep(0.2)` delay (no longer needed with `wait=true`) + +### Task 7: Unit Tests for AssetSeeder + +**Description**: Add unit tests covering the seeder state machine and thread safety. + +**Acceptance Criteria**: +- [ ] Test: `start()` transitions state from IDLE to RUNNING +- [ ] Test: `start()` while RUNNING returns False (idempotent) +- [ ] Test: `cancel()` transitions state from RUNNING to CANCELLING +- [ ] Test: `wait()` blocks until thread completes +- [ ] Test: `wait(timeout)` returns False if timeout expires +- [ ] Test: `get_status()` returns correct progress during scan +- [ ] Test: Concurrent `start()` calls are safe (only one thread spawned) +- [ ] Test: Scan commits partial progress on cancellation +- [ ] Test: Database errors are captured in status, don't crash thread + +--- + +## Thread Safety Considerations + +| Shared Resource | Protection Strategy | +|-----------------|---------------------| +| `_state` enum | Protected by `threading.Lock` | +| `_progress` tuple | Protected by `threading.Lock` | +| `_errors` list | Protected by `threading.Lock` | +| `_thread` reference | Protected by `threading.Lock` | +| `_cancel_event` | `threading.Event` (inherently thread-safe) | +| Database sessions | Created per-operation inside thread (no sharing) | + +## Error Handling + +- Database connection failures: Log error, set state to IDLE, populate errors list +- Individual file scan failures: Log warning, continue with next file, increment error count +- Thread crashes: Caught by outer try/except, state reset to IDLE, error captured + +## Future Considerations + +- **Priority queue**: If targeted scans need to be non-blocking in the future, the seeder could be extended with a priority queue +- **Persistent state**: Scan progress could be persisted to allow resume after restart +- **Parallel scanning**: Multiple threads could scan different roots concurrently (requires careful session management) +- **Throttling**: If scanning competes with generation (e.g., disk I/O contention when hashing large files), add configurable sleep between batches. Currently considered low risk since scanning is I/O-bound and generation is GPU-bound. diff --git a/main.py b/main.py index 92d705b4d..b9659b6e8 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.scanner import seed_assets +from app.assets.seeder import asset_seeder import itertools import utils.extra_config import logging @@ -357,7 +357,8 @@ def setup_database(): if dependencies_available(): init_db() if not args.disable_assets_autoscan: - seed_assets(["models"], enable_logging=True) + if asset_seeder.start(roots=("models",)): + logging.info("Background asset scan initiated for models") except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") @@ -441,5 +442,6 @@ if __name__ == "__main__": event_loop.run_until_complete(x) except KeyboardInterrupt: logging.info("\nStopped server") - - cleanup_temp() + finally: + asset_seeder.shutdown() + cleanup_temp() diff --git a/tests-unit/assets_test/helpers.py b/tests-unit/assets_test/helpers.py index 1a486581b..72d875bea 100644 --- a/tests-unit/assets_test/helpers.py +++ b/tests-unit/assets_test/helpers.py @@ -1,13 +1,14 @@ """Helper functions for assets integration tests.""" -import time - import requests def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: - """Force a fast sync/seed pass by calling the seed endpoint.""" - session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) - time.sleep(0.2) + """Force a synchronous sync/seed pass by calling the seed endpoint with wait=true.""" + session.post( + base_url + "/api/assets/seed?wait=true", + json={"roots": ["models", "input", "output"]}, + timeout=60, + ) def get_asset_filename(asset_hash: str, extension: str) -> str: diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py new file mode 100644 index 000000000..bfb3c5839 --- /dev/null +++ b/tests-unit/seeder_test/test_seeder.py @@ -0,0 +1,347 @@ +"""Unit tests for the AssetSeeder background scanning class.""" + +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest + +from app.assets.seeder import AssetSeeder, Progress, State + + +@pytest.fixture +def fresh_seeder(): + """Create a fresh AssetSeeder instance for testing (bypasses singleton).""" + seeder = object.__new__(AssetSeeder) + seeder._initialized = False + seeder.__init__() + yield seeder + seeder.shutdown(timeout=1.0) + + +@pytest.fixture +def mock_dependencies(): + """Mock all external dependencies for isolated testing.""" + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder._sync_root_safely", return_value=set()), + patch("app.assets.seeder._prune_orphans_safely", return_value=0), + patch("app.assets.seeder._collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder._insert_asset_specs", return_value=0), + ): + yield + + +class TestSeederStateTransitions: + """Test state machine transitions.""" + + def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder): + assert fresh_seeder.get_status().state == State.IDLE + + def test_start_transitions_to_running( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + started = fresh_seeder.start(roots=("models",)) + assert started is True + status = fresh_seeder.get_status() + assert status.state in (State.RUNNING, State.IDLE) + + def test_start_while_running_returns_false( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + second_start = fresh_seeder.start(roots=("models",)) + assert second_start is False + + barrier.set() + + def test_cancel_transitions_to_cancelling( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + cancelled = fresh_seeder.cancel() + assert cancelled is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder): + cancelled = fresh_seeder.cancel() + assert cancelled is False + + def test_state_returns_to_idle_after_completion( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + +class TestSeederWait: + """Test wait() behavior.""" + + def test_wait_blocks_until_complete( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + def test_wait_returns_false_on_timeout( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=10.0) + return [] + + with patch( + "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=0.1) + assert completed is False + + barrier.set() + + def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder): + completed = fresh_seeder.wait(timeout=1.0) + assert completed is True + + +class TestSeederProgress: + """Test progress tracking.""" + + def test_get_status_returns_progress_during_scan( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + progress_seen = [] + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return ["/path/file1.safetensors", "/path/file2.safetensors"] + + with patch( + "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + status = fresh_seeder.get_status() + assert status.progress is not None + progress_seen.append(status.progress) + + barrier.set() + + def test_progress_callback_is_invoked( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + progress_updates: list[Progress] = [] + + def callback(p: Progress): + progress_updates.append(p) + + with patch( + "app.assets.seeder._collect_paths_for_roots", + return_value=[f"/path/file{i}.safetensors" for i in range(10)], + ): + fresh_seeder.start(roots=("models",), progress_callback=callback) + fresh_seeder.wait(timeout=5.0) + + assert len(progress_updates) > 0 + + +class TestSeederCancellation: + """Test cancellation behavior.""" + + def test_scan_commits_partial_progress_on_cancellation( + self, fresh_seeder: AssetSeeder + ): + insert_count = 0 + barrier = threading.Event() + + def slow_insert(specs, tags): + nonlocal insert_count + insert_count += 1 + if insert_count >= 2: + barrier.wait(timeout=5.0) + return len(specs) + + paths = [f"/path/file{i}.safetensors" for i in range(1500)] + specs = [ + { + "abs_path": p, + "size_bytes": 100, + "mtime_ns": 0, + "info_name": f"file{i}", + "tags": [], + "fname": f"file{i}", + } + for i, p in enumerate(paths) + ] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder._sync_root_safely", return_value=set()), + patch("app.assets.seeder._prune_orphans_safely", return_value=0), + patch("app.assets.seeder._collect_paths_for_roots", return_value=paths), + patch("app.assets.seeder._build_asset_specs", return_value=(specs, set(), 0)), + patch("app.assets.seeder._insert_asset_specs", side_effect=slow_insert), + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.1) + + fresh_seeder.cancel() + barrier.set() + fresh_seeder.wait(timeout=5.0) + + assert insert_count >= 1 + + +class TestSeederErrorHandling: + """Test error handling behavior.""" + + def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder._sync_root_safely", return_value=set()), + patch("app.assets.seeder._prune_orphans_safely", return_value=0), + patch( + "app.assets.seeder._collect_paths_for_roots", + return_value=["/path/file.safetensors"], + ), + patch( + "app.assets.seeder._build_asset_specs", + return_value=( + [ + { + "abs_path": "/path/file.safetensors", + "size_bytes": 100, + "mtime_ns": 0, + "info_name": "file", + "tags": [], + "fname": "file", + } + ], + set(), + 0, + ), + ), + patch( + "app.assets.seeder._insert_asset_specs", + side_effect=Exception("DB connection failed"), + ), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "DB connection failed" in status.errors[0] + + def test_dependencies_unavailable_captured_in_errors( + self, fresh_seeder: AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "dependencies" in status.errors[0].lower() + + def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder._sync_root_safely", + side_effect=RuntimeError("Unexpected crash"), + ), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert status.state == State.IDLE + assert len(status.errors) > 0 + + +class TestSeederThreadSafety: + """Test thread safety of concurrent operations.""" + + def test_concurrent_start_calls_spawn_only_one_thread( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + ): + results = [] + + def try_start(): + results.append(fresh_seeder.start(roots=("models",))) + + threads = [threading.Thread(target=try_start) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + barrier.set() + + assert sum(results) == 1 + + def test_get_status_safe_during_scan( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + + statuses = [] + for _ in range(100): + statuses.append(fresh_seeder.get_status()) + time.sleep(0.001) + + barrier.set() + + assert all(s.state in (State.RUNNING, State.IDLE, State.CANCELLING) for s in statuses)