fix: address code review feedback

- Fix missing import for compute_filename_for_reference in ingest.py
- Apply code review fixes across routes, queries, scanner, seeder,
  hashing, ingest, path_utils, main, and server
- Update and add tests for sync references and seeder

Amp-Thread-ID: https://ampcode.com/threads/T-019cb61a-ed54-738c-a05f-9b5242e513f3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-03-03 15:51:35 -08:00
parent 3232f48a41
commit 4d4c2cedd3
13 changed files with 675 additions and 218 deletions

View File

@@ -0,0 +1,350 @@
"""Tests for sync_references_with_filesystem in scanner.py."""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import (
Asset,
AssetReference,
AssetReferenceTag,
Base,
Tag,
)
from app.assets.scanner import sync_references_with_filesystem
from app.assets.services.file_utils import get_mtime_ns
@pytest.fixture
def db_engine():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session(db_engine):
with Session(db_engine) as sess:
yield sess
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str:
"""Create a file and return its absolute path (no symlink resolution)."""
p = temp_dir / name
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(content)
return os.path.abspath(str(p))
def _stat_mtime_ns(path: str) -> int:
return get_mtime_ns(os.stat(path, follow_symlinks=True))
def _make_asset(
session: Session,
asset_id: str,
file_path: str,
ref_id: str,
*,
asset_hash: str | None = None,
size_bytes: int = 100,
mtime_ns: int | None = None,
needs_verify: bool = False,
is_missing: bool = False,
) -> tuple[Asset, AssetReference]:
"""Insert an Asset + AssetReference and flush."""
asset = session.get(Asset, asset_id)
if asset is None:
asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes)
session.add(asset)
session.flush()
ref = AssetReference(
id=ref_id,
asset_id=asset_id,
name=f"test-{ref_id}",
owner_id="system",
file_path=file_path,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
is_missing=is_missing,
)
session.add(ref)
session.flush()
return asset, ref
def _ensure_missing_tag(session: Session):
"""Ensure the 'missing' tag exists."""
if not session.get(Tag, "missing"):
session.add(Tag(name="missing", tag_type="system"))
session.flush()
class _VerifyCase:
def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify):
self.id = id
self.stat_unchanged = stat_unchanged
self.needs_verify_before = needs_verify_before
self.expect_needs_verify = expect_needs_verify
VERIFY_CASES = [
_VerifyCase(
id="unchanged_clears_verify",
stat_unchanged=True,
needs_verify_before=True,
expect_needs_verify=False,
),
_VerifyCase(
id="unchanged_keeps_clear",
stat_unchanged=True,
needs_verify_before=False,
expect_needs_verify=False,
),
_VerifyCase(
id="changed_sets_verify",
stat_unchanged=False,
needs_verify_before=False,
expect_needs_verify=True,
),
_VerifyCase(
id="changed_keeps_verify",
stat_unchanged=False,
needs_verify_before=True,
expect_needs_verify=True,
),
]
@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id)
def test_needs_verify_toggling(session, temp_dir, case):
"""needs_verify is set/cleared based on mtime+size match."""
fp = _create_file(temp_dir, "model.bin")
real_mtime = _stat_mtime_ns(fp)
mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1
_make_asset(
session, "a1", fp, "r1",
asset_hash="blake3:abc",
mtime_ns=mtime_for_db,
needs_verify=case.needs_verify_before,
)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
ref = session.get(AssetReference, "r1")
assert ref.needs_verify is case.expect_needs_verify
class _MissingCase:
def __init__(self, id, file_exists, expect_is_missing):
self.id = id
self.file_exists = file_exists
self.expect_is_missing = expect_is_missing
MISSING_CASES = [
_MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False),
_MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True),
]
@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id)
def test_is_missing_flag(session, temp_dir, case):
"""is_missing reflects whether the file exists on disk."""
if case.file_exists:
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
else:
fp = str(temp_dir / "gone.bin")
mtime = 999
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
ref = session.get(AssetReference, "r1")
assert ref.is_missing is case.expect_is_missing
def test_seed_asset_all_missing_deletes_asset(session, temp_dir):
"""Seed asset with all refs missing gets deleted entirely."""
fp = str(temp_dir / "gone.bin")
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
assert session.get(Asset, "seed1") is None
assert session.get(AssetReference, "r1") is None
def test_seed_asset_some_exist_returns_survivors(session, temp_dir):
"""Seed asset with at least one existing ref survives and is returned."""
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
survivors = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
session.commit()
assert session.get(Asset, "seed1") is not None
assert os.path.abspath(fp) in survivors
def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir):
"""Hashed asset with one stat-unchanged ref deletes missing refs."""
fp_ok = _create_file(temp_dir, "good.bin")
fp_gone = str(temp_dir / "gone.bin")
mtime = _stat_mtime_ns(fp_ok)
_make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime)
# Second ref on same asset, file missing
ref_gone = AssetReference(
id="r_gone", asset_id="h1", name="gone",
owner_id="system", file_path=fp_gone, mtime_ns=999,
)
session.add(ref_gone)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
assert session.get(AssetReference, "r_ok") is not None
assert session.get(AssetReference, "r_gone") is None
def test_hashed_asset_all_missing_keeps_refs(session, temp_dir):
"""Hashed asset with all refs missing keeps refs (no pruning)."""
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
assert session.get(AssetReference, "r1") is not None
ref = session.get(AssetReference, "r1")
assert ref.is_missing is True
def test_missing_tag_added_when_all_refs_gone(session, temp_dir):
"""Missing tag is added to hashed asset when all refs are missing."""
_ensure_missing_tag(session)
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=True,
)
session.commit()
session.expire_all()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is not None
def test_missing_tag_removed_when_ref_ok(session, temp_dir):
"""Missing tag is removed from hashed asset when a ref is stat-unchanged."""
_ensure_missing_tag(session)
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime)
# Pre-add a stale missing tag
session.add(AssetReferenceTag(
asset_reference_id="r1", tag_name="missing", origin="automatic",
))
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=True,
)
session.commit()
session.expire_all()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is None
def test_missing_tags_not_touched_when_flag_false(session, temp_dir):
"""Missing tags are not modified when update_missing_tags=False."""
_ensure_missing_tag(session)
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=False,
)
session.commit()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is None # tag was never added
def test_returns_none_when_collect_false(session, temp_dir):
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
result = sync_references_with_filesystem(
session, "models", collect_existing_paths=False,
)
assert result is None
def test_returns_empty_set_for_no_prefixes(session):
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]):
result = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
assert result == set()
def test_no_references_is_noop(session, temp_dir):
"""No crash and no side effects when there are no references."""
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
survivors = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
session.commit()
assert survivors == set()

