Files
ComfyUI/app/assets/database/queries/cache_state.py
Luke Mino-Altherr f13b7aeba2 feat: non-destructive asset pruning with is_missing flag
- Add is_missing column to AssetCacheState for soft-delete
- Replace hard-delete pruning with mark_cache_states_missing_outside_prefixes
- Auto-restore missing cache states when files are re-scanned
- Filter out missing cache states from queries by default
- Rename functions for clarity:
  - mark_cache_states_missing_outside_prefixes (was delete_cache_states_outside_prefixes)
  - get_unreferenced_unhashed_asset_ids (was get_orphaned_seed_asset_ids)
  - mark_assets_missing_outside_prefixes (was prune_orphaned_assets)
  - mark_missing_outside_prefixes_safely (was prune_orphans_safely)
- Add restore_cache_states_by_paths for explicit restoration
- Add cleanup_unreferenced_assets for explicit hard-delete when needed
- Update API endpoint /api/assets/prune to use new soft-delete behavior

This preserves user metadata (tags, etc.) when base directories change,
allowing assets to be restored when the original paths become available again.

Amp-Thread-ID: https://ampcode.com/threads/T-019c3114-bf28-73a9-a4d2-85b208fd5462
Co-authored-by: Amp <amp@ampcode.com>
2026-02-11 17:41:38 -08:00

337 lines
9.6 KiB
Python

import os
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string
class CacheStateRow(NamedTuple):
"""Row from cache state query with joined asset data."""
state_id: int
file_path: str
mtime_ns: int | None
needs_verify: bool
asset_id: str
asset_hash: str | None
size_bytes: int
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
(
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
)
.scalars()
.all()
)
def upsert_cache_state(
session: Session,
asset_id: str,
file_path: str,
mtime_ns: int,
) -> tuple[bool, bool]:
"""Upsert a cache state by file_path. Returns (created, updated).
Also restores cache states that were previously marked as missing.
"""
vals = {
"asset_id": asset_id,
"file_path": file_path,
"mtime_ns": int(mtime_ns),
"is_missing": False,
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
if created:
return True, False
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == file_path)
.where(
sa.or_(
AssetCacheState.asset_id != asset_id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
AssetCacheState.is_missing == True, # noqa: E712
)
)
.values(asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False)
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0
return False, updated
def mark_cache_states_missing_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Mark cache states as missing when file_path doesn't match any valid prefix.
This is a non-destructive soft-delete that preserves user metadata.
Cache states can be restored if the file reappears in a future scan.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states marked as missing
"""
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
result = session.execute(
sa.update(AssetCacheState)
.where(~matches_valid_prefix)
.where(AssetCacheState.is_missing == False) # noqa: E712
.values(is_missing=True)
)
return result.rowcount
def restore_cache_states_by_paths(session: Session, file_paths: list[str]) -> int:
"""Restore cache states that were previously marked as missing.
Called when a file path is re-scanned and found to exist.
Args:
session: Database session
file_paths: List of file paths that exist and should be restored
Returns:
Number of cache states restored
"""
if not file_paths:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.file_path.in_(file_paths))
.where(AssetCacheState.is_missing == True) # noqa: E712
.values(is_missing=False)
)
return result.rowcount
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
"""Get IDs of unhashed assets (hash=None) with no active cache states.
An asset is considered unreferenced if it has no cache states,
or all its cache states are marked as missing.
Returns:
List of asset IDs that are unreferenced
"""
active_cache_state_exists = (
sa.select(sa.literal(1))
.where(AssetCacheState.asset_id == Asset.id)
.where(AssetCacheState.is_missing == False) # noqa: E712
.correlate(Asset)
.exists()
)
unreferenced_subq = sa.select(Asset.id).where(
Asset.hash.is_(None), ~active_cache_state_exists
)
return [row[0] for row in session.execute(unreferenced_subq).all()]
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""Delete assets and their AssetInfos by ID.
Args:
session: Database session
asset_ids: List of asset IDs to delete
Returns:
Number of assets deleted
"""
if not asset_ids:
return 0
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
def get_cache_states_for_prefixes(
session: Session,
prefixes: list[str],
*,
include_missing: bool = False,
) -> list[CacheStateRow]:
"""Get all cache states with paths matching any of the given prefixes.
Args:
session: Database session
prefixes: List of absolute directory prefixes to match
include_missing: If False (default), exclude cache states marked as missing
Returns:
List of cache state rows with joined asset data, ordered by asset_id, state_id
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
)
if not include_missing:
query = query.where(AssetCacheState.is_missing == False) # noqa: E712
rows = session.execute(
query.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
).all()
return [
CacheStateRow(
state_id=row[0],
file_path=row[1],
mtime_ns=row[2],
needs_verify=row[3],
asset_id=row[4],
asset_hash=row[5],
size_bytes=int(row[6] or 0),
)
for row in rows
]
def bulk_set_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
"""Set needs_verify flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(needs_verify=value)
)
return result.rowcount
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
"""Delete cache states by their IDs.
Returns: Number of rows deleted
"""
if not state_ids:
return 0
result = session.execute(
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
)
return result.rowcount
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
"""Delete a seed asset (hash is None) and its AssetInfos.
Returns: True if asset was deleted, False if not found
"""
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
asset = session.get(Asset, asset_id)
if asset:
session.delete(asset)
return True
return False
def bulk_insert_cache_states_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: asset_id, file_path, mtime_ns
The is_missing field is automatically set to False for new inserts.
"""
if not rows:
return
enriched_rows = [{**row, "is_missing": False} for row in rows]
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
index_elements=[AssetCacheState.file_path]
)
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(4)):
session.execute(ins, chunk)
def get_cache_states_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query cache states to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
paths = list(path_to_asset.keys())
winners: set[str] = set()
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
result = session.execute(
select(AssetCacheState.file_path).where(
AssetCacheState.file_path.in_(chunk),
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
)
)
winners.update(result.scalars().all())
return winners