mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 22:20:03 +00:00
This change solves the basename collision bug by using UNIQUE(file_path) on the unified asset_references table. Key changes: Database: - Migration 0005 merges asset_cache_states and asset_infos into asset_references - AssetReference now contains: cache state fields (file_path, mtime_ns, needs_verify, is_missing, enrichment_level) plus info fields (name, owner_id, preview_id, etc.) - AssetReferenceMeta replaces AssetInfoMeta - AssetReferenceTag replaces AssetInfoTag - UNIQUE constraint on file_path prevents duplicate entries for same file Code: - New unified query module: asset_reference.py (replaces asset_info.py, cache_state.py) - Updated scanner, seeder, and services to use AssetReference - Updated API routes to use reference_id instead of asset_info_id Tests: - All 175 unit tests updated and passing - Integration tests require server environment (not run here) Amp-Thread-ID: https://ampcode.com/threads/T-019c4fe8-9dcb-75ce-bea8-ea786343a581 Co-authored-by: Amp <amp@ampcode.com>
1040 lines
30 KiB
Python
1040 lines
30 KiB
Python
"""Query functions for the unified AssetReference table.
|
|
|
|
This module replaces the separate asset_info.py and cache_state.py query modules,
|
|
providing a unified interface for the merged asset_references table.
|
|
"""
|
|
|
|
import os
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from typing import NamedTuple, Sequence
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import delete, exists, select
|
|
from sqlalchemy.dialects import sqlite
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session, noload
|
|
|
|
from app.assets.database.models import (
|
|
Asset,
|
|
AssetReference,
|
|
AssetReferenceMeta,
|
|
AssetReferenceTag,
|
|
Tag,
|
|
)
|
|
from app.assets.database.queries.common import (
|
|
MAX_BIND_PARAMS,
|
|
build_visible_owner_clause,
|
|
calculate_rows_per_statement,
|
|
iter_chunks,
|
|
)
|
|
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
|
|
|
|
|
# =============================================================================
|
|
# Metadata conversion helpers (from former asset_info.py)
|
|
# =============================================================================
|
|
|
|
|
|
def _check_is_scalar(v):
|
|
if v is None:
|
|
return True
|
|
if isinstance(v, bool):
|
|
return True
|
|
if isinstance(v, (int, float, Decimal, str)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
|
"""Convert a scalar value to a typed projection row."""
|
|
if value is None:
|
|
return {
|
|
"key": key,
|
|
"ordinal": ordinal,
|
|
"val_str": None,
|
|
"val_num": None,
|
|
"val_bool": None,
|
|
"val_json": None,
|
|
}
|
|
if isinstance(value, bool):
|
|
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
|
if isinstance(value, (int, float, Decimal)):
|
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
return {"key": key, "ordinal": ordinal, "val_num": num}
|
|
if isinstance(value, str):
|
|
return {"key": key, "ordinal": ordinal, "val_str": value}
|
|
return {"key": key, "ordinal": ordinal, "val_json": value}
|
|
|
|
|
|
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
|
"""Turn a metadata key/value into typed projection rows."""
|
|
if value is None:
|
|
return [_scalar_to_row(key, 0, None)]
|
|
|
|
if _check_is_scalar(value):
|
|
return [_scalar_to_row(key, 0, value)]
|
|
|
|
if isinstance(value, list):
|
|
if all(_check_is_scalar(x) for x in value):
|
|
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
|
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
|
|
|
return [{"key": key, "ordinal": 0, "val_json": value}]
|
|
|
|
|
|
# =============================================================================
|
|
# Filter helpers
|
|
# =============================================================================
|
|
|
|
|
|
def _apply_tag_filters(
|
|
stmt: sa.sql.Select,
|
|
include_tags: Sequence[str] | None = None,
|
|
exclude_tags: Sequence[str] | None = None,
|
|
) -> sa.sql.Select:
|
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
include_tags = normalize_tags(include_tags)
|
|
exclude_tags = normalize_tags(exclude_tags)
|
|
|
|
if include_tags:
|
|
for tag_name in include_tags:
|
|
stmt = stmt.where(
|
|
exists().where(
|
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
& (AssetReferenceTag.tag_name == tag_name)
|
|
)
|
|
)
|
|
|
|
if exclude_tags:
|
|
stmt = stmt.where(
|
|
~exists().where(
|
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
|
)
|
|
)
|
|
return stmt
|
|
|
|
|
|
def _apply_metadata_filter(
|
|
stmt: sa.sql.Select,
|
|
metadata_filter: dict | None = None,
|
|
) -> sa.sql.Select:
|
|
"""Apply filters using asset_reference_meta projection table."""
|
|
if not metadata_filter:
|
|
return stmt
|
|
|
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
return sa.exists().where(
|
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
AssetReferenceMeta.key == key,
|
|
*preds,
|
|
)
|
|
|
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
if value is None:
|
|
no_row_for_key = sa.not_(
|
|
sa.exists().where(
|
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
AssetReferenceMeta.key == key,
|
|
)
|
|
)
|
|
null_row = _exists_for_pred(
|
|
key,
|
|
AssetReferenceMeta.val_json.is_(None),
|
|
AssetReferenceMeta.val_str.is_(None),
|
|
AssetReferenceMeta.val_num.is_(None),
|
|
AssetReferenceMeta.val_bool.is_(None),
|
|
)
|
|
return sa.or_(no_row_for_key, null_row)
|
|
|
|
if isinstance(value, bool):
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
|
if isinstance(value, (int, float)):
|
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
|
if isinstance(value, str):
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
|
|
|
for k, v in metadata_filter.items():
|
|
if isinstance(v, list):
|
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
|
if ors:
|
|
stmt = stmt.where(sa.or_(*ors))
|
|
else:
|
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
return stmt
|
|
|
|
|
|
# =============================================================================
|
|
# Basic CRUD operations
|
|
# =============================================================================
|
|
|
|
|
|
def get_reference_by_id(
|
|
session: Session,
|
|
reference_id: str,
|
|
) -> AssetReference | None:
|
|
return session.get(AssetReference, reference_id)
|
|
|
|
|
|
def get_reference_by_file_path(
|
|
session: Session,
|
|
file_path: str,
|
|
) -> AssetReference | None:
|
|
"""Get a reference by its file path."""
|
|
return (
|
|
session.execute(
|
|
select(AssetReference).where(AssetReference.file_path == file_path).limit(1)
|
|
)
|
|
.scalars()
|
|
.first()
|
|
)
|
|
|
|
|
|
def reference_exists_for_asset_id(
|
|
session: Session,
|
|
asset_id: str,
|
|
) -> bool:
|
|
q = (
|
|
select(sa.literal(True))
|
|
.select_from(AssetReference)
|
|
.where(AssetReference.asset_id == asset_id)
|
|
.limit(1)
|
|
)
|
|
return session.execute(q).first() is not None
|
|
|
|
|
|
def insert_reference(
|
|
session: Session,
|
|
asset_id: str,
|
|
name: str,
|
|
owner_id: str = "",
|
|
file_path: str | None = None,
|
|
mtime_ns: int | None = None,
|
|
preview_id: str | None = None,
|
|
) -> AssetReference | None:
|
|
"""Insert a new AssetReference. Returns None if unique constraint violated."""
|
|
now = get_utc_now()
|
|
try:
|
|
with session.begin_nested():
|
|
ref = AssetReference(
|
|
asset_id=asset_id,
|
|
name=name,
|
|
owner_id=owner_id,
|
|
file_path=file_path,
|
|
mtime_ns=mtime_ns,
|
|
preview_id=preview_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
last_access_time=now,
|
|
)
|
|
session.add(ref)
|
|
session.flush()
|
|
return ref
|
|
except IntegrityError:
|
|
return None
|
|
|
|
|
|
def get_or_create_reference(
|
|
session: Session,
|
|
asset_id: str,
|
|
name: str,
|
|
owner_id: str = "",
|
|
file_path: str | None = None,
|
|
mtime_ns: int | None = None,
|
|
preview_id: str | None = None,
|
|
) -> tuple[AssetReference, bool]:
|
|
"""Get existing or create new AssetReference.
|
|
|
|
For filesystem references (file_path is set), uniqueness is by file_path.
|
|
For API references (file_path is None), we look for matching
|
|
asset_id + owner_id + name.
|
|
|
|
Returns (reference, created).
|
|
"""
|
|
ref = insert_reference(
|
|
session,
|
|
asset_id=asset_id,
|
|
name=name,
|
|
owner_id=owner_id,
|
|
file_path=file_path,
|
|
mtime_ns=mtime_ns,
|
|
preview_id=preview_id,
|
|
)
|
|
if ref:
|
|
return ref, True
|
|
|
|
# Find existing - priority to file_path match, then name match
|
|
if file_path:
|
|
existing = get_reference_by_file_path(session, file_path)
|
|
else:
|
|
existing = (
|
|
session.execute(
|
|
select(AssetReference)
|
|
.where(
|
|
AssetReference.asset_id == asset_id,
|
|
AssetReference.name == name,
|
|
AssetReference.owner_id == owner_id,
|
|
AssetReference.file_path.is_(None),
|
|
)
|
|
.limit(1)
|
|
)
|
|
.unique()
|
|
.scalar_one_or_none()
|
|
)
|
|
if not existing:
|
|
raise RuntimeError("Failed to find AssetReference after insert conflict.")
|
|
return existing, False
|
|
|
|
|
|
def update_reference_timestamps(
|
|
session: Session,
|
|
reference: AssetReference,
|
|
preview_id: str | None = None,
|
|
) -> None:
|
|
"""Update timestamps and optionally preview_id on existing AssetReference."""
|
|
now = get_utc_now()
|
|
if preview_id and reference.preview_id != preview_id:
|
|
reference.preview_id = preview_id
|
|
reference.updated_at = now
|
|
|
|
|
|
# =============================================================================
|
|
# Listing and pagination
|
|
# =============================================================================
|
|
|
|
|
|
def list_references_page(
|
|
session: Session,
|
|
owner_id: str = "",
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
name_contains: str | None = None,
|
|
include_tags: Sequence[str] | None = None,
|
|
exclude_tags: Sequence[str] | None = None,
|
|
metadata_filter: dict | None = None,
|
|
sort: str | None = None,
|
|
order: str | None = None,
|
|
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
|
|
"""List references with pagination, filtering, and sorting.
|
|
|
|
Returns (references, tag_map, total_count).
|
|
"""
|
|
base = (
|
|
select(AssetReference)
|
|
.join(Asset, Asset.id == AssetReference.asset_id)
|
|
.where(build_visible_owner_clause(owner_id))
|
|
.options(noload(AssetReference.tags))
|
|
)
|
|
|
|
if name_contains:
|
|
escaped, esc = escape_sql_like_string(name_contains)
|
|
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
|
|
|
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
|
base = _apply_metadata_filter(base, metadata_filter)
|
|
|
|
sort = (sort or "created_at").lower()
|
|
order = (order or "desc").lower()
|
|
sort_map = {
|
|
"name": AssetReference.name,
|
|
"created_at": AssetReference.created_at,
|
|
"updated_at": AssetReference.updated_at,
|
|
"last_access_time": AssetReference.last_access_time,
|
|
"size": Asset.size_bytes,
|
|
}
|
|
sort_col = sort_map.get(sort, AssetReference.created_at)
|
|
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
|
|
|
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
|
|
|
count_stmt = (
|
|
select(sa.func.count())
|
|
.select_from(AssetReference)
|
|
.join(Asset, Asset.id == AssetReference.asset_id)
|
|
.where(build_visible_owner_clause(owner_id))
|
|
)
|
|
if name_contains:
|
|
escaped, esc = escape_sql_like_string(name_contains)
|
|
count_stmt = count_stmt.where(
|
|
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
|
)
|
|
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
|
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
|
|
|
total = int(session.execute(count_stmt).scalar_one() or 0)
|
|
refs = session.execute(base).unique().scalars().all()
|
|
|
|
id_list: list[str] = [r.id for r in refs]
|
|
tag_map: dict[str, list[str]] = defaultdict(list)
|
|
if id_list:
|
|
rows = session.execute(
|
|
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
|
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
|
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
|
.order_by(AssetReferenceTag.added_at)
|
|
)
|
|
for ref_id, tag_name in rows.all():
|
|
tag_map[ref_id].append(tag_name)
|
|
|
|
return list(refs), tag_map, total
|
|
|
|
|
|
def fetch_reference_asset_and_tags(
|
|
session: Session,
|
|
reference_id: str,
|
|
owner_id: str = "",
|
|
) -> tuple[AssetReference, Asset, list[str]] | None:
|
|
stmt = (
|
|
select(AssetReference, Asset, Tag.name)
|
|
.join(Asset, Asset.id == AssetReference.asset_id)
|
|
.join(
|
|
AssetReferenceTag,
|
|
AssetReferenceTag.asset_reference_id == AssetReference.id,
|
|
isouter=True,
|
|
)
|
|
.join(Tag, Tag.name == AssetReferenceTag.tag_name, isouter=True)
|
|
.where(
|
|
AssetReference.id == reference_id,
|
|
build_visible_owner_clause(owner_id),
|
|
)
|
|
.options(noload(AssetReference.tags))
|
|
.order_by(Tag.name.asc())
|
|
)
|
|
|
|
rows = session.execute(stmt).all()
|
|
if not rows:
|
|
return None
|
|
|
|
first_ref, first_asset, _ = rows[0]
|
|
tags: list[str] = []
|
|
seen: set[str] = set()
|
|
for _ref, _asset, tag_name in rows:
|
|
if tag_name and tag_name not in seen:
|
|
seen.add(tag_name)
|
|
tags.append(tag_name)
|
|
return first_ref, first_asset, tags
|
|
|
|
|
|
def fetch_reference_and_asset(
|
|
session: Session,
|
|
reference_id: str,
|
|
owner_id: str = "",
|
|
) -> tuple[AssetReference, Asset] | None:
|
|
stmt = (
|
|
select(AssetReference, Asset)
|
|
.join(Asset, Asset.id == AssetReference.asset_id)
|
|
.where(
|
|
AssetReference.id == reference_id,
|
|
build_visible_owner_clause(owner_id),
|
|
)
|
|
.limit(1)
|
|
.options(noload(AssetReference.tags))
|
|
)
|
|
pair = session.execute(stmt).first()
|
|
if not pair:
|
|
return None
|
|
return pair[0], pair[1]
|
|
|
|
|
|
# =============================================================================
|
|
# Timestamp updates
|
|
# =============================================================================
|
|
|
|
|
|
def update_reference_access_time(
|
|
session: Session,
|
|
reference_id: str,
|
|
ts: datetime | None = None,
|
|
only_if_newer: bool = True,
|
|
) -> None:
|
|
ts = ts or get_utc_now()
|
|
stmt = sa.update(AssetReference).where(AssetReference.id == reference_id)
|
|
if only_if_newer:
|
|
stmt = stmt.where(
|
|
sa.or_(
|
|
AssetReference.last_access_time.is_(None),
|
|
AssetReference.last_access_time < ts,
|
|
)
|
|
)
|
|
session.execute(stmt.values(last_access_time=ts))
|
|
|
|
|
|
def update_reference_name(
|
|
session: Session,
|
|
reference_id: str,
|
|
name: str,
|
|
) -> None:
|
|
"""Update the name of an AssetReference."""
|
|
now = get_utc_now()
|
|
session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.id == reference_id)
|
|
.values(name=name, updated_at=now)
|
|
)
|
|
|
|
|
|
def update_reference_updated_at(
|
|
session: Session,
|
|
reference_id: str,
|
|
ts: datetime | None = None,
|
|
) -> None:
|
|
"""Update the updated_at timestamp of an AssetReference."""
|
|
ts = ts or get_utc_now()
|
|
session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.id == reference_id)
|
|
.values(updated_at=ts)
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Metadata operations
|
|
# =============================================================================
|
|
|
|
|
|
def set_reference_metadata(
|
|
session: Session,
|
|
reference_id: str,
|
|
user_metadata: dict | None = None,
|
|
) -> None:
|
|
ref = session.get(AssetReference, reference_id)
|
|
if not ref:
|
|
raise ValueError(f"AssetReference {reference_id} not found")
|
|
|
|
ref.user_metadata = user_metadata or {}
|
|
ref.updated_at = get_utc_now()
|
|
session.flush()
|
|
|
|
session.execute(
|
|
delete(AssetReferenceMeta).where(
|
|
AssetReferenceMeta.asset_reference_id == reference_id
|
|
)
|
|
)
|
|
session.flush()
|
|
|
|
if not user_metadata:
|
|
return
|
|
|
|
rows: list[AssetReferenceMeta] = []
|
|
for k, v in user_metadata.items():
|
|
for r in convert_metadata_to_rows(k, v):
|
|
rows.append(
|
|
AssetReferenceMeta(
|
|
asset_reference_id=reference_id,
|
|
key=r["key"],
|
|
ordinal=int(r["ordinal"]),
|
|
val_str=r.get("val_str"),
|
|
val_num=r.get("val_num"),
|
|
val_bool=r.get("val_bool"),
|
|
val_json=r.get("val_json"),
|
|
)
|
|
)
|
|
if rows:
|
|
session.add_all(rows)
|
|
session.flush()
|
|
|
|
|
|
# =============================================================================
|
|
# Delete operations
|
|
# =============================================================================
|
|
|
|
|
|
def delete_reference_by_id(
|
|
session: Session,
|
|
reference_id: str,
|
|
owner_id: str,
|
|
) -> bool:
|
|
stmt = sa.delete(AssetReference).where(
|
|
AssetReference.id == reference_id,
|
|
build_visible_owner_clause(owner_id),
|
|
)
|
|
return int(session.execute(stmt).rowcount or 0) > 0
|
|
|
|
|
|
def set_reference_preview(
|
|
session: Session,
|
|
reference_id: str,
|
|
preview_asset_id: str | None = None,
|
|
) -> None:
|
|
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
|
ref = session.get(AssetReference, reference_id)
|
|
if not ref:
|
|
raise ValueError(f"AssetReference {reference_id} not found")
|
|
|
|
if preview_asset_id is None:
|
|
ref.preview_id = None
|
|
else:
|
|
if not session.get(Asset, preview_asset_id):
|
|
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
|
ref.preview_id = preview_asset_id
|
|
|
|
ref.updated_at = get_utc_now()
|
|
session.flush()
|
|
|
|
|
|
# =============================================================================
|
|
# Cache state operations (from former cache_state.py)
|
|
# =============================================================================
|
|
|
|
|
|
class CacheStateRow(NamedTuple):
|
|
"""Row from reference query with cache state data."""
|
|
|
|
reference_id: str
|
|
file_path: str
|
|
mtime_ns: int | None
|
|
needs_verify: bool
|
|
asset_id: str
|
|
asset_hash: str | None
|
|
size_bytes: int
|
|
|
|
|
|
def list_references_by_asset_id(
|
|
session: Session,
|
|
asset_id: str,
|
|
) -> Sequence[AssetReference]:
|
|
return (
|
|
session.execute(
|
|
select(AssetReference)
|
|
.where(AssetReference.asset_id == asset_id)
|
|
.order_by(AssetReference.id.asc())
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
|
|
def upsert_reference(
|
|
session: Session,
|
|
asset_id: str,
|
|
file_path: str,
|
|
name: str,
|
|
mtime_ns: int,
|
|
owner_id: str = "",
|
|
) -> tuple[bool, bool]:
|
|
"""Upsert a reference by file_path. Returns (created, updated).
|
|
|
|
Also restores references that were previously marked as missing.
|
|
"""
|
|
now = get_utc_now()
|
|
vals = {
|
|
"asset_id": asset_id,
|
|
"file_path": file_path,
|
|
"name": name,
|
|
"owner_id": owner_id,
|
|
"mtime_ns": int(mtime_ns),
|
|
"is_missing": False,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"last_access_time": now,
|
|
}
|
|
ins = (
|
|
sqlite.insert(AssetReference)
|
|
.values(**vals)
|
|
.on_conflict_do_nothing(index_elements=[AssetReference.file_path])
|
|
)
|
|
res = session.execute(ins)
|
|
created = int(res.rowcount or 0) > 0
|
|
|
|
if created:
|
|
return True, False
|
|
|
|
upd = (
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.file_path == file_path)
|
|
.where(
|
|
sa.or_(
|
|
AssetReference.asset_id != asset_id,
|
|
AssetReference.mtime_ns.is_(None),
|
|
AssetReference.mtime_ns != int(mtime_ns),
|
|
AssetReference.is_missing == True, # noqa: E712
|
|
)
|
|
)
|
|
.values(
|
|
asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, updated_at=now
|
|
)
|
|
)
|
|
res2 = session.execute(upd)
|
|
updated = int(res2.rowcount or 0) > 0
|
|
return False, updated
|
|
|
|
|
|
def mark_references_missing_outside_prefixes(
|
|
session: Session,
|
|
valid_prefixes: list[str],
|
|
) -> int:
|
|
"""Mark references as missing when file_path doesn't match any valid prefix.
|
|
|
|
Returns number of references marked as missing.
|
|
"""
|
|
if not valid_prefixes:
|
|
return 0
|
|
|
|
def make_prefix_condition(prefix: str):
|
|
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
|
escaped, esc = escape_sql_like_string(base)
|
|
return AssetReference.file_path.like(escaped + "%", escape=esc)
|
|
|
|
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
|
|
result = session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.file_path.isnot(None))
|
|
.where(~matches_valid_prefix)
|
|
.where(AssetReference.is_missing == False) # noqa: E712
|
|
.values(is_missing=True)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def restore_references_by_paths(session: Session, file_paths: list[str]) -> int:
|
|
"""Restore references that were previously marked as missing.
|
|
|
|
Returns number of references restored.
|
|
"""
|
|
if not file_paths:
|
|
return 0
|
|
|
|
result = session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.file_path.in_(file_paths))
|
|
.where(AssetReference.is_missing == True) # noqa: E712
|
|
.values(is_missing=False)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
|
|
"""Get IDs of unhashed assets (hash=None) with no active references.
|
|
|
|
An asset is considered unreferenced if it has no references,
|
|
or all its references are marked as missing.
|
|
|
|
Returns list of asset IDs that are unreferenced.
|
|
"""
|
|
active_ref_exists = (
|
|
sa.select(sa.literal(1))
|
|
.where(AssetReference.asset_id == Asset.id)
|
|
.where(AssetReference.is_missing == False) # noqa: E712
|
|
.correlate(Asset)
|
|
.exists()
|
|
)
|
|
unreferenced_subq = sa.select(Asset.id).where(
|
|
Asset.hash.is_(None), ~active_ref_exists
|
|
)
|
|
return [row[0] for row in session.execute(unreferenced_subq).all()]
|
|
|
|
|
|
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
|
|
"""Delete assets and their references by ID.
|
|
|
|
Returns number of assets deleted.
|
|
"""
|
|
if not asset_ids:
|
|
return 0
|
|
session.execute(
|
|
sa.delete(AssetReference).where(AssetReference.asset_id.in_(asset_ids))
|
|
)
|
|
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
|
|
return result.rowcount
|
|
|
|
|
|
def get_references_for_prefixes(
|
|
session: Session,
|
|
prefixes: list[str],
|
|
*,
|
|
include_missing: bool = False,
|
|
) -> list[CacheStateRow]:
|
|
"""Get all references with file paths matching any of the given prefixes.
|
|
|
|
Args:
|
|
session: Database session
|
|
prefixes: List of absolute directory prefixes to match
|
|
include_missing: If False (default), exclude references marked as missing
|
|
|
|
Returns:
|
|
List of cache state rows with joined asset data
|
|
"""
|
|
if not prefixes:
|
|
return []
|
|
|
|
conds = []
|
|
for p in prefixes:
|
|
base = os.path.abspath(p)
|
|
if not base.endswith(os.sep):
|
|
base += os.sep
|
|
escaped, esc = escape_sql_like_string(base)
|
|
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
|
|
|
query = (
|
|
sa.select(
|
|
AssetReference.id,
|
|
AssetReference.file_path,
|
|
AssetReference.mtime_ns,
|
|
AssetReference.needs_verify,
|
|
AssetReference.asset_id,
|
|
Asset.hash,
|
|
Asset.size_bytes,
|
|
)
|
|
.join(Asset, Asset.id == AssetReference.asset_id)
|
|
.where(AssetReference.file_path.isnot(None))
|
|
.where(sa.or_(*conds))
|
|
)
|
|
|
|
if not include_missing:
|
|
query = query.where(AssetReference.is_missing == False) # noqa: E712
|
|
|
|
rows = session.execute(
|
|
query.order_by(AssetReference.asset_id.asc(), AssetReference.id.asc())
|
|
).all()
|
|
|
|
return [
|
|
CacheStateRow(
|
|
reference_id=row[0],
|
|
file_path=row[1],
|
|
mtime_ns=row[2],
|
|
needs_verify=row[3],
|
|
asset_id=row[4],
|
|
asset_hash=row[5],
|
|
size_bytes=int(row[6] or 0),
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
def bulk_update_needs_verify(
|
|
session: Session, reference_ids: list[str], value: bool
|
|
) -> int:
|
|
"""Set needs_verify flag for multiple references.
|
|
|
|
Returns: Number of rows updated
|
|
"""
|
|
if not reference_ids:
|
|
return 0
|
|
result = session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.id.in_(reference_ids))
|
|
.values(needs_verify=value)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def bulk_update_is_missing(
|
|
session: Session, reference_ids: list[str], value: bool
|
|
) -> int:
|
|
"""Set is_missing flag for multiple references.
|
|
|
|
Returns: Number of rows updated
|
|
"""
|
|
if not reference_ids:
|
|
return 0
|
|
result = session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.id.in_(reference_ids))
|
|
.values(is_missing=value)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
|
"""Delete references by their IDs.
|
|
|
|
Returns: Number of rows deleted
|
|
"""
|
|
if not reference_ids:
|
|
return 0
|
|
result = session.execute(
|
|
sa.delete(AssetReference).where(AssetReference.id.in_(reference_ids))
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
|
|
"""Delete a seed asset (hash is None) and its references.
|
|
|
|
Returns: True if asset was deleted, False if not found
|
|
"""
|
|
session.execute(
|
|
sa.delete(AssetReference).where(AssetReference.asset_id == asset_id)
|
|
)
|
|
asset = session.get(Asset, asset_id)
|
|
if asset:
|
|
session.delete(asset)
|
|
return True
|
|
return False
|
|
|
|
|
|
# =============================================================================
|
|
# Enrichment operations
|
|
# =============================================================================
|
|
|
|
|
|
class UnenrichedReferenceRow(NamedTuple):
|
|
"""Row for references needing enrichment."""
|
|
|
|
reference_id: str
|
|
asset_id: str
|
|
file_path: str
|
|
enrichment_level: int
|
|
|
|
|
|
def get_unenriched_references(
|
|
session: Session,
|
|
prefixes: list[str],
|
|
max_level: int = 0,
|
|
limit: int = 1000,
|
|
) -> list[UnenrichedReferenceRow]:
|
|
"""Get references that need enrichment (enrichment_level <= max_level).
|
|
|
|
Args:
|
|
session: Database session
|
|
prefixes: List of absolute directory prefixes to scan
|
|
max_level: Maximum enrichment level to include (0=stubs, 1=metadata done)
|
|
limit: Maximum number of rows to return
|
|
|
|
Returns:
|
|
List of unenriched reference rows with file paths
|
|
"""
|
|
if not prefixes:
|
|
return []
|
|
|
|
conds = []
|
|
for p in prefixes:
|
|
base = os.path.abspath(p)
|
|
if not base.endswith(os.sep):
|
|
base += os.sep
|
|
escaped, esc = escape_sql_like_string(base)
|
|
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
|
|
|
query = (
|
|
sa.select(
|
|
AssetReference.id,
|
|
AssetReference.asset_id,
|
|
AssetReference.file_path,
|
|
AssetReference.enrichment_level,
|
|
)
|
|
.where(AssetReference.file_path.isnot(None))
|
|
.where(sa.or_(*conds))
|
|
.where(AssetReference.is_missing == False) # noqa: E712
|
|
.where(AssetReference.enrichment_level <= max_level)
|
|
.order_by(AssetReference.id.asc())
|
|
.limit(limit)
|
|
)
|
|
|
|
rows = session.execute(query).all()
|
|
return [
|
|
UnenrichedReferenceRow(
|
|
reference_id=row[0],
|
|
asset_id=row[1],
|
|
file_path=row[2],
|
|
enrichment_level=row[3],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
def update_enrichment_level(
|
|
session: Session,
|
|
reference_id: str,
|
|
level: int,
|
|
) -> None:
|
|
"""Update the enrichment level for a reference."""
|
|
session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.id == reference_id)
|
|
.values(enrichment_level=level)
|
|
)
|
|
|
|
|
|
def bulk_update_enrichment_level(
|
|
session: Session,
|
|
reference_ids: list[str],
|
|
level: int,
|
|
) -> int:
|
|
"""Update enrichment level for multiple references.
|
|
|
|
Returns: Number of rows updated
|
|
"""
|
|
if not reference_ids:
|
|
return 0
|
|
result = session.execute(
|
|
sa.update(AssetReference)
|
|
.where(AssetReference.id.in_(reference_ids))
|
|
.values(enrichment_level=level)
|
|
)
|
|
return result.rowcount
|
|
|
|
|
|
# =============================================================================
|
|
# Bulk operations
|
|
# =============================================================================
|
|
|
|
|
|
def bulk_insert_references_ignore_conflicts(
|
|
session: Session,
|
|
rows: list[dict],
|
|
) -> None:
|
|
"""Bulk insert reference rows with ON CONFLICT DO NOTHING on file_path.
|
|
|
|
Each dict should have: id, asset_id, file_path, name, owner_id, mtime_ns, etc.
|
|
The is_missing field is automatically set to False for new inserts.
|
|
"""
|
|
if not rows:
|
|
return
|
|
enriched_rows = [{**row, "is_missing": False} for row in rows]
|
|
ins = sqlite.insert(AssetReference).on_conflict_do_nothing(
|
|
index_elements=[AssetReference.file_path]
|
|
)
|
|
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)):
|
|
session.execute(ins, chunk)
|
|
|
|
|
|
def get_references_by_paths_and_asset_ids(
|
|
session: Session,
|
|
path_to_asset: dict[str, str],
|
|
) -> set[str]:
|
|
"""Query references to find paths where our asset_id won the insert.
|
|
|
|
Args:
|
|
path_to_asset: Mapping of file_path -> asset_id we tried to insert
|
|
|
|
Returns:
|
|
Set of file_paths where our asset_id is present
|
|
"""
|
|
if not path_to_asset:
|
|
return set()
|
|
|
|
paths = list(path_to_asset.keys())
|
|
winners: set[str] = set()
|
|
|
|
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
|
|
result = session.execute(
|
|
select(AssetReference.file_path).where(
|
|
AssetReference.file_path.in_(chunk),
|
|
AssetReference.asset_id.in_([path_to_asset[p] for p in chunk]),
|
|
)
|
|
)
|
|
winners.update(result.scalars().all())
|
|
|
|
return winners
|
|
|
|
|
|
def get_reference_ids_by_ids(
|
|
session: Session,
|
|
reference_ids: list[str],
|
|
) -> set[str]:
|
|
"""Query to find which reference IDs exist in the database."""
|
|
if not reference_ids:
|
|
return set()
|
|
|
|
found: set[str] = set()
|
|
for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS):
|
|
result = session.execute(
|
|
select(AssetReference.id).where(AssetReference.id.in_(chunk))
|
|
)
|
|
found.update(result.scalars().all())
|
|
return found
|