mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 22:20:03 +00:00
feat(assets): add background asset seeder for non-blocking startup
- Add AssetSeeder singleton class with thread management and cancellation - Support IDLE/RUNNING/CANCELLING state machine with thread-safe access - Emit WebSocket events for scan progress (started, progress, completed, cancelled, error) - Update main.py to use non-blocking asset_seeder.start() at startup - Add shutdown() call in finally block for graceful cleanup - Update POST /api/assets/seed to return 202 Accepted, support ?wait=true - Add GET /api/assets/seed/status and POST /api/assets/seed/cancel endpoints - Update test helper to use ?wait=true for synchronous behavior - Add 17 unit tests covering state transitions, cancellation, and thread safety - Log scan configuration (models directory, input/output paths) at scan start Amp-Thread-ID: https://ampcode.com/threads/T-019c2b45-e6e8-740a-b38b-b11daea8d094 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from app.assets.api.schemas_in import (
|
||||
UploadError,
|
||||
)
|
||||
from app.assets.api.upload import parse_multipart_upload
|
||||
from app.assets.scanner import seed_assets as scanner_seed_assets
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
@@ -620,21 +620,78 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
"""Trigger asset seeding for specified roots (models, input, output).
|
||||
|
||||
Query params:
|
||||
wait: If "true", block until scan completes (synchronous behavior for tests)
|
||||
|
||||
Returns:
|
||||
202 Accepted if scan started
|
||||
409 Conflict if scan already running
|
||||
200 OK with final stats if wait=true
|
||||
"""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
valid_roots = tuple(r for r in roots if r in ("models", "input", "output"))
|
||||
if not valid_roots:
|
||||
return _build_error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
scanner_seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("scanner_seed_assets failed for roots=%s", valid_roots)
|
||||
return _build_error_response(500, "INTERNAL", "Seed operation failed")
|
||||
wait_param = request.query.get("wait", "").lower()
|
||||
should_wait = wait_param in ("true", "1", "yes")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
started = asset_seeder.start(roots=valid_roots)
|
||||
if not started:
|
||||
return web.json_response({"status": "already_running"}, status=409)
|
||||
|
||||
if should_wait:
|
||||
asset_seeder.wait()
|
||||
status = asset_seeder.get_status()
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "completed",
|
||||
"progress": {
|
||||
"scanned": status.progress.scanned if status.progress else 0,
|
||||
"total": status.progress.total if status.progress else 0,
|
||||
"created": status.progress.created if status.progress else 0,
|
||||
"skipped": status.progress.skipped if status.progress else 0,
|
||||
},
|
||||
"errors": status.errors,
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
|
||||
return web.json_response({"status": "started"}, status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/seed/status")
|
||||
async def get_seed_status(request: web.Request) -> web.Response:
|
||||
"""Get current scan status and progress."""
|
||||
status = asset_seeder.get_status()
|
||||
return web.json_response(
|
||||
{
|
||||
"state": status.state.value,
|
||||
"progress": {
|
||||
"scanned": status.progress.scanned,
|
||||
"total": status.progress.total,
|
||||
"created": status.progress.created,
|
||||
"skipped": status.progress.skipped,
|
||||
}
|
||||
if status.progress
|
||||
else None,
|
||||
"errors": status.errors,
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed/cancel")
|
||||
async def cancel_seed(request: web.Request) -> web.Response:
|
||||
"""Request cancellation of in-progress scan."""
|
||||
cancelled = asset_seeder.cancel()
|
||||
if cancelled:
|
||||
return web.json_response({"status": "cancelling"}, status=200)
|
||||
return web.json_response({"status": "idle"}, status=200)
|
||||
|
||||
376
app/assets/seeder.py
Normal file
376
app/assets/seeder.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Background asset seeder with thread management and cancellation support."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from app.assets.scanner import (
|
||||
RootType,
|
||||
_build_asset_specs,
|
||||
_collect_paths_for_roots,
|
||||
_insert_asset_specs,
|
||||
_prune_orphans_safely,
|
||||
_sync_root_safely,
|
||||
get_prefixes_for_root,
|
||||
)
|
||||
from app.database.db import dependencies_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class State(Enum):
|
||||
"""Seeder state machine states."""
|
||||
|
||||
IDLE = "IDLE"
|
||||
RUNNING = "RUNNING"
|
||||
CANCELLING = "CANCELLING"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Progress:
|
||||
"""Progress information for a scan operation."""
|
||||
|
||||
scanned: int = 0
|
||||
total: int = 0
|
||||
created: int = 0
|
||||
skipped: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanStatus:
|
||||
"""Current status of the asset seeder."""
|
||||
|
||||
state: State
|
||||
progress: Progress | None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[Progress], None]
|
||||
|
||||
|
||||
class AssetSeeder:
|
||||
"""Singleton class managing background asset scanning.
|
||||
|
||||
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
|
||||
Each scan creates a new thread that exits when complete.
|
||||
"""
|
||||
|
||||
_instance: "AssetSeeder | None" = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "AssetSeeder":
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
self._lock = threading.Lock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._errors: list[str] = []
|
||||
self._thread: threading.Thread | None = None
|
||||
self._cancel_event = threading.Event()
|
||||
self._roots: tuple[RootType, ...] = ()
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
|
||||
def start(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models",),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> bool:
|
||||
"""Start a background scan for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan (models, input, output)
|
||||
progress_callback: Optional callback called with progress updates
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
return False
|
||||
self._state = State.RUNNING
|
||||
self._progress = Progress()
|
||||
self._errors = []
|
||||
self._roots = roots
|
||||
self._progress_callback = progress_callback
|
||||
self._cancel_event.clear()
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_scan,
|
||||
name="AssetSeeder",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
return True
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
Returns:
|
||||
True if cancellation was requested, False if not running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.RUNNING:
|
||||
return False
|
||||
self._state = State.CANCELLING
|
||||
self._cancel_event.set()
|
||||
return True
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""Wait for the current scan to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait, or None for no timeout
|
||||
|
||||
Returns:
|
||||
True if scan completed, False if timeout expired or no scan running
|
||||
"""
|
||||
with self._lock:
|
||||
thread = self._thread
|
||||
if thread is None:
|
||||
return True
|
||||
thread.join(timeout=timeout)
|
||||
return not thread.is_alive()
|
||||
|
||||
def get_status(self) -> ScanStatus:
|
||||
"""Get the current status and progress of the seeder."""
|
||||
with self._lock:
|
||||
return ScanStatus(
|
||||
state=self._state,
|
||||
progress=Progress(
|
||||
scanned=self._progress.scanned,
|
||||
total=self._progress.total,
|
||||
created=self._progress.created,
|
||||
skipped=self._progress.skipped,
|
||||
)
|
||||
if self._progress
|
||||
else None,
|
||||
errors=list(self._errors),
|
||||
)
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown: cancel any running scan and wait for thread.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for thread to exit
|
||||
"""
|
||||
self.cancel()
|
||||
self.wait(timeout=timeout)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||
"""Emit a WebSocket event if server is available."""
|
||||
try:
|
||||
from server import PromptServer
|
||||
|
||||
if hasattr(PromptServer, "instance") and PromptServer.instance:
|
||||
PromptServer.instance.send_sync(event_type, data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
scanned: int | None = None,
|
||||
total: int | None = None,
|
||||
created: int | None = None,
|
||||
skipped: int | None = None,
|
||||
) -> None:
|
||||
"""Update progress counters (thread-safe)."""
|
||||
with self._lock:
|
||||
if self._progress is None:
|
||||
return
|
||||
if scanned is not None:
|
||||
self._progress.scanned = scanned
|
||||
if total is not None:
|
||||
self._progress.total = total
|
||||
if created is not None:
|
||||
self._progress.created = created
|
||||
if skipped is not None:
|
||||
self._progress.skipped = skipped
|
||||
if self._progress_callback:
|
||||
try:
|
||||
self._progress_callback(
|
||||
Progress(
|
||||
scanned=self._progress.scanned,
|
||||
total=self._progress.total,
|
||||
created=self._progress.created,
|
||||
skipped=self._progress.skipped,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _add_error(self, message: str) -> None:
|
||||
"""Add an error message (thread-safe)."""
|
||||
with self._lock:
|
||||
self._errors.append(message)
|
||||
|
||||
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
|
||||
"""Log the directories that will be scanned."""
|
||||
import folder_paths
|
||||
|
||||
for root in roots:
|
||||
if root == "models":
|
||||
logging.info(
|
||||
"Asset scan [models] directory: %s",
|
||||
os.path.abspath(folder_paths.models_dir),
|
||||
)
|
||||
else:
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if prefixes:
|
||||
logging.info("Asset scan [%s] directories: %s", root, prefixes)
|
||||
|
||||
def _run_scan(self) -> None:
|
||||
"""Main scan loop running in background thread."""
|
||||
t_start = time.perf_counter()
|
||||
roots = self._roots
|
||||
cancelled = False
|
||||
total_created = 0
|
||||
skipped_existing = 0
|
||||
total_paths = 0
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
self._add_error("Database dependencies not available")
|
||||
self._emit_event(
|
||||
"assets.seed.error",
|
||||
{"message": "Database dependencies not available"},
|
||||
)
|
||||
return
|
||||
|
||||
self._log_scan_config(roots)
|
||||
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled during sync phase")
|
||||
cancelled = True
|
||||
return
|
||||
existing_paths.update(_sync_root_safely(r))
|
||||
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled after sync phase")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
all_prefixes = [
|
||||
os.path.abspath(p) for r in roots for p in get_prefixes_for_root(r)
|
||||
]
|
||||
orphans_pruned = _prune_orphans_safely(all_prefixes)
|
||||
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled after orphan pruning")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
paths = _collect_paths_for_roots(roots)
|
||||
total_paths = len(paths)
|
||||
self._update_progress(total=total_paths)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "total": total_paths},
|
||||
)
|
||||
|
||||
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths)
|
||||
self._update_progress(skipped=skipped_existing)
|
||||
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled after building specs")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
batch_size = 500
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
for i in range(0, len(specs), batch_size):
|
||||
if self._is_cancelled():
|
||||
logging.info(
|
||||
"Asset scan cancelled after %d/%d files (created=%d)",
|
||||
i,
|
||||
len(specs),
|
||||
total_created,
|
||||
)
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
batch = specs[i : i + batch_size]
|
||||
batch_tags = {t for spec in batch for t in spec["tags"]}
|
||||
try:
|
||||
created = _insert_asset_specs(batch, batch_tags)
|
||||
total_created += created
|
||||
except Exception as e:
|
||||
self._add_error(f"Batch insert failed at offset {i}: {e}")
|
||||
logging.exception("Batch insert failed at offset %d", i)
|
||||
|
||||
scanned = i + len(batch)
|
||||
self._update_progress(scanned=scanned, created=total_created)
|
||||
|
||||
now = time.perf_counter()
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{"scanned": scanned, "total": len(specs), "created": total_created},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
self._update_progress(scanned=len(specs), created=total_created)
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
logging.info(
|
||||
"Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, orphans_pruned=%d, total=%d)",
|
||||
roots,
|
||||
elapsed,
|
||||
total_created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.completed",
|
||||
{
|
||||
"scanned": len(specs),
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
"skipped": skipped_existing,
|
||||
"elapsed": round(elapsed, 3),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._add_error(f"Scan failed: {e}")
|
||||
logging.exception("Asset scan failed")
|
||||
self._emit_event("assets.seed.error", {"message": str(e)})
|
||||
finally:
|
||||
if cancelled:
|
||||
self._emit_event(
|
||||
"assets.seed.cancelled",
|
||||
{
|
||||
"scanned": self._progress.scanned if self._progress else 0,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._state = State.IDLE
|
||||
|
||||
|
||||
asset_seeder = AssetSeeder()
|
||||
195
docs/design/background-asset-seeder.md
Normal file
195
docs/design/background-asset-seeder.md
Normal file
@@ -0,0 +1,195 @@
|
||||
# Background Asset Seeder Design Document
|
||||
|
||||
## Problem Statement
|
||||
|
||||
The `seed_assets` function in `app/assets/scanner.py` scans filesystem directories and imports assets into the database. Currently it runs synchronously, which causes two problems:
|
||||
|
||||
1. **Startup blocking**: When ComfyUI starts, `setup_database()` calls `seed_assets(["models"])` synchronously. If the models directory contains thousands of files, startup is delayed significantly before the UI becomes available.
|
||||
|
||||
2. **API request blocking**: The `POST /api/assets/seed` endpoint runs synchronously, blocking the HTTP request until scanning completes. For large directories, this causes request timeouts and a poor user experience.
|
||||
|
||||
## Goals
|
||||
|
||||
- Move asset scanning to a background thread so startup and API requests return immediately
|
||||
- Provide visibility into scan progress via API and WebSocket events
|
||||
- Support graceful cancellation of in-progress scans
|
||||
- Maintain backward compatibility for tests that rely on synchronous behavior
|
||||
- Ensure thread safety when accessing shared state
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- Targeted/priority scanning of specific paths (addressed separately via synchronous `scan_paths()`)
|
||||
- Parallel scanning across multiple threads
|
||||
- Persistent scan state across restarts
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Component: AssetSeeder Singleton
|
||||
|
||||
A new `AssetSeeder` class in `app/assets/seeder.py` manages background scanning with the following responsibilities:
|
||||
|
||||
- Owns a single background `threading.Thread` for scanning work
|
||||
- Tracks scan state and progress in a thread-safe manner
|
||||
- Provides cancellation support via `threading.Event`
|
||||
- Emits WebSocket events for UI progress updates
|
||||
|
||||
### State Machine
|
||||
|
||||
```
|
||||
IDLE ──start()──► RUNNING ──(completes)──► IDLE
|
||||
│
|
||||
cancel()
|
||||
│y
|
||||
▼
|
||||
CANCELLING ──(thread exits)──► IDLE
|
||||
```
|
||||
|
||||
### Integration Points
|
||||
|
||||
| Component | Change |
|
||||
|-----------|--------|
|
||||
| `main.py` | Call `asset_seeder.start()` (non-blocking) instead of `seed_assets()` |
|
||||
| `main.py` | Call `asset_seeder.shutdown()` in `finally` block alongside `cleanup_temp()` |
|
||||
| API routes | New endpoints for status and cancellation |
|
||||
| WebSocket | Emit progress events during scanning |
|
||||
| Test helper | Use `?wait=true` query param for synchronous behavior |
|
||||
|
||||
---
|
||||
|
||||
## Tasks
|
||||
|
||||
### Task 1: Create AssetSeeder Class
|
||||
|
||||
**Description**: Implement the core `AssetSeeder` singleton class with thread management, state tracking, and cancellation support.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] `AssetSeeder` is a singleton accessible via module-level instance
|
||||
- [ ] `State` enum with values: `IDLE`, `RUNNING`, `CANCELLING`
|
||||
- [ ] `start(roots)` method spawns a daemon thread and returns immediately
|
||||
- [ ] `start()` is idempotent—calling while already running is a no-op and returns `False`
|
||||
- [ ] `cancel()` method signals the thread to stop gracefully
|
||||
- [ ] `wait(timeout)` method blocks until thread completes or timeout expires
|
||||
- [ ] `get_status()` returns current state and progress information
|
||||
- [ ] All state access is protected by `threading.Lock`
|
||||
- [ ] Thread creates its own database sessions (no session sharing across threads)
|
||||
- [ ] Progress tuple tracks: `(scanned, total, created, skipped)`
|
||||
- [ ] Errors during scanning are captured and available via `get_status()`
|
||||
|
||||
### Task 2: Add Cancellation Checkpoints
|
||||
|
||||
**Description**: Modify the scanning logic to check for cancellation between batches, allowing graceful early termination.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] `threading.Event` is checked between batch operations
|
||||
- [ ] When cancellation is requested, current batch completes before stopping
|
||||
- [ ] Partial progress is committed (assets already scanned remain in database)
|
||||
- [ ] State transitions to `IDLE` after cancellation completes
|
||||
- [ ] Cancellation is logged with partial progress statistics
|
||||
|
||||
### Task 3: Update Startup Integration
|
||||
|
||||
**Description**: Modify `main.py` to use non-blocking asset seeding at startup.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] `setup_database()` calls `asset_seeder.start(roots=["models"])` instead of `seed_assets()`
|
||||
- [ ] Startup proceeds immediately without waiting for scan completion
|
||||
- [ ] `asset_seeder.shutdown()` called in `finally` block alongside `cleanup_temp()`
|
||||
- [ ] `--disable-assets-autoscan` flag continues to skip seeding entirely
|
||||
- [ ] Startup logs indicate background scan was initiated (not completed)
|
||||
|
||||
### Task 4: Create API Endpoints
|
||||
|
||||
**Description**: Add REST endpoints for triggering, monitoring, and cancelling background scans.
|
||||
|
||||
**Endpoints**:
|
||||
|
||||
#### POST /api/assets/seed
|
||||
Trigger a background scan for specified roots.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] Accepts `{"roots": ["models", "input", "output"]}` in request body
|
||||
- [ ] Returns `202 Accepted` with `{"status": "started"}` when scan begins
|
||||
- [ ] Returns `409 Conflict` with `{"status": "already_running"}` if scan in progress
|
||||
- [ ] Supports `?wait=true` query param for synchronous behavior (blocks until complete)
|
||||
- [ ] With `?wait=true`, returns `200 OK` with final statistics on completion
|
||||
|
||||
#### GET /api/assets/seed/status
|
||||
Get current scan status and progress.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] Returns `{"state": "IDLE|RUNNING|CANCELLING", "progress": {...}, "errors": [...]}`
|
||||
- [ ] Progress object includes: `scanned`, `total`, `created`, `skipped`
|
||||
- [ ] When idle, progress reflects last completed scan (or null if never run)
|
||||
- [ ] Errors array contains messages from any failures during last/current scan
|
||||
|
||||
#### POST /api/assets/seed/cancel
|
||||
Request cancellation of in-progress scan.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] Returns `200 OK` with `{"status": "cancelling"}` if scan was running
|
||||
- [ ] Returns `200 OK` with `{"status": "idle"}` if no scan was running
|
||||
- [ ] Cancellation is graceful—does not corrupt database state
|
||||
|
||||
### Task 5: Add WebSocket Progress Events
|
||||
|
||||
**Description**: Emit WebSocket events during scanning so the UI can display progress.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] Event type: `assets.seed.started` with `{"roots": [...], "total": N}`
|
||||
- [ ] Event type: `assets.seed.progress` with `{"scanned": N, "total": M, "created": C}`
|
||||
- [ ] Event type: `assets.seed.completed` with final statistics
|
||||
- [ ] Event type: `assets.seed.cancelled` if scan was cancelled
|
||||
- [ ] Event type: `assets.seed.error` if scan failed with error message
|
||||
- [ ] Progress events emitted at reasonable intervals (not every file, ~every 100 files or 1 second)
|
||||
|
||||
### Task 6: Update Test Helper
|
||||
|
||||
**Description**: Modify the test helper to use synchronous behavior via query parameter.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] `trigger_sync_seed_assets()` uses `?wait=true` query parameter
|
||||
- [ ] Tests continue to pass with synchronous blocking behavior
|
||||
- [ ] Remove artificial `time.sleep(0.2)` delay (no longer needed with `wait=true`)
|
||||
|
||||
### Task 7: Unit Tests for AssetSeeder
|
||||
|
||||
**Description**: Add unit tests covering the seeder state machine and thread safety.
|
||||
|
||||
**Acceptance Criteria**:
|
||||
- [ ] Test: `start()` transitions state from IDLE to RUNNING
|
||||
- [ ] Test: `start()` while RUNNING returns False (idempotent)
|
||||
- [ ] Test: `cancel()` transitions state from RUNNING to CANCELLING
|
||||
- [ ] Test: `wait()` blocks until thread completes
|
||||
- [ ] Test: `wait(timeout)` returns False if timeout expires
|
||||
- [ ] Test: `get_status()` returns correct progress during scan
|
||||
- [ ] Test: Concurrent `start()` calls are safe (only one thread spawned)
|
||||
- [ ] Test: Scan commits partial progress on cancellation
|
||||
- [ ] Test: Database errors are captured in status, don't crash thread
|
||||
|
||||
---
|
||||
|
||||
## Thread Safety Considerations
|
||||
|
||||
| Shared Resource | Protection Strategy |
|
||||
|-----------------|---------------------|
|
||||
| `_state` enum | Protected by `threading.Lock` |
|
||||
| `_progress` tuple | Protected by `threading.Lock` |
|
||||
| `_errors` list | Protected by `threading.Lock` |
|
||||
| `_thread` reference | Protected by `threading.Lock` |
|
||||
| `_cancel_event` | `threading.Event` (inherently thread-safe) |
|
||||
| Database sessions | Created per-operation inside thread (no sharing) |
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Database connection failures: Log error, set state to IDLE, populate errors list
|
||||
- Individual file scan failures: Log warning, continue with next file, increment error count
|
||||
- Thread crashes: Caught by outer try/except, state reset to IDLE, error captured
|
||||
|
||||
## Future Considerations
|
||||
|
||||
- **Priority queue**: If targeted scans need to be non-blocking in the future, the seeder could be extended with a priority queue
|
||||
- **Persistent state**: Scan progress could be persisted to allow resume after restart
|
||||
- **Parallel scanning**: Multiple threads could scan different roots concurrently (requires careful session management)
|
||||
- **Throttling**: If scanning competes with generation (e.g., disk I/O contention when hashing large files), add configurable sleep between batches. Currently considered low risk since scanning is I/O-bound and generation is GPU-bound.
|
||||
10
main.py
10
main.py
@@ -7,7 +7,7 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.scanner import seed_assets
|
||||
from app.assets.seeder import asset_seeder
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
@@ -357,7 +357,8 @@ def setup_database():
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if not args.disable_assets_autoscan:
|
||||
seed_assets(["models"], enable_logging=True)
|
||||
if asset_seeder.start(roots=("models",)):
|
||||
logging.info("Background asset scan initiated for models")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
@@ -441,5 +442,6 @@ if __name__ == "__main__":
|
||||
event_loop.run_until_complete(x)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
||||
cleanup_temp()
|
||||
finally:
|
||||
asset_seeder.shutdown()
|
||||
cleanup_temp()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Helper functions for assets integration tests."""
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||
"""Force a fast sync/seed pass by calling the seed endpoint."""
|
||||
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
|
||||
time.sleep(0.2)
|
||||
"""Force a synchronous sync/seed pass by calling the seed endpoint with wait=true."""
|
||||
session.post(
|
||||
base_url + "/api/assets/seed?wait=true",
|
||||
json={"roots": ["models", "input", "output"]},
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
|
||||
347
tests-unit/seeder_test/test_seeder.py
Normal file
347
tests-unit/seeder_test/test_seeder.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Unit tests for the AssetSeeder background scanning class."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import AssetSeeder, Progress, State
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_seeder():
|
||||
"""Create a fresh AssetSeeder instance for testing (bypasses singleton)."""
|
||||
seeder = object.__new__(AssetSeeder)
|
||||
seeder._initialized = False
|
||||
seeder.__init__()
|
||||
yield seeder
|
||||
seeder.shutdown(timeout=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Mock all external dependencies for isolated testing."""
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder._sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder._prune_orphans_safely", return_value=0),
|
||||
patch("app.assets.seeder._collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder._insert_asset_specs", return_value=0),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestSeederStateTransitions:
|
||||
"""Test state machine transitions."""
|
||||
|
||||
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder):
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
def test_start_transitions_to_running(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
started = fresh_seeder.start(roots=("models",))
|
||||
assert started is True
|
||||
status = fresh_seeder.get_status()
|
||||
assert status.state in (State.RUNNING, State.IDLE)
|
||||
|
||||
def test_start_while_running_returns_false(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
second_start = fresh_seeder.start(roots=("models",))
|
||||
assert second_start is False
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_transitions_to_cancelling(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
cancelled = fresh_seeder.cancel()
|
||||
assert cancelled is True
|
||||
assert fresh_seeder.get_status().state == State.CANCELLING
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||
cancelled = fresh_seeder.cancel()
|
||||
assert cancelled is False
|
||||
|
||||
def test_state_returns_to_idle_after_completion(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=5.0)
|
||||
assert completed is True
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
|
||||
class TestSeederWait:
|
||||
"""Test wait() behavior."""
|
||||
|
||||
def test_wait_blocks_until_complete(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=5.0)
|
||||
assert completed is True
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
def test_wait_returns_false_on_timeout(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=10.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=0.1)
|
||||
assert completed is False
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder):
|
||||
completed = fresh_seeder.wait(timeout=1.0)
|
||||
assert completed is True
|
||||
|
||||
|
||||
class TestSeederProgress:
|
||||
"""Test progress tracking."""
|
||||
|
||||
def test_get_status_returns_progress_during_scan(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
progress_seen = []
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return ["/path/file1.safetensors", "/path/file2.safetensors"]
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert status.progress is not None
|
||||
progress_seen.append(status.progress)
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_progress_callback_is_invoked(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
progress_updates: list[Progress] = []
|
||||
|
||||
def callback(p: Progress):
|
||||
progress_updates.append(p)
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder._collect_paths_for_roots",
|
||||
return_value=[f"/path/file{i}.safetensors" for i in range(10)],
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), progress_callback=callback)
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert len(progress_updates) > 0
|
||||
|
||||
|
||||
class TestSeederCancellation:
|
||||
"""Test cancellation behavior."""
|
||||
|
||||
def test_scan_commits_partial_progress_on_cancellation(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
):
|
||||
insert_count = 0
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_insert(specs, tags):
|
||||
nonlocal insert_count
|
||||
insert_count += 1
|
||||
if insert_count >= 2:
|
||||
barrier.wait(timeout=5.0)
|
||||
return len(specs)
|
||||
|
||||
paths = [f"/path/file{i}.safetensors" for i in range(1500)]
|
||||
specs = [
|
||||
{
|
||||
"abs_path": p,
|
||||
"size_bytes": 100,
|
||||
"mtime_ns": 0,
|
||||
"info_name": f"file{i}",
|
||||
"tags": [],
|
||||
"fname": f"file{i}",
|
||||
}
|
||||
for i, p in enumerate(paths)
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder._sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder._prune_orphans_safely", return_value=0),
|
||||
patch("app.assets.seeder._collect_paths_for_roots", return_value=paths),
|
||||
patch("app.assets.seeder._build_asset_specs", return_value=(specs, set(), 0)),
|
||||
patch("app.assets.seeder._insert_asset_specs", side_effect=slow_insert),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.1)
|
||||
|
||||
fresh_seeder.cancel()
|
||||
barrier.set()
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert insert_count >= 1
|
||||
|
||||
|
||||
class TestSeederErrorHandling:
|
||||
"""Test error handling behavior."""
|
||||
|
||||
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder._sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder._prune_orphans_safely", return_value=0),
|
||||
patch(
|
||||
"app.assets.seeder._collect_paths_for_roots",
|
||||
return_value=["/path/file.safetensors"],
|
||||
),
|
||||
patch(
|
||||
"app.assets.seeder._build_asset_specs",
|
||||
return_value=(
|
||||
[
|
||||
{
|
||||
"abs_path": "/path/file.safetensors",
|
||||
"size_bytes": 100,
|
||||
"mtime_ns": 0,
|
||||
"info_name": "file",
|
||||
"tags": [],
|
||||
"fname": "file",
|
||||
}
|
||||
],
|
||||
set(),
|
||||
0,
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"app.assets.seeder._insert_asset_specs",
|
||||
side_effect=Exception("DB connection failed"),
|
||||
),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert len(status.errors) > 0
|
||||
assert "DB connection failed" in status.errors[0]
|
||||
|
||||
def test_dependencies_unavailable_captured_in_errors(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
):
|
||||
with patch("app.assets.seeder.dependencies_available", return_value=False):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert len(status.errors) > 0
|
||||
assert "dependencies" in status.errors[0].lower()
|
||||
|
||||
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch(
|
||||
"app.assets.seeder._sync_root_safely",
|
||||
side_effect=RuntimeError("Unexpected crash"),
|
||||
),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert status.state == State.IDLE
|
||||
assert len(status.errors) > 0
|
||||
|
||||
|
||||
class TestSeederThreadSafety:
|
||||
"""Test thread safety of concurrent operations."""
|
||||
|
||||
def test_concurrent_start_calls_spawn_only_one_thread(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
results = []
|
||||
|
||||
def try_start():
|
||||
results.append(fresh_seeder.start(roots=("models",)))
|
||||
|
||||
threads = [threading.Thread(target=try_start) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
barrier.set()
|
||||
|
||||
assert sum(results) == 1
|
||||
|
||||
def test_get_status_safe_during_scan(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
|
||||
statuses = []
|
||||
for _ in range(100):
|
||||
statuses.append(fresh_seeder.get_status())
|
||||
time.sleep(0.001)
|
||||
|
||||
barrier.set()
|
||||
|
||||
assert all(s.state in (State.RUNNING, State.IDLE, State.CANCELLING) for s in statuses)
|
||||
Reference in New Issue
Block a user