mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-11 16:20:00 +00:00
add support for assets duplicates
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, Sequence, Optional, Iterable
|
||||
from typing import Any, Sequence, Optional, Iterable, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -82,14 +82,14 @@ async def ingest_fs_asset(
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Upsert Asset identity row + cache state pointing at local file.
|
||||
Upsert Asset identity row + cache state(s) pointing at local file.
|
||||
|
||||
Always:
|
||||
- Insert Asset if missing;
|
||||
- Insert AssetCacheState if missing; else update mtime_ns if different.
|
||||
- Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different.
|
||||
|
||||
Optionally (when info_name is provided):
|
||||
- Create an AssetInfo.
|
||||
- Create or update an AssetInfo on (asset_hash, owner_id, name).
|
||||
- Link provided tags to that AssetInfo.
|
||||
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
|
||||
* If False (default), create unknown tags.
|
||||
@@ -157,11 +157,16 @@ async def ingest_fs_asset(
|
||||
out["state_created"] = True
|
||||
|
||||
if not out["state_created"]:
|
||||
state = await session.get(AssetCacheState, asset_hash)
|
||||
# most likely a unique(file_path) conflict; update that row
|
||||
state = (
|
||||
await session.execute(
|
||||
select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if state is not None:
|
||||
changed = False
|
||||
if state.file_path != locator:
|
||||
state.file_path = locator
|
||||
if state.asset_hash != asset_hash:
|
||||
state.asset_hash = asset_hash
|
||||
changed = True
|
||||
if state.mtime_ns != int(mtime_ns):
|
||||
state.mtime_ns = int(mtime_ns)
|
||||
@@ -260,7 +265,15 @@ async def ingest_fs_asset(
|
||||
# )
|
||||
# start of adding metadata["filename"]
|
||||
if out["asset_info_id"] is not None:
|
||||
computed_filename = compute_model_relative_filename(abs_path)
|
||||
primary_path = (
|
||||
await session.execute(
|
||||
select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.asset_hash == asset_hash)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
# Start from current metadata on this AssetInfo, if any
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
@@ -366,7 +379,6 @@ async def list_asset_infos_page(
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
# Sort
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
@@ -381,7 +393,6 @@ async def list_asset_infos_page(
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
# Total count (same filters, no ordering/limit/offset)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
@@ -395,10 +406,9 @@ async def list_asset_infos_page(
|
||||
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
# Fetch rows
|
||||
infos = (await session.execute(base)).scalars().unique().all()
|
||||
|
||||
# Collect tags in bulk (single query)
|
||||
# Collect tags in bulk
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
@@ -470,12 +480,33 @@ async def fetch_asset_info_asset_and_tags(
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
|
||||
return await session.get(AssetCacheState, asset_hash)
|
||||
"""Return the oldest cache row for this asset."""
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_hash == asset_hash)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def list_cache_states_by_asset_hash(
|
||||
session: AsyncSession, *, asset_hash: str
|
||||
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
|
||||
"""Return all cache rows for this asset ordered by oldest first."""
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_hash == asset_hash)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
async def list_asset_locations(
|
||||
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
|
||||
) -> list[AssetLocation] | Sequence[AssetLocation]:
|
||||
) -> Union[list[AssetLocation], Sequence[AssetLocation]]:
|
||||
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
|
||||
if provider:
|
||||
stmt = stmt.where(AssetLocation.provider == provider)
|
||||
@@ -815,7 +846,6 @@ async def list_tags_with_usage(
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
# Ordering
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else: # default "count_desc"
|
||||
@@ -990,6 +1020,7 @@ def _apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
@@ -1050,7 +1081,7 @@ def _apply_metadata_filter(
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
# ANY-of (exists for any element)
|
||||
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
@@ -1079,12 +1110,10 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
# None
|
||||
if value is None:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": None})
|
||||
return rows
|
||||
|
||||
# Scalars
|
||||
if _is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
@@ -1099,9 +1128,7 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
# Lists
|
||||
if isinstance(value, list):
|
||||
# list of scalars?
|
||||
if all(_is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
|
||||
Reference in New Issue
Block a user