mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-09 15:20:06 +00:00
Add pause/resume/stop/restart controls to AssetSeeder
- Add PAUSED state to state machine - Add pause() method - blocks scan at next checkpoint - Add resume() method - unblocks paused scan - Add stop() method - alias for cancel() - Add restart() method - cancel + wait + start with same/overridden params - Add _check_pause_and_cancel() helper for checkpoint locations - Emit assets.seed.paused and assets.seed.resumed WebSocket events - Update get_object_info to use async seeder instead of blocking seed_assets - Scan all roots (models, input, output) on object_info, not just models Amp-Thread-ID: https://ampcode.com/threads/T-019c4f2b-5801-711c-8d47-bd1525808d77 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -523,3 +523,261 @@ class TestSeederPhases:
|
||||
|
||||
assert len(fast_called) == 1
|
||||
assert len(enrich_called) == 1
|
||||
|
||||
|
||||
class TestSeederPauseResume:
|
||||
"""Test pause/resume behavior."""
|
||||
|
||||
def test_pause_transitions_to_paused(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
paused = fresh_seeder.pause()
|
||||
assert paused is True
|
||||
assert fresh_seeder.get_status().state == State.PAUSED
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||
paused = fresh_seeder.pause()
|
||||
assert paused is False
|
||||
|
||||
def test_resume_returns_to_running(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
fresh_seeder.pause()
|
||||
assert fresh_seeder.get_status().state == State.PAUSED
|
||||
|
||||
resumed = fresh_seeder.resume()
|
||||
assert resumed is True
|
||||
assert fresh_seeder.get_status().state == State.RUNNING
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_resume_when_not_paused_returns_false(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
resumed = fresh_seeder.resume()
|
||||
assert resumed is False
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_while_paused_works(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached_checkpoint = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
reached_checkpoint.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
reached_checkpoint.wait(timeout=1.0)
|
||||
|
||||
fresh_seeder.pause()
|
||||
time.sleep(0.05)
|
||||
|
||||
cancelled = fresh_seeder.cancel()
|
||||
assert cancelled is True
|
||||
|
||||
barrier.set()
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
def test_pause_blocks_scan_until_resume(self, fresh_seeder: AssetSeeder):
|
||||
"""Verify scan blocks at checkpoint while paused."""
|
||||
batch_count = 0
|
||||
pause_detected = threading.Event()
|
||||
resume_signal = threading.Event()
|
||||
|
||||
def counting_insert(specs, tags):
|
||||
nonlocal batch_count
|
||||
batch_count += 1
|
||||
if batch_count == 1:
|
||||
pause_detected.set()
|
||||
resume_signal.wait(timeout=5.0)
|
||||
return len(specs)
|
||||
|
||||
paths = [f"/path/file{i}.safetensors" for i in range(1000)]
|
||||
specs = [
|
||||
{
|
||||
"abs_path": p,
|
||||
"size_bytes": 100,
|
||||
"mtime_ns": 0,
|
||||
"info_name": f"file{i}",
|
||||
"tags": [],
|
||||
"fname": f"file{i}",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": None,
|
||||
}
|
||||
for i, p in enumerate(paths)
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
|
||||
patch("app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", side_effect=counting_insert),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
pause_detected.wait(timeout=2.0)
|
||||
|
||||
fresh_seeder.pause()
|
||||
count_at_pause = batch_count
|
||||
time.sleep(0.1)
|
||||
assert batch_count == count_at_pause
|
||||
|
||||
fresh_seeder.resume()
|
||||
resume_signal.set()
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert batch_count > count_at_pause
|
||||
|
||||
|
||||
class TestSeederStopRestart:
|
||||
"""Test stop and restart behavior."""
|
||||
|
||||
def test_stop_is_alias_for_cancel(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
stopped = fresh_seeder.stop()
|
||||
assert stopped is True
|
||||
assert fresh_seeder.get_status().state == State.CANCELLING
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_restart_cancels_and_starts_new_scan(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
start_count = 0
|
||||
|
||||
def slow_collect(*args):
|
||||
nonlocal start_count
|
||||
start_count += 1
|
||||
if start_count == 1:
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
barrier.set()
|
||||
restarted = fresh_seeder.restart()
|
||||
assert restarted is True
|
||||
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
assert start_count == 2
|
||||
|
||||
def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder):
|
||||
"""Verify restart uses previous params when not overridden."""
|
||||
collected_roots = []
|
||||
|
||||
def track_collect(roots):
|
||||
collected_roots.append(roots)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
):
|
||||
fresh_seeder.start(roots=("input", "output"))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
fresh_seeder.restart()
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert len(collected_roots) == 2
|
||||
assert collected_roots[0] == ("input", "output")
|
||||
assert collected_roots[1] == ("input", "output")
|
||||
|
||||
def test_restart_can_override_params(self, fresh_seeder: AssetSeeder):
|
||||
"""Verify restart can override previous params."""
|
||||
collected_roots = []
|
||||
|
||||
def track_collect(roots):
|
||||
collected_roots.append(roots)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
fresh_seeder.restart(roots=("input",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert len(collected_roots) == 2
|
||||
assert collected_roots[0] == ("models",)
|
||||
assert collected_roots[1] == ("input",)
|
||||
|
||||
Reference in New Issue
Block a user