Files
ComfyUI/app/assets/database/queries/cache_state.py
Luke Mino-Altherr 3fee0962c6 feat: implement two-phase scanning architecture (fast + enrich)
Phase 1 (FAST): Creates stub records with filesystem metadata only
- path, size, mtime - no file content reading
- Populates asset database quickly on startup

Phase 2 (ENRICH): Extracts metadata and computes hashes
- Safetensors header parsing, MIME types
- Optional blake3 hash computation
- Updates existing stub records

Changes:
- Add ScanPhase enum (FAST, ENRICH, FULL)
- Add enrichment_level column to AssetCacheState (0=stub, 1=metadata, 2=hashed)
- Add build_stub_specs() for fast scanning without metadata extraction
- Add get_unenriched_cache_states(), enrich_asset(), enrich_assets_batch()
- Add start_fast(), start_enrich() convenience methods to AssetSeeder
- Update start() to accept phase parameter (defaults to FULL)
- Split _run_scan() into _run_fast_phase() and _run_enrich_phase()
- Add migration 0003_add_enrichment_level.py
- Update tests for new architecture

Amp-Thread-ID: https://ampcode.com/threads/T-019c4eef-1568-778f-aede-38254728f848
Co-authored-by: Amp <amp@ampcode.com>
2026-02-11 17:41:38 -08:00

452 lines
13 KiB
Python

