mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 00:49:13 +00:00
refactoring: use the same code for "scan task" and realtime DB population
This commit is contained in:
@@ -7,10 +7,11 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
import folder_paths
|
||||
|
||||
from . import assets_manager
|
||||
from .api import schemas_out
|
||||
|
||||
import folder_paths
|
||||
from ._assets_helpers import get_comfy_models_folders
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +37,7 @@ class ScanProgress:
|
||||
errors: int = 0
|
||||
last_error: Optional[str] = None
|
||||
|
||||
# Optional details for diagnostics
|
||||
# Optional details for diagnostics (e.g., files per bucket)
|
||||
details: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -49,8 +50,6 @@ def _new_scan_id(root: RootType) -> str:
|
||||
|
||||
|
||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
# make shallow copies to avoid external mutation
|
||||
states = [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
|
||||
return schemas_out.AssetScanStatusResponse(
|
||||
scans=[
|
||||
schemas_out.AssetScanStatus(
|
||||
@@ -65,7 +64,7 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
errors=s.errors,
|
||||
last_error=s.last_error,
|
||||
)
|
||||
for s in states
|
||||
for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
|
||||
]
|
||||
)
|
||||
|
||||
@@ -94,15 +93,12 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
|
||||
results: list[ScanProgress] = []
|
||||
for root in normalized:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
# already running; return the live progress object
|
||||
results.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
# Create fresh progress
|
||||
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
|
||||
# Start task
|
||||
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}")
|
||||
RUNNING_TASKS[root] = task
|
||||
results.append(prog)
|
||||
@@ -151,24 +147,21 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
|
||||
prog.last_error = str(exc)
|
||||
finally:
|
||||
prog.finished_at = time.time()
|
||||
# Drop the task entry if it's the current one
|
||||
t = RUNNING_TASKS.get(root)
|
||||
if t and t.done():
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _scan_models(prog: ScanProgress) -> None:
|
||||
# Iterate all folder_names whose base paths lie under the Comfy 'models' directory
|
||||
models_root = os.path.abspath(os.path.join(folder_paths.base_path, "models"))
|
||||
"""
|
||||
Scan all configured model buckets from folder_paths.folder_names_and_paths,
|
||||
restricted to entries whose base paths lie under folder_paths.models_dir
|
||||
(per get_comfy_models_folders). We trust those mappings and do not try to
|
||||
infer anything else here.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = get_comfy_models_folders()
|
||||
|
||||
# Build list of (folder_name, base_paths[]) that are configured for this category.
|
||||
# If any path for the category lies under 'models', include the category.
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
|
||||
plans: list[tuple[str, str]] = [] # (abs_path, file_name_for_tags)
|
||||
plans: list[str] = [] # absolute file paths to ingest
|
||||
per_bucket: dict[str, int] = {}
|
||||
|
||||
for folder_name, bases in targets:
|
||||
@@ -198,13 +191,12 @@ async def _scan_models(prog: ScanProgress) -> None:
|
||||
|
||||
try:
|
||||
if not os.path.getsize(abs_path):
|
||||
continue
|
||||
continue # skip empty files
|
||||
except OSError as e:
|
||||
LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e)
|
||||
continue
|
||||
|
||||
file_name_for_tags = os.path.join(folder_name, rel_path)
|
||||
plans.append((abs_path, file_name_for_tags))
|
||||
plans.append(abs_path)
|
||||
count_valid += 1
|
||||
|
||||
if count_valid:
|
||||
@@ -221,16 +213,12 @@ async def _scan_models(prog: ScanProgress) -> None:
|
||||
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
|
||||
tasks: list[asyncio.Task] = []
|
||||
|
||||
for abs_path, name_for_tags in plans:
|
||||
async def worker(fp_abs: str = abs_path, fn_rel: str = name_for_tags):
|
||||
for abs_path in plans:
|
||||
async def worker(fp_abs: str = abs_path):
|
||||
try:
|
||||
# Offload sync ingestion into a thread
|
||||
await asyncio.to_thread(
|
||||
assets_manager.populate_db_with_asset,
|
||||
["models"],
|
||||
fn_rel,
|
||||
fp_abs,
|
||||
)
|
||||
# Offload sync ingestion into a thread; populate_db_with_asset
|
||||
# derives name and tags from the path using _assets_helpers.
|
||||
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
|
||||
except Exception as e:
|
||||
prog.errors += 1
|
||||
prog.last_error = str(e)
|
||||
@@ -260,7 +248,10 @@ def _count_files_in_tree(base_abs: str) -> int:
|
||||
|
||||
|
||||
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None:
|
||||
# Guard: base_dir must be a directory
|
||||
"""
|
||||
Generic scanner for input/output roots. We pass only the absolute path to
|
||||
populate_db_with_asset and let it derive the relative name and tags.
|
||||
"""
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs)
|
||||
@@ -272,24 +263,27 @@ async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress
|
||||
tasks: list[asyncio.Task] = []
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
rel = os.path.relpath(os.path.join(dirpath, name), base_abs)
|
||||
abs_path = os.path.join(base_abs, rel)
|
||||
abs_path = os.path.abspath(os.path.join(dirpath, name))
|
||||
|
||||
# Safety: ensure within base
|
||||
try:
|
||||
if os.path.commonpath([os.path.abspath(abs_path), base_abs]) != base_abs:
|
||||
if os.path.commonpath([abs_path, base_abs]) != base_abs:
|
||||
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path)
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
async def worker(fp_abs: str = abs_path, fn_rel: str = rel):
|
||||
# Skip empty files and handle stat errors
|
||||
try:
|
||||
if not os.path.getsize(abs_path):
|
||||
continue
|
||||
except OSError as e:
|
||||
LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e)
|
||||
continue
|
||||
|
||||
async def worker(fp_abs: str = abs_path):
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
assets_manager.populate_db_with_asset,
|
||||
[root],
|
||||
fn_rel,
|
||||
fp_abs,
|
||||
)
|
||||
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
|
||||
except Exception as e:
|
||||
prog.errors += 1
|
||||
prog.last_error = str(e)
|
||||
|
||||
Reference in New Issue
Block a user