mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 22:20:03 +00:00
Amp-Thread-ID: https://ampcode.com/threads/T-019c3167-e8f1-7409-904b-5fc0edaeef37 Co-authored-by: Amp <amp@ampcode.com>
337 lines
9.6 KiB
Python
337 lines
9.6 KiB
Python
import os
|
|
from typing import NamedTuple, Sequence
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import select
|
|
from sqlalchemy.dialects import sqlite
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
|
from app.assets.database.queries.common import (
|
|
MAX_BIND_PARAMS,
|
|
calculate_rows_per_statement,
|
|
iter_chunks,
|
|
)
|
|
from app.assets.helpers import escape_sql_like_string
|
|
|
|
|
|
class CacheStateRow(NamedTuple):
|
|
"""Row from cache state query with joined asset data."""
|
|
|
|
state_id: int
|
|
file_path: str
|
|
mtime_ns: int | None
|
|
needs_verify: bool
|
|
asset_id: str
|
|
asset_hash: str | None
|
|
size_bytes: int
|
|
|
|
|
|
def list_cache_states_by_asset_id(
|
|
session: Session, *, asset_id: str
|
|
) -> Sequence[AssetCacheState]:
|
|
return (
|
|
(
|
|
session.execute(
|
|
select(AssetCacheState)
|
|
.where(AssetCacheState.asset_id == asset_id)
|
|
.order_by(AssetCacheState.id.asc())
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
|
|
def upsert_cache_state(
|
|
session: Session,
|
|
asset_id: str,
|
|
file_path: str,
|
|
mtime_ns: int,
|
|
) -> tuple[bool, bool]:
|
|
"""Upsert a cache state by file_path. Returns (created, updated).
|
|
|
|
Also restores cache states that were previously marked as missing.
|
|
"""
|
|
vals = {
|
|
"asset_id": asset_id,
|
|
"file_path": file_path,
|
|
"mtime_ns": int(mtime_ns),
|
|
"is_missing": False,
|
|
}
|
|
ins = (
|
|
sqlite.insert(AssetCacheState)
|
|
.values(**vals)
|
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
|
)
|
|
res = session.execute(ins)
|
|
created = int(res.rowcount or 0) > 0
|
|
|
|
if created:
|
|
return True, False
|
|
|
|
upd = (
|
|
sa.update(AssetCacheState)
|
|
.where(AssetCacheState.file_path == file_path)
|
|
.where(
|
|
sa.or_(
|
|
AssetCacheState.asset_id != asset_id,
|
|
AssetCacheState.mtime_ns.is_(None),
|
|
AssetCacheState.mtime_ns != int(mtime_ns),
|
|
AssetCacheState.is_missing == True, # noqa: E712
|
|
)
|
|
)
|
|
.values(asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False)
|
|
)
|
|
res2 = session.execute(upd)
|
|
updated = int(res2.rowcount or 0) > 0
|
|
return False, updated
|
|
|
|
|
|
def mark_cache_states_missing_outside_prefixes(
|
|
session: Session, valid_prefixes: list[str]
|
|
) -> int:
|
|
"""Mark cache states as missing when file_path doesn't match any valid prefix.
|
|
|
|
This is a non-destructive soft-delete that preserves user metadata.
|
|
Cache states can be restored if the file reappears in a future scan.
|
|
|
|
Args:
|
|
session: Database session
|
|
valid_prefixes: List of absolute directory prefixes that are valid
|
|
|
|
Returns:
|
|
Number of cache states marked as missing
|
|
"""
|
|
if not valid_prefixes:
|
|
return 0
|
|
|
|
def make_prefix_condition(prefix: str):
|
|
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
|
escaped, esc = escape_sql_like_string(base)
|
|
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
|
|
|
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
|
|
result = session.execute(
|
|
sa.update(AssetCacheState)
|
|
.where(~matches_valid_prefix)
|
|
.where(AssetCacheState.is_missing == False) # noqa: E712
|
|
.values(is_missing=True)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def restore_cache_states_by_paths(session: Session, file_paths: list[str]) -> int:
|
|
"""Restore cache states that were previously marked as missing.
|
|
|
|
Called when a file path is re-scanned and found to exist.
|
|
|
|
Args:
|
|
session: Database session
|
|
file_paths: List of file paths that exist and should be restored
|
|
|
|
Returns:
|
|
Number of cache states restored
|
|
"""
|
|
if not file_paths:
|
|
return 0
|
|
|
|
result = session.execute(
|
|
sa.update(AssetCacheState)
|
|
.where(AssetCacheState.file_path.in_(file_paths))
|
|
.where(AssetCacheState.is_missing == True) # noqa: E712
|
|
.values(is_missing=False)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
|
|
"""Get IDs of unhashed assets (hash=None) with no active cache states.
|
|
|
|
An asset is considered unreferenced if it has no cache states,
|
|
or all its cache states are marked as missing.
|
|
|
|
Returns:
|
|
List of asset IDs that are unreferenced
|
|
"""
|
|
active_cache_state_exists = (
|
|
sa.select(sa.literal(1))
|
|
.where(AssetCacheState.asset_id == Asset.id)
|
|
.where(AssetCacheState.is_missing == False) # noqa: E712
|
|
.correlate(Asset)
|
|
.exists()
|
|
)
|
|
unreferenced_subq = sa.select(Asset.id).where(
|
|
Asset.hash.is_(None), ~active_cache_state_exists
|
|
)
|
|
return [row[0] for row in session.execute(unreferenced_subq).all()]
|
|
|
|
|
|
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
|
|
"""Delete assets and their AssetInfos by ID.
|
|
|
|
Args:
|
|
session: Database session
|
|
asset_ids: List of asset IDs to delete
|
|
|
|
Returns:
|
|
Number of assets deleted
|
|
"""
|
|
if not asset_ids:
|
|
return 0
|
|
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
|
|
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
|
|
return result.rowcount
|
|
|
|
|
|
def get_cache_states_for_prefixes(
|
|
session: Session,
|
|
prefixes: list[str],
|
|
*,
|
|
include_missing: bool = False,
|
|
) -> list[CacheStateRow]:
|
|
"""Get all cache states with paths matching any of the given prefixes.
|
|
|
|
Args:
|
|
session: Database session
|
|
prefixes: List of absolute directory prefixes to match
|
|
include_missing: If False (default), exclude cache states marked as missing
|
|
|
|
Returns:
|
|
List of cache state rows with joined asset data, ordered by asset_id, state_id
|
|
"""
|
|
if not prefixes:
|
|
return []
|
|
|
|
conds = []
|
|
for p in prefixes:
|
|
base = os.path.abspath(p)
|
|
if not base.endswith(os.sep):
|
|
base += os.sep
|
|
escaped, esc = escape_sql_like_string(base)
|
|
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
|
|
|
query = (
|
|
sa.select(
|
|
AssetCacheState.id,
|
|
AssetCacheState.file_path,
|
|
AssetCacheState.mtime_ns,
|
|
AssetCacheState.needs_verify,
|
|
AssetCacheState.asset_id,
|
|
Asset.hash,
|
|
Asset.size_bytes,
|
|
)
|
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
|
.where(sa.or_(*conds))
|
|
)
|
|
|
|
if not include_missing:
|
|
query = query.where(AssetCacheState.is_missing == False) # noqa: E712
|
|
|
|
rows = session.execute(
|
|
query.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
|
).all()
|
|
|
|
return [
|
|
CacheStateRow(
|
|
state_id=row[0],
|
|
file_path=row[1],
|
|
mtime_ns=row[2],
|
|
needs_verify=row[3],
|
|
asset_id=row[4],
|
|
asset_hash=row[5],
|
|
size_bytes=int(row[6] or 0),
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
def bulk_update_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
|
|
"""Set needs_verify flag for multiple cache states.
|
|
|
|
Returns: Number of rows updated
|
|
"""
|
|
if not state_ids:
|
|
return 0
|
|
result = session.execute(
|
|
sa.update(AssetCacheState)
|
|
.where(AssetCacheState.id.in_(state_ids))
|
|
.values(needs_verify=value)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
|
|
"""Delete cache states by their IDs.
|
|
|
|
Returns: Number of rows deleted
|
|
"""
|
|
if not state_ids:
|
|
return 0
|
|
result = session.execute(
|
|
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
|
|
"""Delete a seed asset (hash is None) and its AssetInfos.
|
|
|
|
Returns: True if asset was deleted, False if not found
|
|
"""
|
|
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
|
|
asset = session.get(Asset, asset_id)
|
|
if asset:
|
|
session.delete(asset)
|
|
return True
|
|
return False
|
|
|
|
|
|
def bulk_insert_cache_states_ignore_conflicts(
|
|
session: Session,
|
|
rows: list[dict],
|
|
) -> None:
|
|
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
|
|
|
|
Each dict should have: asset_id, file_path, mtime_ns
|
|
The is_missing field is automatically set to False for new inserts.
|
|
"""
|
|
if not rows:
|
|
return
|
|
enriched_rows = [{**row, "is_missing": False} for row in rows]
|
|
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
|
|
index_elements=[AssetCacheState.file_path]
|
|
)
|
|
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(4)):
|
|
session.execute(ins, chunk)
|
|
|
|
|
|
def get_cache_states_by_paths_and_asset_ids(
|
|
session: Session,
|
|
path_to_asset: dict[str, str],
|
|
) -> set[str]:
|
|
"""Query cache states to find paths where our asset_id won the insert.
|
|
|
|
Args:
|
|
path_to_asset: Mapping of file_path -> asset_id we tried to insert
|
|
|
|
Returns:
|
|
Set of file_paths where our asset_id is present
|
|
"""
|
|
if not path_to_asset:
|
|
return set()
|
|
|
|
paths = list(path_to_asset.keys())
|
|
winners: set[str] = set()
|
|
|
|
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
|
|
result = session.execute(
|
|
select(AssetCacheState.file_path).where(
|
|
AssetCacheState.file_path.in_(chunk),
|
|
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
|
|
)
|
|
)
|
|
winners.update(result.scalars().all())
|
|
|
|
return winners
|