View File

@@ -1,19 +1,18 @@
"""Unit tests for the AssetSeeder background scanning class."""
"""Unit tests for the _AssetSeeder background scanning class."""
import threading
from unittest.mock import patch
import pytest
from app.assets.seeder import AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
from app.assets.database.queries.asset_reference import UnenrichedReferenceRow
from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
@pytest.fixture
def fresh_seeder():
"""Create a fresh AssetSeeder instance for testing (bypasses singleton)."""
seeder = object.__new__(AssetSeeder)
seeder._initialized = False
seeder.__init__()
"""Create a fresh _AssetSeeder instance for testing."""
seeder = _AssetSeeder()
yield seeder
seeder.shutdown(timeout=1.0)
@@ -25,7 +24,7 @@ def mock_dependencies():
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -36,11 +35,11 @@ def mock_dependencies():
class TestSeederStateTransitions:
"""Test state machine transitions."""
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder):
def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder):
assert fresh_seeder.get_status().state == State.IDLE
def test_start_transitions_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -61,7 +60,7 @@ class TestSeederStateTransitions:
barrier.set()
def test_start_while_running_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -83,7 +82,7 @@ class TestSeederStateTransitions:
barrier.set()
def test_cancel_transitions_to_cancelling(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -105,12 +104,12 @@ class TestSeederStateTransitions:
barrier.set()
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
cancelled = fresh_seeder.cancel()
assert cancelled is False
def test_state_returns_to_idle_after_completion(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
@@ -122,7 +121,7 @@ class TestSeederWait:
"""Test wait() behavior."""
def test_wait_blocks_until_complete(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
@@ -130,7 +129,7 @@ class TestSeederWait:
assert fresh_seeder.get_status().state == State.IDLE
def test_wait_returns_false_on_timeout(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
@@ -147,7 +146,7 @@ class TestSeederWait:
barrier.set()
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder):
def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder):
completed = fresh_seeder.wait(timeout=1.0)
assert completed is True
@@ -156,7 +155,7 @@ class TestSeederProgress:
"""Test progress tracking."""
def test_get_status_returns_progress_during_scan(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
barrier = threading.Event()
reached = threading.Event()
@@ -172,7 +171,7 @@ class TestSeederProgress:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder.build_stub_specs", side_effect=slow_build),
patch("app.assets.seeder.build_asset_specs", side_effect=slow_build),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -188,7 +187,7 @@ class TestSeederProgress:
barrier.set()
def test_progress_callback_is_invoked(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
progress_updates: list[Progress] = []
@@ -209,7 +208,7 @@ class TestSeederCancellation:
"""Test cancellation behavior."""
def test_scan_commits_partial_progress_on_cancellation(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
insert_count = 0
barrier = threading.Event()
@@ -245,7 +244,7 @@ class TestSeederCancellation:
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch(
"app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)
"app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
),
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
@@ -264,7 +263,7 @@ class TestSeederCancellation:
class TestSeederErrorHandling:
"""Test error handling behavior."""
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
@@ -273,7 +272,7 @@ class TestSeederErrorHandling:
return_value=["/path/file.safetensors"],
),
patch(
"app.assets.seeder.build_stub_specs",
"app.assets.seeder.build_asset_specs",
return_value=(
[
{
@@ -307,7 +306,7 @@ class TestSeederErrorHandling:
assert "DB connection failed" in status.errors[0]
def test_dependencies_unavailable_captured_in_errors(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
fresh_seeder.start(roots=("models",))
@@ -317,7 +316,7 @@ class TestSeederErrorHandling:
assert len(status.errors) > 0
assert "dependencies" in status.errors[0].lower()
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder):
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
@@ -337,7 +336,7 @@ class TestSeederThreadSafety:
"""Test thread safety of concurrent operations."""
def test_concurrent_start_calls_spawn_only_one_thread(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
@@ -364,7 +363,7 @@ class TestSeederThreadSafety:
assert sum(results) == 1
def test_get_status_safe_during_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -395,7 +394,7 @@ class TestSeederThreadSafety:
class TestSeederMarkMissing:
"""Test mark_missing_outside_prefixes behavior."""
def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder):
def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
@@ -411,7 +410,7 @@ class TestSeederMarkMissing:
mock_mark.assert_called_once_with(["/models", "/input", "/output"])
def test_mark_missing_raises_when_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -433,14 +432,14 @@ class TestSeederMarkMissing:
barrier.set()
def test_mark_missing_returns_zero_when_dependencies_unavailable(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 0
def test_prune_first_flag_triggers_mark_missing_before_scan(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
call_order = []
@@ -458,7 +457,7 @@ class TestSeederMarkMissing:
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -473,7 +472,7 @@ class TestSeederMarkMissing:
class TestSeederPhases:
"""Test phased scanning behavior."""
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: AssetSeeder):
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_fast only runs the fast phase."""
fast_called = []
enrich_called = []
@@ -490,7 +489,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -501,7 +500,7 @@ class TestSeederPhases:
assert len(fast_called) == 1
assert len(enrich_called) == 0
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: AssetSeeder):
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_enrich only runs the enrich phase."""
fast_called = []
enrich_called = []
@@ -518,7 +517,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -529,7 +528,7 @@ class TestSeederPhases:
assert len(fast_called) == 0
assert len(enrich_called) == 1
def test_full_scan_runs_both_phases(self, fresh_seeder: AssetSeeder):
def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder):
"""Verify full scan runs both fast and enrich phases."""
fast_called = []
enrich_called = []
@@ -546,7 +545,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -562,7 +561,7 @@ class TestSeederPauseResume:
"""Test pause/resume behavior."""
def test_pause_transitions_to_paused(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -584,12 +583,12 @@ class TestSeederPauseResume:
barrier.set()
def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
paused = fresh_seeder.pause()
assert paused is False
def test_resume_returns_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -615,7 +614,7 @@ class TestSeederPauseResume:
barrier.set()
def test_resume_when_not_paused_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -637,7 +636,7 @@ class TestSeederPauseResume:
barrier.set()
def test_cancel_while_paused_works(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached_checkpoint = threading.Event()
@@ -667,7 +666,7 @@ class TestSeederStopRestart:
"""Test stop and restart behavior."""
def test_stop_is_alias_for_cancel(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -690,7 +689,7 @@ class TestSeederStopRestart:
barrier.set()
def test_restart_cancels_and_starts_new_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@@ -717,7 +716,7 @@ class TestSeederStopRestart:
fresh_seeder.wait(timeout=5.0)
assert start_count == 2
def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder):
def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart uses previous params when not overridden."""
collected_roots = []
@@ -729,7 +728,7 @@ class TestSeederStopRestart:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -744,7 +743,7 @@ class TestSeederStopRestart:
assert collected_roots[0] == ("input", "output")
assert collected_roots[1] == ("input", "output")
def test_restart_can_override_params(self, fresh_seeder: AssetSeeder):
def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart can override previous params."""
collected_roots = []
@@ -756,7 +755,7 @@ class TestSeederStopRestart:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@@ -770,3 +769,132 @@ class TestSeederStopRestart:
assert len(collected_roots) == 2
assert collected_roots[0] == ("models",)
assert collected_roots[1] == ("input",)
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
return UnenrichedReferenceRow(
reference_id=ref_id, asset_id=asset_id,
file_path=f"/fake/{ref_id}.bin", enrichment_level=0,
)
class TestEnrichPhaseDefensiveLogic:
"""Test skip_ids filtering and consecutive_empty termination."""
def test_failed_refs_are_skipped_on_subsequent_batches(
self, fresh_seeder: _AssetSeeder,
):
"""References that fail enrichment are filtered out of future batches."""
row_a = _make_row("r1")
row_b = _make_row("r2")
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
return [row_a, row_b]
return []
enriched_refs: list[list[str]] = []
def fake_enrich(rows, **kwargs):
ref_ids = [r.reference_id for r in rows]
enriched_refs.append(ref_ids)
# r1 always fails, r2 succeeds
failed = [r.reference_id for r in rows if r.reference_id == "r1"]
enriched = len(rows) - len(failed)
return enriched, failed
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# First batch: both refs attempted
assert "r1" in enriched_refs[0]
assert "r2" in enriched_refs[0]
# Second batch: r1 filtered out
assert "r1" not in enriched_refs[1]
assert "r2" in enriched_refs[1]
def test_stops_after_consecutive_empty_batches(
self, fresh_seeder: _AssetSeeder,
):
"""Enrich phase terminates after 3 consecutive batches with zero progress."""
row = _make_row("r1")
batch_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal batch_count
batch_count += 1
# Always return the same row (simulating a permanently failing ref)
return [row]
def fake_enrich(rows, **kwargs):
# Always fail — zero enriched, all failed
return 0, [r.reference_id for r in rows]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# Should stop after exactly 3 consecutive empty batches
# Batch 1: returns row, enrich fails → filtered out in batch 2+
# But get_unenriched keeps returning it, filter removes it → empty → break
# Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row],
# skip_ids filters it → empty list → breaks via `if not unenriched: break`
# So it terminates in 2 calls to get_unenriched.
assert batch_count == 2
def test_consecutive_empty_counter_resets_on_success(
self, fresh_seeder: _AssetSeeder,
):
"""A successful batch resets the consecutive empty counter."""
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 6:
return [_make_row(f"r{call_count}", f"a{call_count}")]
return []
def fake_enrich(rows, **kwargs):
ref_id = rows[0].reference_id
# Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6
if ref_id in ("r1", "r2", "r4", "r5"):
return 0, [ref_id]
return 1, []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# All 6 batches should run + 1 final call returning empty
assert call_count == 7
status = fresh_seeder.get_status()
assert status.state == State.IDLE