diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 3eb1ad9ee..3455943ce 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -46,6 +46,7 @@ class _AssetAccumulator(TypedDict): size_db: int states: list[_StateInfo] + RootType = Literal["models", "input", "output"] @@ -200,7 +201,7 @@ def sync_cache_states_with_filesystem( return survivors if collect_existing_paths else None -def _sync_root_safely(root: RootType) -> set[str]: +def sync_root_safely(root: RootType) -> set[str]: """Sync a single root's cache states with the filesystem. Returns survivors (existing paths) or empty set on failure. @@ -220,7 +221,7 @@ def _sync_root_safely(root: RootType) -> set[str]: return set() -def _prune_orphans_safely(prefixes: list[str]) -> int: +def prune_orphans_safely(prefixes: list[str]) -> int: """Prune orphaned assets outside the given prefixes. Returns count pruned or 0 on failure. @@ -235,7 +236,7 @@ def _prune_orphans_safely(prefixes: list[str]) -> int: return 0 -def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: +def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: """Collect all file paths for the given roots.""" paths: list[str] = [] if "models" in roots: @@ -247,7 +248,7 @@ def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: return paths -def _build_asset_specs( +def build_asset_specs( paths: list[str], existing_paths: set[str], enable_metadata_extraction: bool = True, @@ -303,7 +304,7 @@ def _build_asset_specs( return specs, tag_pool, skipped -def _insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: +def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: """Insert asset specs into database, returning count of created infos.""" if not specs: return 0 @@ -330,11 +331,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No existing_paths: set[str] = set() for r in roots: - existing_paths.update(_sync_root_safely(r)) + existing_paths.update(sync_root_safely(r)) - paths = _collect_paths_for_roots(roots) - specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths) - created = _insert_asset_specs(specs, tag_pool) + paths = collect_paths_for_roots(roots) + specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths) + created = insert_asset_specs(specs, tag_pool) if enable_logging: logging.info( diff --git a/app/assets/seeder.py b/app/assets/seeder.py index cdf28f0b9..c63cd7a4c 100644 --- a/app/assets/seeder.py +++ b/app/assets/seeder.py @@ -10,18 +10,18 @@ from typing import TYPE_CHECKING, Callable from app.assets.scanner import ( RootType, - _build_asset_specs, - _collect_paths_for_roots, - _insert_asset_specs, - _prune_orphans_safely, - _sync_root_safely, + build_asset_specs, + collect_paths_for_roots, get_all_known_prefixes, get_prefixes_for_root, + insert_asset_specs, + prune_orphans_safely, + sync_root_safely, ) from app.database.db import dependencies_available if TYPE_CHECKING: - from server import PromptServer + pass class State(Enum): @@ -193,11 +193,13 @@ class AssetSeeder: return 0 if not dependencies_available(): - logging.warning("Database dependencies not available, skipping orphan pruning") + logging.warning( + "Database dependencies not available, skipping orphan pruning" + ) return 0 all_prefixes = get_all_known_prefixes() - pruned = _prune_orphans_safely(all_prefixes) + pruned = prune_orphans_safely(all_prefixes) if pruned > 0: logging.info("Pruned %d orphaned assets", pruned) return pruned @@ -288,7 +290,7 @@ class AssetSeeder: if self._prune_first: all_prefixes = get_all_known_prefixes() - pruned = _prune_orphans_safely(all_prefixes) + pruned = prune_orphans_safely(all_prefixes) if pruned > 0: logging.info("Pruned %d orphaned assets before scan", pruned) @@ -305,14 +307,14 @@ class AssetSeeder: logging.info("Asset scan cancelled during sync phase") cancelled = True return - existing_paths.update(_sync_root_safely(r)) + existing_paths.update(sync_root_safely(r)) if self._is_cancelled(): logging.info("Asset scan cancelled after sync phase") cancelled = True return - paths = _collect_paths_for_roots(roots) + paths = collect_paths_for_roots(roots) total_paths = len(paths) self._update_progress(total=total_paths) @@ -321,7 +323,7 @@ class AssetSeeder: {"roots": list(roots), "total": total_paths}, ) - specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths) + specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths) self._update_progress(skipped=skipped_existing) if self._is_cancelled(): @@ -347,7 +349,7 @@ class AssetSeeder: batch = specs[i : i + batch_size] batch_tags = {t for spec in batch for t in spec["tags"]} try: - created = _insert_asset_specs(batch, batch_tags) + created = insert_asset_specs(batch, batch_tags) total_created += created except Exception as e: self._add_error(f"Batch insert failed at offset {i}: {e}") @@ -360,7 +362,11 @@ class AssetSeeder: if now - last_progress_time >= progress_interval: self._emit_event( "assets.seed.progress", - {"scanned": scanned, "total": len(specs), "created": total_created}, + { + "scanned": scanned, + "total": len(specs), + "created": total_created, + }, ) last_progress_time = now diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py index eecee8f4d..0c1e17bc6 100644 --- a/tests-unit/seeder_test/test_seeder.py +++ b/tests-unit/seeder_test/test_seeder.py @@ -2,7 +2,7 @@ import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest @@ -24,10 +24,10 @@ def mock_dependencies(): """Mock all external dependencies for isolated testing.""" with ( patch("app.assets.seeder.dependencies_available", return_value=True), - patch("app.assets.seeder._sync_root_safely", return_value=set()), - patch("app.assets.seeder._collect_paths_for_roots", return_value=[]), - patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)), - patch("app.assets.seeder._insert_asset_specs", return_value=0), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), ): yield @@ -56,7 +56,7 @@ class TestSeederStateTransitions: return [] with patch( - "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect ): fresh_seeder.start(roots=("models",)) time.sleep(0.05) @@ -76,7 +76,7 @@ class TestSeederStateTransitions: return [] with patch( - "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect ): fresh_seeder.start(roots=("models",)) time.sleep(0.05) @@ -121,7 +121,7 @@ class TestSeederWait: return [] with patch( - "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect ): fresh_seeder.start(roots=("models",)) completed = fresh_seeder.wait(timeout=0.1) @@ -148,7 +148,7 @@ class TestSeederProgress: return ["/path/file1.safetensors", "/path/file2.safetensors"] with patch( - "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect ): fresh_seeder.start(roots=("models",)) time.sleep(0.05) @@ -168,7 +168,7 @@ class TestSeederProgress: progress_updates.append(p) with patch( - "app.assets.seeder._collect_paths_for_roots", + "app.assets.seeder.collect_paths_for_roots", return_value=[f"/path/file{i}.safetensors" for i in range(10)], ): fresh_seeder.start(roots=("models",), progress_callback=callback) @@ -208,10 +208,12 @@ class TestSeederCancellation: with ( patch("app.assets.seeder.dependencies_available", return_value=True), - patch("app.assets.seeder._sync_root_safely", return_value=set()), - patch("app.assets.seeder._collect_paths_for_roots", return_value=paths), - patch("app.assets.seeder._build_asset_specs", return_value=(specs, set(), 0)), - patch("app.assets.seeder._insert_asset_specs", side_effect=slow_insert), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch( + "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0) + ), + patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert), ): fresh_seeder.start(roots=("models",)) time.sleep(0.1) @@ -229,13 +231,13 @@ class TestSeederErrorHandling: def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder): with ( patch("app.assets.seeder.dependencies_available", return_value=True), - patch("app.assets.seeder._sync_root_safely", return_value=set()), + patch("app.assets.seeder.sync_root_safely", return_value=set()), patch( - "app.assets.seeder._collect_paths_for_roots", + "app.assets.seeder.collect_paths_for_roots", return_value=["/path/file.safetensors"], ), patch( - "app.assets.seeder._build_asset_specs", + "app.assets.seeder.build_asset_specs", return_value=( [ { @@ -252,7 +254,7 @@ class TestSeederErrorHandling: ), ), patch( - "app.assets.seeder._insert_asset_specs", + "app.assets.seeder.insert_asset_specs", side_effect=Exception("DB connection failed"), ), ): @@ -278,7 +280,7 @@ class TestSeederErrorHandling: with ( patch("app.assets.seeder.dependencies_available", return_value=True), patch( - "app.assets.seeder._sync_root_safely", + "app.assets.seeder.sync_root_safely", side_effect=RuntimeError("Unexpected crash"), ), ): @@ -303,7 +305,7 @@ class TestSeederThreadSafety: return [] with patch( - "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect ): results = [] @@ -330,7 +332,7 @@ class TestSeederThreadSafety: return [] with patch( - "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect ): fresh_seeder.start(roots=("models",)) @@ -341,7 +343,10 @@ class TestSeederThreadSafety: barrier.set() - assert all(s.state in (State.RUNNING, State.IDLE, State.CANCELLING) for s in statuses) + assert all( + s.state in (State.RUNNING, State.IDLE, State.CANCELLING) + for s in statuses + ) class TestSeederPruneOrphans: @@ -350,8 +355,13 @@ class TestSeederPruneOrphans: def test_prune_orphans_when_idle(self, fresh_seeder: AssetSeeder): with ( patch("app.assets.seeder.dependencies_available", return_value=True), - patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models", "/input", "/output"]), - patch("app.assets.seeder._prune_orphans_safely", return_value=5) as mock_prune, + patch( + "app.assets.seeder.get_all_known_prefixes", + return_value=["/models", "/input", "/output"], + ), + patch( + "app.assets.seeder.prune_orphans_safely", return_value=5 + ) as mock_prune, ): result = fresh_seeder.prune_orphans() assert result == 5 @@ -367,7 +377,7 @@ class TestSeederPruneOrphans: return [] with patch( - "app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect ): fresh_seeder.start(roots=("models",)) time.sleep(0.05) @@ -400,11 +410,11 @@ class TestSeederPruneOrphans: with ( patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]), - patch("app.assets.seeder._prune_orphans_safely", side_effect=track_prune), - patch("app.assets.seeder._sync_root_safely", side_effect=track_sync), - patch("app.assets.seeder._collect_paths_for_roots", return_value=[]), - patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)), - patch("app.assets.seeder._insert_asset_specs", return_value=0), + patch("app.assets.seeder.prune_orphans_safely", side_effect=track_prune), + patch("app.assets.seeder.sync_root_safely", side_effect=track_sync), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), ): fresh_seeder.start(roots=("models",), prune_first=True) fresh_seeder.wait(timeout=5.0)