mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-02 03:39:57 +00:00
add support for assets duplicates
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
@@ -66,9 +68,8 @@ class Asset(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_state: Mapped["AssetCacheState | None"] = relationship(
|
||||
cache_states: Mapped[list["AssetCacheState"]] = relationship(
|
||||
back_populates="asset",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
@@ -93,24 +94,25 @@ class Asset(Base):
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False)
|
||||
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_hash", "asset_hash"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState hash={self.asset_hash[:12]} path={self.file_path!r}>"
|
||||
return f"<AssetCacheState id={self.id} hash={self.asset_hash[:12]} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetLocation(Base):
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, Sequence, Optional, Iterable
|
||||
from typing import Any, Sequence, Optional, Iterable, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -82,14 +82,14 @@ async def ingest_fs_asset(
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Upsert Asset identity row + cache state pointing at local file.
|
||||
Upsert Asset identity row + cache state(s) pointing at local file.
|
||||
|
||||
Always:
|
||||
- Insert Asset if missing;
|
||||
- Insert AssetCacheState if missing; else update mtime_ns if different.
|
||||
- Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different.
|
||||
|
||||
Optionally (when info_name is provided):
|
||||
- Create an AssetInfo.
|
||||
- Create or update an AssetInfo on (asset_hash, owner_id, name).
|
||||
- Link provided tags to that AssetInfo.
|
||||
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
|
||||
* If False (default), create unknown tags.
|
||||
@@ -157,11 +157,16 @@ async def ingest_fs_asset(
|
||||
out["state_created"] = True
|
||||
|
||||
if not out["state_created"]:
|
||||
state = await session.get(AssetCacheState, asset_hash)
|
||||
# most likely a unique(file_path) conflict; update that row
|
||||
state = (
|
||||
await session.execute(
|
||||
select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if state is not None:
|
||||
changed = False
|
||||
if state.file_path != locator:
|
||||
state.file_path = locator
|
||||
if state.asset_hash != asset_hash:
|
||||
state.asset_hash = asset_hash
|
||||
changed = True
|
||||
if state.mtime_ns != int(mtime_ns):
|
||||
state.mtime_ns = int(mtime_ns)
|
||||
@@ -260,7 +265,15 @@ async def ingest_fs_asset(
|
||||
# )
|
||||
# start of adding metadata["filename"]
|
||||
if out["asset_info_id"] is not None:
|
||||
computed_filename = compute_model_relative_filename(abs_path)
|
||||
primary_path = (
|
||||
await session.execute(
|
||||
select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.asset_hash == asset_hash)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
# Start from current metadata on this AssetInfo, if any
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
@@ -366,7 +379,6 @@ async def list_asset_infos_page(
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
# Sort
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
@@ -381,7 +393,6 @@ async def list_asset_infos_page(
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
# Total count (same filters, no ordering/limit/offset)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
@@ -395,10 +406,9 @@ async def list_asset_infos_page(
|
||||
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
# Fetch rows
|
||||
infos = (await session.execute(base)).scalars().unique().all()
|
||||
|
||||
# Collect tags in bulk (single query)
|
||||
# Collect tags in bulk
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
@@ -470,12 +480,33 @@ async def fetch_asset_info_asset_and_tags(
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
|
||||
return await session.get(AssetCacheState, asset_hash)
|
||||
"""Return the oldest cache row for this asset."""
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_hash == asset_hash)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def list_cache_states_by_asset_hash(
|
||||
session: AsyncSession, *, asset_hash: str
|
||||
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
|
||||
"""Return all cache rows for this asset ordered by oldest first."""
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_hash == asset_hash)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
async def list_asset_locations(
|
||||
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
|
||||
) -> list[AssetLocation] | Sequence[AssetLocation]:
|
||||
) -> Union[list[AssetLocation], Sequence[AssetLocation]]:
|
||||
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
|
||||
if provider:
|
||||
stmt = stmt.where(AssetLocation.provider == provider)
|
||||
@@ -815,7 +846,6 @@ async def list_tags_with_usage(
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
# Ordering
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else: # default "count_desc"
|
||||
@@ -990,6 +1020,7 @@ def _apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
@@ -1050,7 +1081,7 @@ def _apply_metadata_filter(
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
# ANY-of (exists for any element)
|
||||
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
@@ -1079,12 +1110,10 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
# None
|
||||
if value is None:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": None})
|
||||
return rows
|
||||
|
||||
# Scalars
|
||||
if _is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
@@ -1099,9 +1128,7 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
# Lists
|
||||
if isinstance(value, list):
|
||||
# list of scalars?
|
||||
if all(_is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
|
||||
Reference in New Issue
Block a user