import os
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string
class CacheStateRow(NamedTuple):
"""Row from cache state query with joined asset data."""
state_id: int
file_path: str
mtime_ns: int | None
needs_verify: bool
asset_id: str
asset_hash: str | None
size_bytes: int
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
(
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
)
.scalars()
.all()
)
def upsert_cache_state(
session: Session,
asset_id: str,
file_path: str,
mtime_ns: int,
) -> tuple[bool, bool]:
"""Upsert a cache state by file_path. Returns (created, updated).
Also restores cache states that were previously marked as missing.
"""
vals = {
"asset_id": asset_id,
"file_path": file_path,
"mtime_ns": int(mtime_ns),
"is_missing": False,
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
if created:
return True, False
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == file_path)
.where(
sa.or_(
AssetCacheState.asset_id != asset_id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
AssetCacheState.is_missing == True, # noqa: E712
)
)
.values(asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False)
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0
return False, updated
def mark_cache_states_missing_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Mark cache states as missing when file_path doesn't match any valid prefix.
This is a non-destructive soft-delete that preserves user metadata.
Cache states can be restored if the file reappears in a future scan.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states marked as missing
"""
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
result = session.execute(
sa.update(AssetCacheState)
.where(~matches_valid_prefix)
.where(AssetCacheState.is_missing == False) # noqa: E712
.values(is_missing=True)
)
return result.rowcount
def restore_cache_states_by_paths(session: Session, file_paths: list[str]) -> int:
"""Restore cache states that were previously marked as missing.
Called when a file path is re-scanned and found to exist.
Args:
session: Database session
file_paths: List of file paths that exist and should be restored
Returns:
Number of cache states restored
"""
if not file_paths:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.file_path.in_(file_paths))
.where(AssetCacheState.is_missing == True) # noqa: E712
.values(is_missing=False)
)
return result.rowcount
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
"""Get IDs of unhashed assets (hash=None) with no active cache states.
An asset is considered unreferenced if it has no cache states,
or all its cache states are marked as missing.
Returns:
List of asset IDs that are unreferenced
"""
active_cache_state_exists = (
sa.select(sa.literal(1))
.where(AssetCacheState.asset_id == Asset.id)
.where(AssetCacheState.is_missing == False) # noqa: E712
.correlate(Asset)
.exists()
)
unreferenced_subq = sa.select(Asset.id).where(
Asset.hash.is_(None), ~active_cache_state_exists
)
return [row[0] for row in session.execute(unreferenced_subq).all()]
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""Delete assets and their AssetInfos by ID.
Args:
session: Database session
asset_ids: List of asset IDs to delete
Returns:
Number of assets deleted
"""
if not asset_ids:
return 0
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
def get_cache_states_for_prefixes(
session: Session,
prefixes: list[str],
*,
include_missing: bool = False,
) -> list[CacheStateRow]:
"""Get all cache states with paths matching any of the given prefixes.
Args:
session: Database session
prefixes: List of absolute directory prefixes to match
include_missing: If False (default), exclude cache states marked as missing
Returns:
List of cache state rows with joined asset data, ordered by asset_id, state_id
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
)
if not include_missing:
query = query.where(AssetCacheState.is_missing == False) # noqa: E712
rows = session.execute(
query.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
).all()
return [
CacheStateRow(
state_id=row[0],
file_path=row[1],
mtime_ns=row[2],
needs_verify=row[3],
asset_id=row[4],
asset_hash=row[5],
size_bytes=int(row[6] or 0),
)
for row in rows
]
def bulk_update_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
"""Set needs_verify flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(needs_verify=value)
)
return result.rowcount
def bulk_update_is_missing(session: Session, state_ids: list[int], value: bool) -> int:
"""Set is_missing flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(is_missing=value)
)
return result.rowcount
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
"""Delete cache states by their IDs.
Returns: Number of rows deleted
"""
if not state_ids:
return 0
result = session.execute(
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
)
return result.rowcount
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
"""Delete a seed asset (hash is None) and its AssetInfos.
Returns: True if asset was deleted, False if not found
"""
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
asset = session.get(Asset, asset_id)
if asset:
session.delete(asset)
return True
return False
class UnenrichedAssetRow(NamedTuple):
"""Row for assets needing enrichment."""
cache_state_id: int
asset_id: str
asset_info_id: str
file_path: str
enrichment_level: int
def get_unenriched_cache_states(
session: Session,
prefixes: list[str],
max_level: int = 0,
limit: int = 1000,
) -> list[UnenrichedAssetRow]:
"""Get cache states that need enrichment (enrichment_level <= max_level).
Args:
session: Database session
prefixes: List of absolute directory prefixes to scan
max_level: Maximum enrichment level to include (0=stubs, 1=metadata done)
limit: Maximum number of rows to return
Returns:
List of unenriched asset rows with file paths
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetCacheState.id,
AssetCacheState.asset_id,
AssetInfo.id,
AssetCacheState.file_path,
AssetCacheState.enrichment_level,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.join(AssetInfo, AssetInfo.asset_id == Asset.id)
.where(sa.or_(*conds))
.where(AssetCacheState.is_missing == False) # noqa: E712
.where(AssetCacheState.enrichment_level <= max_level)
.order_by(AssetCacheState.id.asc())
.limit(limit)
)
rows = session.execute(query).all()
return [
UnenrichedAssetRow(
cache_state_id=row[0],
asset_id=row[1],
asset_info_id=row[2],
file_path=row[3],
enrichment_level=row[4],
)
for row in rows
]
def update_enrichment_level(
session: Session,
cache_state_id: int,
level: int,
) -> None:
"""Update the enrichment level for a cache state."""
session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id == cache_state_id)
.values(enrichment_level=level)
)
def bulk_update_enrichment_level(
session: Session,
cache_state_ids: list[int],
level: int,
) -> int:
"""Update enrichment level for multiple cache states.
Returns: Number of rows updated
"""
if not cache_state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(cache_state_ids))
.values(enrichment_level=level)
)
return result.rowcount
def bulk_insert_cache_states_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: asset_id, file_path, mtime_ns
The is_missing field is automatically set to False for new inserts.
"""
if not rows:
return
enriched_rows = [{**row, "is_missing": False} for row in rows]
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
index_elements=[AssetCacheState.file_path]
)
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(4)):
session.execute(ins, chunk)
def get_cache_states_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query cache states to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
paths = list(path_to_asset.keys())
winners: set[str] = set()
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
result = session.execute(
select(AssetCacheState.file_path).where(
AssetCacheState.file_path.in_(chunk),
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
)
)
winners.update(result.scalars().all())
return winners