refactoring: use the same code for "scan task" and realtime DB population

This commit is contained in:
bigcat88
2025-08-25 13:31:56 +03:00
parent d7464e9e73
commit 09dabf95bc
7 changed files with 178 additions and 98 deletions

99
app/_assets_helpers.py Normal file
View File

@@ -0,0 +1,99 @@
import os
from pathlib import Path
from typing import Optional, Literal, Sequence
import folder_paths
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, normalize_tags([root_category, *parent_parts])
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]

View File

@@ -1,7 +1,7 @@
import logging
import mimetypes
import os
from typing import Optional, Sequence
from pathlib import Path
from comfy.cli_args import args
from comfy_api.internal import async_to_sync
@@ -26,6 +26,7 @@ from .database.services import (
create_asset_info_for_existing_asset,
)
from .api import schemas_out
from ._assets_helpers import get_name_and_tags_from_asset_path
async def asset_exists(*, asset_hash: str) -> bool:
@@ -33,16 +34,20 @@ async def asset_exists(*, asset_hash: str) -> bool:
return await asset_exists_by_hash(session, asset_hash=asset_hash)
def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None:
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if not args.disable_model_processing:
p = Path(file_name)
dir_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*tags, *dir_parts])),
file_name=p.name,
file_path=file_path,
)
if tags is None:
tags = []
try:
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*path_tags, *tags])),
file_name=asset_name,
file_path=file_path,
)
except ValueError:
logging.exception("Cant parse '%s' as an asset file path.", file_path)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:

View File

@@ -7,10 +7,11 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, Optional, Sequence
import folder_paths
from . import assets_manager
from .api import schemas_out
import folder_paths
from ._assets_helpers import get_comfy_models_folders
LOGGER = logging.getLogger(__name__)
@@ -36,7 +37,7 @@ class ScanProgress:
errors: int = 0
last_error: Optional[str] = None
# Optional details for diagnostics
# Optional details for diagnostics (e.g., files per bucket)
details: dict[str, int] = field(default_factory=dict)
@@ -49,8 +50,6 @@ def _new_scan_id(root: RootType) -> str:
def current_statuses() -> schemas_out.AssetScanStatusResponse:
# make shallow copies to avoid external mutation
states = [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
return schemas_out.AssetScanStatusResponse(
scans=[
schemas_out.AssetScanStatus(
@@ -65,7 +64,7 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse:
errors=s.errors,
last_error=s.last_error,
)
for s in states
for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
]
)
@@ -94,15 +93,12 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
results: list[ScanProgress] = []
for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
# already running; return the live progress object
results.append(PROGRESS_BY_ROOT[root])
continue
# Create fresh progress
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
# Start task
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}")
RUNNING_TASKS[root] = task
results.append(prog)
@@ -151,24 +147,21 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
prog.last_error = str(exc)
finally:
prog.finished_at = time.time()
# Drop the task entry if it's the current one
t = RUNNING_TASKS.get(root)
if t and t.done():
RUNNING_TASKS.pop(root, None)
async def _scan_models(prog: ScanProgress) -> None:
# Iterate all folder_names whose base paths lie under the Comfy 'models' directory
models_root = os.path.abspath(os.path.join(folder_paths.base_path, "models"))
"""
Scan all configured model buckets from folder_paths.folder_names_and_paths,
restricted to entries whose base paths lie under folder_paths.models_dir
(per get_comfy_models_folders). We trust those mappings and do not try to
infer anything else here.
"""
targets: list[tuple[str, list[str]]] = get_comfy_models_folders()
# Build list of (folder_name, base_paths[]) that are configured for this category.
# If any path for the category lies under 'models', include the category.
targets: list[tuple[str, list[str]]] = []
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
plans: list[tuple[str, str]] = [] # (abs_path, file_name_for_tags)
plans: list[str] = [] # absolute file paths to ingest
per_bucket: dict[str, int] = {}
for folder_name, bases in targets:
@@ -198,13 +191,12 @@ async def _scan_models(prog: ScanProgress) -> None:
try:
if not os.path.getsize(abs_path):
continue
continue # skip empty files
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
file_name_for_tags = os.path.join(folder_name, rel_path)
plans.append((abs_path, file_name_for_tags))
plans.append(abs_path)
count_valid += 1
if count_valid:
@@ -221,16 +213,12 @@ async def _scan_models(prog: ScanProgress) -> None:
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = []
for abs_path, name_for_tags in plans:
async def worker(fp_abs: str = abs_path, fn_rel: str = name_for_tags):
for abs_path in plans:
async def worker(fp_abs: str = abs_path):
try:
# Offload sync ingestion into a thread
await asyncio.to_thread(
assets_manager.populate_db_with_asset,
["models"],
fn_rel,
fp_abs,
)
# Offload sync ingestion into a thread; populate_db_with_asset
# derives name and tags from the path using _assets_helpers.
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
except Exception as e:
prog.errors += 1
prog.last_error = str(e)
@@ -260,7 +248,10 @@ def _count_files_in_tree(base_abs: str) -> int:
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None:
# Guard: base_dir must be a directory
"""
Generic scanner for input/output roots. We pass only the absolute path to
populate_db_with_asset and let it derive the relative name and tags.
"""
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs)
@@ -272,24 +263,27 @@ async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress
tasks: list[asyncio.Task] = []
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
rel = os.path.relpath(os.path.join(dirpath, name), base_abs)
abs_path = os.path.join(base_abs, rel)
abs_path = os.path.abspath(os.path.join(dirpath, name))
# Safety: ensure within base
try:
if os.path.commonpath([os.path.abspath(abs_path), base_abs]) != base_abs:
if os.path.commonpath([abs_path, base_abs]) != base_abs:
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path)
continue
except ValueError:
continue
async def worker(fp_abs: str = abs_path, fn_rel: str = rel):
# Skip empty files and handle stat errors
try:
if not os.path.getsize(abs_path):
continue
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
async def worker(fp_abs: str = abs_path):
try:
await asyncio.to_thread(
assets_manager.populate_db_with_asset,
[root],
fn_rel,
fp_abs,
)
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
except Exception as e:
prog.errors += 1
prog.last_error = str(e)

View File

@@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
from .timeutil import utcnow
from .._assets_helpers import normalize_tags
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
@@ -471,7 +471,7 @@ async def set_asset_info_tags(
Replace the tag set on an AssetInfo with `tags`. Idempotent.
Creates missing tag names as 'user'.
"""
desired = _normalize_tags(tags)
desired = normalize_tags(tags)
# current links
current = set(
@@ -691,7 +691,7 @@ async def add_tags_to_asset_info(
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = _normalize_tags(tags)
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
@@ -753,7 +753,7 @@ async def remove_tags_from_asset_info(
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = _normalize_tags(tags)
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
@@ -784,12 +784,8 @@ async def remove_tags_from_asset_info(
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def _normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
wanted = _normalize_tags(list(names))
wanted = normalize_tags(list(names))
if not wanted:
return []
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
@@ -808,8 +804,8 @@ def _apply_tag_filters(
exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = _normalize_tags(include_tags)
exclude_tags = _normalize_tags(exclude_tags)
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags: