mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-22 16:04:12 +00:00
add AssetsResolver support
This commit is contained in:
@@ -12,7 +12,7 @@ from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
|
||||
from .timeutil import utcnow
|
||||
from .._assets_helpers import normalize_tags
|
||||
|
||||
@@ -38,30 +38,24 @@ async def check_fs_asset_exists_quick(
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
|
||||
Returns 'True' if there is already AssetCacheState record that matches this absolute path,
|
||||
AND (if provided) mtime_ns matches stored locator-state,
|
||||
AND (if provided) size_bytes matches verified size when known.
|
||||
"""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = select(sa.literal(True)).select_from(Asset)
|
||||
stmt = select(sa.literal(True)).select_from(AssetCacheState).join(
|
||||
Asset, Asset.hash == AssetCacheState.asset_hash
|
||||
).where(AssetCacheState.file_path == locator).limit(1)
|
||||
|
||||
conditions = [
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
]
|
||||
|
||||
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
|
||||
# If verified size is 0 (unknown), we don't force equality.
|
||||
if size_bytes is not None:
|
||||
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
# If mtime_ns provided require the locator-state to exist and match.
|
||||
conds = []
|
||||
if mtime_ns is not None:
|
||||
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
|
||||
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
|
||||
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
|
||||
if size_bytes is not None:
|
||||
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
stmt = stmt.where(*conditions).limit(1)
|
||||
if conds:
|
||||
stmt = stmt.where(*conds)
|
||||
|
||||
row = (await session.execute(stmt)).first()
|
||||
return row is not None
|
||||
@@ -85,11 +79,11 @@ async def ingest_fs_asset(
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates or updates Asset record for a local (fs) asset.
|
||||
Upsert Asset identity row + cache state pointing at local file.
|
||||
|
||||
Always:
|
||||
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
|
||||
- Insert AssetLocatorState if missing; else update mtime_ns if different.
|
||||
- Insert AssetCacheState if missing; else update mtime_ns if different.
|
||||
|
||||
Optionally (when info_name is provided):
|
||||
- Create an AssetInfo (no refcount changes).
|
||||
@@ -126,8 +120,6 @@ async def ingest_fs_asset(
|
||||
size_bytes=int(size_bytes),
|
||||
mime_type=mime_type,
|
||||
refcount=0,
|
||||
storage_backend="fs",
|
||||
storage_locator=locator,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
)
|
||||
@@ -145,21 +137,19 @@ async def ingest_fs_asset(
|
||||
if mime_type and existing.mime_type != mime_type:
|
||||
existing.mime_type = mime_type
|
||||
changed = True
|
||||
if existing.storage_locator != locator:
|
||||
existing.storage_locator = locator
|
||||
changed = True
|
||||
if changed:
|
||||
existing.updated_at = datetime_now
|
||||
out["asset_updated"] = True
|
||||
else:
|
||||
logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash)
|
||||
|
||||
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
|
||||
# ---- Step 2: INSERT/UPDATE AssetCacheState (mtime_ns, file_path) ----
|
||||
with contextlib.suppress(IntegrityError):
|
||||
async with session.begin_nested():
|
||||
session.add(
|
||||
AssetLocatorState(
|
||||
AssetCacheState(
|
||||
asset_hash=asset_hash,
|
||||
file_path=locator,
|
||||
mtime_ns=int(mtime_ns),
|
||||
)
|
||||
)
|
||||
@@ -167,11 +157,17 @@ async def ingest_fs_asset(
|
||||
out["state_created"] = True
|
||||
|
||||
if not out["state_created"]:
|
||||
state = await session.get(AssetLocatorState, asset_hash)
|
||||
state = await session.get(AssetCacheState, asset_hash)
|
||||
if state is not None:
|
||||
desired_mtime = int(mtime_ns)
|
||||
if state.mtime_ns != desired_mtime:
|
||||
state.mtime_ns = desired_mtime
|
||||
changed = False
|
||||
if state.file_path != locator:
|
||||
state.file_path = locator
|
||||
changed = True
|
||||
if state.mtime_ns != int(mtime_ns):
|
||||
state.mtime_ns = int(mtime_ns)
|
||||
changed = True
|
||||
if changed:
|
||||
await session.flush()
|
||||
out["state_updated"] = True
|
||||
else:
|
||||
logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash)
|
||||
@@ -278,11 +274,10 @@ async def touch_asset_infos_by_fs_path(
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(Asset)
|
||||
.select_from(AssetCacheState)
|
||||
.where(
|
||||
Asset.hash == AssetInfo.asset_hash,
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
AssetCacheState.asset_hash == AssetInfo.asset_hash,
|
||||
AssetCacheState.file_path == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -337,13 +332,6 @@ async def list_asset_infos_page(
|
||||
|
||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||
"""
|
||||
# Clamp
|
||||
if limit <= 0:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
|
||||
# Build base query
|
||||
base = (
|
||||
@@ -419,6 +407,66 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
|
||||
return await session.get(AssetCacheState, asset_hash)
|
||||
|
||||
|
||||
async def list_asset_locations(
|
||||
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
|
||||
) -> list[AssetLocation]:
|
||||
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
|
||||
if provider:
|
||||
stmt = stmt.where(AssetLocation.provider == provider)
|
||||
return (await session.execute(stmt)).scalars().all()
|
||||
|
||||
|
||||
async def upsert_asset_location(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
provider: str,
|
||||
locator: str,
|
||||
expected_size_bytes: Optional[int] = None,
|
||||
etag: Optional[str] = None,
|
||||
last_modified: Optional[str] = None,
|
||||
) -> AssetLocation:
|
||||
loc = (
|
||||
await session.execute(
|
||||
select(AssetLocation).where(
|
||||
AssetLocation.asset_hash == asset_hash,
|
||||
AssetLocation.provider == provider,
|
||||
AssetLocation.locator == locator,
|
||||
).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if loc:
|
||||
changed = False
|
||||
if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes:
|
||||
loc.expected_size_bytes = expected_size_bytes
|
||||
changed = True
|
||||
if etag is not None and loc.etag != etag:
|
||||
loc.etag = etag
|
||||
changed = True
|
||||
if last_modified is not None and loc.last_modified != last_modified:
|
||||
loc.last_modified = last_modified
|
||||
changed = True
|
||||
if changed:
|
||||
await session.flush()
|
||||
return loc
|
||||
|
||||
loc = AssetLocation(
|
||||
asset_hash=asset_hash,
|
||||
provider=provider,
|
||||
locator=locator,
|
||||
expected_size_bytes=expected_size_bytes,
|
||||
etag=etag,
|
||||
last_modified=last_modified,
|
||||
)
|
||||
session.add(loc)
|
||||
await session.flush()
|
||||
return loc
|
||||
|
||||
|
||||
async def create_asset_info_for_existing_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -925,7 +973,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
# store numeric; SQLAlchemy will coerce to Numeric
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": value})
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
@@ -943,7 +992,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
rows.append({"key": key, "ordinal": i, "val_num": x})
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user