add support for assets duplicates

This commit is contained in:
bigcat88
2025-09-06 19:22:51 +03:00
parent 789a62ce35
commit 2d9be462d3
6 changed files with 116 additions and 62 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Optional
import uuid
@@ -66,9 +68,8 @@ class Asset(Base):
viewonly=True,
)
cache_state: Mapped["AssetCacheState | None"] = relationship(
cache_states: Mapped[list["AssetCacheState"]] = relationship(
back_populates="asset",
uselist=False,
cascade="all, delete-orphan",
passive_deletes=True,
)
@@ -93,24 +94,25 @@ class Asset(Base):
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_hash", "asset_hash"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState hash={self.asset_hash[:12]} path={self.file_path!r}>"
return f"<AssetCacheState id={self.id} hash={self.asset_hash[:12]} path={self.file_path!r}>"
class AssetLocation(Base):

View File

@@ -4,7 +4,7 @@ import logging
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from typing import Any, Sequence, Optional, Iterable
from typing import Any, Sequence, Optional, Iterable, Union
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
@@ -82,14 +82,14 @@ async def ingest_fs_asset(
require_existing_tags: bool = False,
) -> dict:
"""
Upsert Asset identity row + cache state pointing at local file.
Upsert Asset identity row + cache state(s) pointing at local file.
Always:
- Insert Asset if missing;
- Insert AssetCacheState if missing; else update mtime_ns if different.
- Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different.
Optionally (when info_name is provided):
- Create an AssetInfo.
- Create or update an AssetInfo on (asset_hash, owner_id, name).
- Link provided tags to that AssetInfo.
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
* If False (default), create unknown tags.
@@ -157,11 +157,16 @@ async def ingest_fs_asset(
out["state_created"] = True
if not out["state_created"]:
state = await session.get(AssetCacheState, asset_hash)
# most likely a unique(file_path) conflict; update that row
state = (
await session.execute(
select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1)
)
).scalars().first()
if state is not None:
changed = False
if state.file_path != locator:
state.file_path = locator
if state.asset_hash != asset_hash:
state.asset_hash = asset_hash
changed = True
if state.mtime_ns != int(mtime_ns):
state.mtime_ns = int(mtime_ns)
@@ -260,7 +265,15 @@ async def ingest_fs_asset(
# )
# start of adding metadata["filename"]
if out["asset_info_id"] is not None:
computed_filename = compute_model_relative_filename(abs_path)
primary_path = (
await session.execute(
select(AssetCacheState.file_path)
.where(AssetCacheState.asset_hash == asset_hash)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
# Start from current metadata on this AssetInfo, if any
current_meta = existing_info.user_metadata or {}
@@ -366,7 +379,6 @@ async def list_asset_infos_page(
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
# Sort
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
@@ -381,7 +393,6 @@ async def list_asset_infos_page(
base = base.order_by(sort_exp).limit(limit).offset(offset)
# Total count (same filters, no ordering/limit/offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
@@ -395,10 +406,9 @@ async def list_asset_infos_page(
total = int((await session.execute(count_stmt)).scalar_one() or 0)
# Fetch rows
infos = (await session.execute(base)).scalars().unique().all()
# Collect tags in bulk (single query)
# Collect tags in bulk
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
@@ -470,12 +480,33 @@ async def fetch_asset_info_asset_and_tags(
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
return await session.get(AssetCacheState, asset_hash)
"""Return the oldest cache row for this asset."""
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_hash == asset_hash)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
async def list_cache_states_by_asset_hash(
session: AsyncSession, *, asset_hash: str
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
"""Return all cache rows for this asset ordered by oldest first."""
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_hash == asset_hash)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
async def list_asset_locations(
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
) -> list[AssetLocation] | Sequence[AssetLocation]:
) -> Union[list[AssetLocation], Sequence[AssetLocation]]:
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
if provider:
stmt = stmt.where(AssetLocation.provider == provider)
@@ -815,7 +846,6 @@ async def list_tags_with_usage(
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
# Ordering
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else: # default "count_desc"
@@ -990,6 +1020,7 @@ def _apply_tag_filters(
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
@@ -1050,7 +1081,7 @@ def _apply_metadata_filter(
for k, v in metadata_filter.items():
if isinstance(v, list):
# ANY-of (exists for any element)
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
@@ -1079,12 +1110,10 @@ def _project_kv(key: str, value: Any) -> list[dict]:
"""
rows: list[dict] = []
# None
if value is None:
rows.append({"key": key, "ordinal": 0, "val_json": None})
return rows
# Scalars
if _is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
@@ -1099,9 +1128,7 @@ def _project_kv(key: str, value: Any) -> list[dict]:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
# Lists
if isinstance(value, list):
# list of scalars?
if all(_is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None: