refactor: make scanner helper functions public

Rename _sync_root_safely, _prune_orphans_safely, _collect_paths_for_roots,
_build_asset_specs, and _insert_asset_specs to remove underscore prefix
since they are used by seeder.py as part of the public API.

Amp-Thread-ID: https://ampcode.com/threads/T-019c3037-df32-7138-99d8-b4b824d896b3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-02-05 19:01:46 -08:00
parent 28c4b58dd6
commit ebb2f5b0e9
3 changed files with 70 additions and 53 deletions

View File

@@ -46,6 +46,7 @@ class _AssetAccumulator(TypedDict):
size_db: int
states: list[_StateInfo]
RootType = Literal["models", "input", "output"]
@@ -200,7 +201,7 @@ def sync_cache_states_with_filesystem(
return survivors if collect_existing_paths else None
def _sync_root_safely(root: RootType) -> set[str]:
def sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's cache states with the filesystem.
Returns survivors (existing paths) or empty set on failure.
@@ -220,7 +221,7 @@ def _sync_root_safely(root: RootType) -> set[str]:
return set()
def _prune_orphans_safely(prefixes: list[str]) -> int:
def prune_orphans_safely(prefixes: list[str]) -> int:
"""Prune orphaned assets outside the given prefixes.
Returns count pruned or 0 on failure.
@@ -235,7 +236,7 @@ def _prune_orphans_safely(prefixes: list[str]) -> int:
return 0
def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots."""
paths: list[str] = []
if "models" in roots:
@@ -247,7 +248,7 @@ def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
return paths
def _build_asset_specs(
def build_asset_specs(
paths: list[str],
existing_paths: set[str],
enable_metadata_extraction: bool = True,
@@ -303,7 +304,7 @@ def _build_asset_specs(
return specs, tag_pool, skipped
def _insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created infos."""
if not specs:
return 0
@@ -330,11 +331,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
existing_paths: set[str] = set()
for r in roots:
existing_paths.update(_sync_root_safely(r))
existing_paths.update(sync_root_safely(r))
paths = _collect_paths_for_roots(roots)
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths)
created = _insert_asset_specs(specs, tag_pool)
paths = collect_paths_for_roots(roots)
specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths)
created = insert_asset_specs(specs, tag_pool)
if enable_logging:
logging.info(

View File

@@ -10,18 +10,18 @@ from typing import TYPE_CHECKING, Callable
from app.assets.scanner import (
RootType,
_build_asset_specs,
_collect_paths_for_roots,
_insert_asset_specs,
_prune_orphans_safely,
_sync_root_safely,
build_asset_specs,
collect_paths_for_roots,
get_all_known_prefixes,
get_prefixes_for_root,
insert_asset_specs,
prune_orphans_safely,
sync_root_safely,
)
from app.database.db import dependencies_available
if TYPE_CHECKING:
from server import PromptServer
pass
class State(Enum):
@@ -193,11 +193,13 @@ class AssetSeeder:
return 0
if not dependencies_available():
logging.warning("Database dependencies not available, skipping orphan pruning")
logging.warning(
"Database dependencies not available, skipping orphan pruning"
)
return 0
all_prefixes = get_all_known_prefixes()
pruned = _prune_orphans_safely(all_prefixes)
pruned = prune_orphans_safely(all_prefixes)
if pruned > 0:
logging.info("Pruned %d orphaned assets", pruned)
return pruned
@@ -288,7 +290,7 @@ class AssetSeeder:
if self._prune_first:
all_prefixes = get_all_known_prefixes()
pruned = _prune_orphans_safely(all_prefixes)
pruned = prune_orphans_safely(all_prefixes)
if pruned > 0:
logging.info("Pruned %d orphaned assets before scan", pruned)
@@ -305,14 +307,14 @@ class AssetSeeder:
logging.info("Asset scan cancelled during sync phase")
cancelled = True
return
existing_paths.update(_sync_root_safely(r))
existing_paths.update(sync_root_safely(r))
if self._is_cancelled():
logging.info("Asset scan cancelled after sync phase")
cancelled = True
return
paths = _collect_paths_for_roots(roots)
paths = collect_paths_for_roots(roots)
total_paths = len(paths)
self._update_progress(total=total_paths)
@@ -321,7 +323,7 @@ class AssetSeeder:
{"roots": list(roots), "total": total_paths},
)
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths)
specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths)
self._update_progress(skipped=skipped_existing)
if self._is_cancelled():
@@ -347,7 +349,7 @@ class AssetSeeder:
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = _insert_asset_specs(batch, batch_tags)
created = insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
@@ -360,7 +362,11 @@ class AssetSeeder:
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{"scanned": scanned, "total": len(specs), "created": total_created},
{
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
)
last_progress_time = now

View File

@@ -2,7 +2,7 @@
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
@@ -24,10 +24,10 @@ def mock_dependencies():
"""Mock all external dependencies for isolated testing."""
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder._sync_root_safely", return_value=set()),
patch("app.assets.seeder._collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder._insert_asset_specs", return_value=0),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
):
yield
@@ -56,7 +56,7 @@ class TestSeederStateTransitions:
return []
with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
@@ -76,7 +76,7 @@ class TestSeederStateTransitions:
return []
with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
@@ -121,7 +121,7 @@ class TestSeederWait:
return []
with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=0.1)
@@ -148,7 +148,7 @@ class TestSeederProgress:
return ["/path/file1.safetensors", "/path/file2.safetensors"]
with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
@@ -168,7 +168,7 @@ class TestSeederProgress:
progress_updates.append(p)
with patch(
"app.assets.seeder._collect_paths_for_roots",
"app.assets.seeder.collect_paths_for_roots",
return_value=[f"/path/file{i}.safetensors" for i in range(10)],
):
fresh_seeder.start(roots=("models",), progress_callback=callback)
@@ -208,10 +208,12 @@ class TestSeederCancellation:
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder._sync_root_safely", return_value=set()),
patch("app.assets.seeder._collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder._build_asset_specs", return_value=(specs, set(), 0)),
patch("app.assets.seeder._insert_asset_specs", side_effect=slow_insert),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch(
"app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
),
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
):
fresh_seeder.start(roots=("models",))
time.sleep(0.1)
@@ -229,13 +231,13 @@ class TestSeederErrorHandling:
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder._sync_root_safely", return_value=set()),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch(
"app.assets.seeder._collect_paths_for_roots",
"app.assets.seeder.collect_paths_for_roots",
return_value=["/path/file.safetensors"],
),
patch(
"app.assets.seeder._build_asset_specs",
"app.assets.seeder.build_asset_specs",
return_value=(
[
{
@@ -252,7 +254,7 @@ class TestSeederErrorHandling:
),
),
patch(
"app.assets.seeder._insert_asset_specs",
"app.assets.seeder.insert_asset_specs",
side_effect=Exception("DB connection failed"),
),
):
@@ -278,7 +280,7 @@ class TestSeederErrorHandling:
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
"app.assets.seeder._sync_root_safely",
"app.assets.seeder.sync_root_safely",
side_effect=RuntimeError("Unexpected crash"),
),
):
@@ -303,7 +305,7 @@ class TestSeederThreadSafety:
return []
with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
results = []
@@ -330,7 +332,7 @@ class TestSeederThreadSafety:
return []
with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
@@ -341,7 +343,10 @@ class TestSeederThreadSafety:
barrier.set()
assert all(s.state in (State.RUNNING, State.IDLE, State.CANCELLING) for s in statuses)
assert all(
s.state in (State.RUNNING, State.IDLE, State.CANCELLING)
for s in statuses
)
class TestSeederPruneOrphans:
@@ -350,8 +355,13 @@ class TestSeederPruneOrphans:
def test_prune_orphans_when_idle(self, fresh_seeder: AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models", "/input", "/output"]),
patch("app.assets.seeder._prune_orphans_safely", return_value=5) as mock_prune,
patch(
"app.assets.seeder.get_all_known_prefixes",
return_value=["/models", "/input", "/output"],
),
patch(
"app.assets.seeder.prune_orphans_safely", return_value=5
) as mock_prune,
):
result = fresh_seeder.prune_orphans()
assert result == 5
@@ -367,7 +377,7 @@ class TestSeederPruneOrphans:
return []
with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
@@ -400,11 +410,11 @@ class TestSeederPruneOrphans:
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]),
patch("app.assets.seeder._prune_orphans_safely", side_effect=track_prune),
patch("app.assets.seeder._sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder._collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder._insert_asset_specs", return_value=0),
patch("app.assets.seeder.prune_orphans_safely", side_effect=track_prune),
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
):
fresh_seeder.start(roots=("models",), prune_first=True)
fresh_seeder.wait(timeout=5.0)