diff --git a/app/assets/seeder.py b/app/assets/seeder.py index 73a663f11..0c65032a1 100644 --- a/app/assets/seeder.py +++ b/app/assets/seeder.py @@ -30,6 +30,7 @@ class State(Enum): IDLE = "IDLE" RUNNING = "RUNNING" + PAUSED = "PAUSED" CANCELLING = "CANCELLING" @@ -90,6 +91,8 @@ class AssetSeeder: self._errors: list[str] = [] self._thread: threading.Thread | None = None self._cancel_event = threading.Event() + self._pause_event = threading.Event() + self._pause_event.set() # Start unpaused (set = running, clear = paused) self._roots: tuple[RootType, ...] = () self._phase: ScanPhase = ScanPhase.FULL self._compute_hashes: bool = False @@ -127,6 +130,7 @@ class AssetSeeder: self._compute_hashes = compute_hashes self._progress_callback = progress_callback self._cancel_event.clear() + self._pause_event.set() # Ensure unpaused when starting self._thread = threading.Thread( target=self._run_scan, name="AssetSeeder", @@ -187,15 +191,94 @@ class AssetSeeder: """Request cancellation of the current scan. Returns: - True if cancellation was requested, False if not running + True if cancellation was requested, False if not running or paused + """ + with self._lock: + if self._state not in (State.RUNNING, State.PAUSED): + return False + self._state = State.CANCELLING + self._cancel_event.set() + self._pause_event.set() # Unblock if paused so thread can exit + return True + + def stop(self) -> bool: + """Stop the current scan (alias for cancel). + + Returns: + True if stop was requested, False if not running + """ + return self.cancel() + + def pause(self) -> bool: + """Pause the current scan. + + The scan will complete its current batch before pausing. + + Returns: + True if pause was requested, False if not running """ with self._lock: if self._state != State.RUNNING: return False - self._state = State.CANCELLING - self._cancel_event.set() + self._state = State.PAUSED + self._pause_event.clear() return True + def resume(self) -> bool: + """Resume a paused scan. + + Returns: + True if resumed, False if not paused + """ + with self._lock: + if self._state != State.PAUSED: + return False + self._state = State.RUNNING + self._pause_event.set() + self._emit_event("assets.seed.resumed", {}) + return True + + def restart( + self, + roots: tuple[RootType, ...] | None = None, + phase: ScanPhase | None = None, + progress_callback: ProgressCallback | None = None, + prune_first: bool | None = None, + compute_hashes: bool | None = None, + timeout: float = 5.0, + ) -> bool: + """Cancel any running scan and start a new one. + + Args: + roots: Roots to scan (defaults to previous roots) + phase: Scan phase (defaults to previous phase) + progress_callback: Progress callback (defaults to previous) + prune_first: Prune before scan (defaults to previous) + compute_hashes: Compute hashes (defaults to previous) + timeout: Max seconds to wait for current scan to stop + + Returns: + True if new scan was started, False if failed to stop previous + """ + with self._lock: + prev_roots = self._roots + prev_phase = self._phase + prev_callback = self._progress_callback + prev_prune = getattr(self, "_prune_first", False) + prev_hashes = self._compute_hashes + + self.cancel() + if not self.wait(timeout=timeout): + return False + + return self.start( + roots=roots if roots is not None else prev_roots, + phase=phase if phase is not None else prev_phase, + progress_callback=progress_callback if progress_callback is not None else prev_callback, + prune_first=prune_first if prune_first is not None else prev_prune, + compute_hashes=compute_hashes if compute_hashes is not None else prev_hashes, + ) + def wait(self, timeout: float | None = None) -> bool: """Wait for the current scan to complete. @@ -284,6 +367,21 @@ class AssetSeeder: """Check if cancellation has been requested.""" return self._cancel_event.is_set() + def _check_pause_and_cancel(self) -> bool: + """Block while paused, then check if cancelled. + + Call this at checkpoint locations in scan loops. It will: + 1. Block indefinitely while paused (until resume or cancel) + 2. Return True if cancelled, False to continue + + Returns: + True if scan should stop, False to continue + """ + if not self._pause_event.is_set(): + self._emit_event("assets.seed.paused", {}) + self._pause_event.wait() # Blocks if paused + return self._is_cancelled() + def _emit_event(self, event_type: str, data: dict) -> None: """Emit a WebSocket event if server is available.""" try: @@ -377,7 +475,7 @@ class AssetSeeder: if marked > 0: logging.info("Marked %d cache states as missing before scan", marked) - if self._is_cancelled(): + if self._check_pause_and_cancel(): logging.info("Asset scan cancelled after pruning phase") cancelled = True return @@ -388,7 +486,7 @@ class AssetSeeder: if phase in (ScanPhase.FAST, ScanPhase.FULL): total_created, skipped_existing, total_paths = self._run_fast_phase(roots) - if self._is_cancelled(): + if self._check_pause_and_cancel(): cancelled = True return @@ -404,7 +502,7 @@ class AssetSeeder: # Phase 2: Enrichment scan (metadata + hashes) if phase in (ScanPhase.ENRICH, ScanPhase.FULL): - if self._is_cancelled(): + if self._check_pause_and_cancel(): cancelled = True return @@ -469,11 +567,11 @@ class AssetSeeder: existing_paths: set[str] = set() for r in roots: - if self._is_cancelled(): + if self._check_pause_and_cancel(): return total_created, skipped_existing, 0 existing_paths.update(sync_root_safely(r)) - if self._is_cancelled(): + if self._check_pause_and_cancel(): return total_created, skipped_existing, 0 paths = collect_paths_for_roots(roots) @@ -489,7 +587,7 @@ class AssetSeeder: specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths) self._update_progress(skipped=skipped_existing) - if self._is_cancelled(): + if self._check_pause_and_cancel(): return total_created, skipped_existing, total_paths batch_size = 500 @@ -497,7 +595,7 @@ class AssetSeeder: progress_interval = 1.0 for i in range(0, len(specs), batch_size): - if self._is_cancelled(): + if self._check_pause_and_cancel(): logging.info( "Fast scan cancelled after %d/%d files (created=%d)", i, @@ -554,7 +652,7 @@ class AssetSeeder: ) while True: - if self._is_cancelled(): + if self._check_pause_and_cancel(): logging.info("Enrich scan cancelled after %d assets", total_enriched) break diff --git a/server.py b/server.py index 2300393b2..7ed34244c 100644 --- a/server.py +++ b/server.py @@ -33,7 +33,7 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal -from app.assets.scanner import seed_assets +from app.assets.seeder import asset_seeder from app.assets.api.routes import register_assets_system from app.user_manager import UserManager @@ -691,10 +691,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): - try: - seed_assets(["models"]) - except Exception as e: - logging.error(f"Failed to seed assets: {e}") + asset_seeder.start(roots=("models", "input", "output")) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py index 121896119..750b381da 100644 --- a/tests-unit/seeder_test/test_seeder.py +++ b/tests-unit/seeder_test/test_seeder.py @@ -523,3 +523,261 @@ class TestSeederPhases: assert len(fast_called) == 1 assert len(enrich_called) == 1 + + +class TestSeederPauseResume: + """Test pause/resume behavior.""" + + def test_pause_transitions_to_paused( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + paused = fresh_seeder.pause() + assert paused is True + assert fresh_seeder.get_status().state == State.PAUSED + + barrier.set() + + def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder): + paused = fresh_seeder.pause() + assert paused is False + + def test_resume_returns_to_running( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + resumed = fresh_seeder.resume() + assert resumed is True + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_resume_when_not_paused_returns_false( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + resumed = fresh_seeder.resume() + assert resumed is False + + barrier.set() + + def test_cancel_while_paused_works( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached_checkpoint = threading.Event() + + def slow_collect(*args): + reached_checkpoint.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + reached_checkpoint.wait(timeout=1.0) + + fresh_seeder.pause() + time.sleep(0.05) + + cancelled = fresh_seeder.cancel() + assert cancelled is True + + barrier.set() + fresh_seeder.wait(timeout=5.0) + assert fresh_seeder.get_status().state == State.IDLE + + def test_pause_blocks_scan_until_resume(self, fresh_seeder: AssetSeeder): + """Verify scan blocks at checkpoint while paused.""" + batch_count = 0 + pause_detected = threading.Event() + resume_signal = threading.Event() + + def counting_insert(specs, tags): + nonlocal batch_count + batch_count += 1 + if batch_count == 1: + pause_detected.set() + resume_signal.wait(timeout=5.0) + return len(specs) + + paths = [f"/path/file{i}.safetensors" for i in range(1000)] + specs = [ + { + "abs_path": p, + "size_bytes": 100, + "mtime_ns": 0, + "info_name": f"file{i}", + "tags": [], + "fname": f"file{i}", + "metadata": None, + "hash": None, + "mime_type": None, + } + for i, p in enumerate(paths) + ] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch("app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)), + patch("app.assets.seeder.insert_asset_specs", side_effect=counting_insert), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + pause_detected.wait(timeout=2.0) + + fresh_seeder.pause() + count_at_pause = batch_count + time.sleep(0.1) + assert batch_count == count_at_pause + + fresh_seeder.resume() + resume_signal.set() + fresh_seeder.wait(timeout=5.0) + + assert batch_count > count_at_pause + + +class TestSeederStopRestart: + """Test stop and restart behavior.""" + + def test_stop_is_alias_for_cancel( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + stopped = fresh_seeder.stop() + assert stopped is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_restart_cancels_and_starts_new_scan( + self, fresh_seeder: AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + start_count = 0 + + def slow_collect(*args): + nonlocal start_count + start_count += 1 + if start_count == 1: + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + time.sleep(0.05) + + barrier.set() + restarted = fresh_seeder.restart() + assert restarted is True + + fresh_seeder.wait(timeout=5.0) + assert start_count == 2 + + def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder): + """Verify restart uses previous params when not overridden.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("input", "output")) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart() + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("input", "output") + assert collected_roots[1] == ("input", "output") + + def test_restart_can_override_params(self, fresh_seeder: AssetSeeder): + """Verify restart can override previous params.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart(roots=("input",)) + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("models",) + assert collected_roots[1] == ("input",)