refactor(assets): merge AssetInfo and AssetCacheState into AssetReference

This change solves the basename collision bug by using UNIQUE(file_path) on the
unified asset_references table. Key changes:

Database:
- Migration 0005 merges asset_cache_states and asset_infos into asset_references
- AssetReference now contains: cache state fields (file_path, mtime_ns, needs_verify,
  is_missing, enrichment_level) plus info fields (name, owner_id, preview_id, etc.)
- AssetReferenceMeta replaces AssetInfoMeta
- AssetReferenceTag replaces AssetInfoTag
- UNIQUE constraint on file_path prevents duplicate entries for same file

Code:
- New unified query module: asset_reference.py (replaces asset_info.py, cache_state.py)
- Updated scanner, seeder, and services to use AssetReference
- Updated API routes to use reference_id instead of asset_info_id

Tests:
- All 175 unit tests updated and passing
- Integration tests require server environment (not run here)

Amp-Thread-ID: https://ampcode.com/threads/T-019c4fe8-9dcb-75ce-bea8-ea786343a581
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-02-11 20:03:10 -08:00
parent 5ff6bf7a83
commit 8ff4d38ad1
36 changed files with 3191 additions and 2327 deletions

View File

@@ -43,10 +43,10 @@ UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Gets a dictionary of query parameters from the request.
"""Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
request.query is a MultiMapping[str], needs to be converted to a dict
to be validated by Pydantic.
"""
query_dict = {
key: request.query.getall(key)
@@ -58,7 +58,8 @@ def get_query_dict(request: web.Request) -> dict[str, Any]:
# Note to any custom node developers reading this code:
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
# The assets system is not yet fully implemented,
# do not rely on the code in /app/assets remaining the same.
def register_assets_system(
@@ -80,6 +81,7 @@ def _build_error_response(
def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response:
import json
errors = json.loads(ve.json())
return _build_error_response(400, code, "Validation failed.", {"errors": errors})
@@ -142,15 +144,15 @@ async def list_assets_route(request: web.Request) -> web.Response:
summaries = [
schemas_out.AssetSummary(
id=item.info.id,
name=item.info.name,
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.info.created_at,
updated_at=item.info.updated_at,
last_access_time=item.info.last_access_time,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
@@ -168,40 +170,40 @@ async def get_asset_route(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
"""
asset_info_id = str(uuid.UUID(request.match_info["id"]))
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
result = get_asset_detail(
asset_info_id=asset_info_id,
reference_id=reference_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if not result:
return _build_error_response(
404,
"ASSET_NOT_FOUND",
f"AssetInfo {asset_info_id} not found",
{"id": asset_info_id},
f"AssetReference {reference_id} not found",
{"id": reference_id},
)
payload = schemas_out.AssetDetail(
id=result.info.id,
name=result.info.name,
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
except ValueError as e:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
)
except Exception:
logging.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
"get_asset failed for reference_id=%s, owner_id=%s",
reference_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
@@ -216,7 +218,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
try:
result = resolve_asset_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
reference_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
abs_path = result.abs_path
@@ -232,16 +234,14 @@ async def download_asset_content(request: web.Request) -> web.Response:
)
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f"{disposition}; filename=\"{quoted}\"; filename*=UTF-8''{urllib.parse.quote(quoted)}"
encoded = urllib.parse.quote(quoted)
cd = f"{disposition}; filename=\"{quoted}\"; filename*=UTF-8''{encoded}"
file_size = os.path.getsize(abs_path)
size_mb = file_size / (1024 * 1024)
logging.info(
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
abs_path,
file_size,
file_size / (1024 * 1024),
content_type,
filename,
"download_asset_content: path=%s, size=%d bytes (%.2f MB), type=%s, name=%s",
abs_path, file_size, size_mb, content_type, filename,
)
async def stream_file_chunks():
@@ -288,16 +288,16 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
)
payload_out = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
@@ -340,7 +340,7 @@ async def upload_asset(request: web.Request) -> web.Response:
)
try:
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
# Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash(
hash_str=spec.hash,
@@ -391,16 +391,16 @@ async def upload_asset(request: web.Request) -> web.Response:
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new,
)
status = 201 if result.created_new else 200
@@ -409,7 +409,7 @@ async def upload_asset(request: web.Request) -> web.Response:
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset_route(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
@@ -421,27 +421,27 @@ async def update_asset_route(request: web.Request) -> web.Response:
try:
result = update_asset_metadata(
asset_info_id=asset_info_id,
reference_id=reference_id,
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.AssetUpdated(
id=result.info.id,
name=result.info.name,
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
updated_at=result.info.updated_at,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
except (ValueError, PermissionError) as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id}
)
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
"update_asset failed for reference_id=%s, owner_id=%s",
reference_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
@@ -450,7 +450,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset_route(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
reference_id = str(uuid.UUID(request.match_info["id"]))
delete_content_param = request.query.get("delete_content")
delete_content = (
True
@@ -460,21 +460,21 @@ async def delete_asset_route(request: web.Request) -> web.Response:
try:
deleted = delete_asset_reference(
asset_info_id=asset_info_id,
reference_id=reference_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
logging.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
"delete_asset_reference failed for reference_id=%s, owner_id=%s",
reference_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found."
404, "ASSET_NOT_FOUND", f"AssetReference {reference_id} not found."
)
return web.Response(status=204)
@@ -490,8 +490,12 @@ async def get_tags(request: web.Request) -> web.Response:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as e:
import json
return _build_error_response(
400, "INVALID_QUERY", "Invalid query parameters", {"errors": json.loads(e.json())}
400,
"INVALID_QUERY",
"Invalid query parameters",
{"errors": json.loads(e.json())},
)
rows, total = list_tags(
@@ -515,7 +519,7 @@ async def get_tags(request: web.Request) -> web.Response:
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
json_payload = await request.json()
data = schemas_in.TagsAdd.model_validate(json_payload)
@@ -533,7 +537,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
try:
result = apply_tags(
asset_info_id=asset_info_id,
reference_id=reference_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
@@ -545,12 +549,12 @@ async def add_asset_tags(request: web.Request) -> web.Response:
)
except (ValueError, PermissionError) as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id}
)
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
"add_tags_to_asset failed for reference_id=%s, owner_id=%s",
reference_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
@@ -560,7 +564,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
json_payload = await request.json()
data = schemas_in.TagsRemove.model_validate(json_payload)
@@ -578,7 +582,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
try:
result = remove_tags(
asset_info_id=asset_info_id,
reference_id=reference_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
@@ -589,12 +593,12 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
)
except ValueError as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id}
)
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
"remove_tags_from_asset failed for reference_id=%s, owner_id=%s",
reference_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
@@ -683,11 +687,11 @@ async def cancel_seed(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/prune")
async def mark_missing_assets(request: web.Request) -> web.Response:
"""Mark assets as missing when their cache states point to files outside all known root prefixes.
"""Mark assets as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their metadata
are preserved, but cache states are flagged as missing. They can be restored
if the file reappears in a future scan.
This is a non-destructive soft-delete operation. Assets and metadata
are preserved, but references are flagged as missing. They can be
restored if the file reappears in a future scan.
Returns:
200 OK with count of marked assets

View File

@@ -13,7 +13,7 @@ from pydantic import (
class UploadError(Exception):
"""Error during upload parsing with HTTP status and code (used in HTTP layer only)."""
"""Error during upload parsing with HTTP status and code."""
def __init__(self, status: int, code: str, message: str):
super().__init__(message)
@@ -216,14 +216,14 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
Files are stored using the content hash as filename stem.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)

View File

@@ -95,7 +95,7 @@ async def parse_multipart_upload(
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
# Hash exists - drain file but don't write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)

View File

@@ -16,7 +16,6 @@ from sqlalchemy import (
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
@@ -37,29 +36,23 @@ class Asset(Base):
DateTime(timezone=False), nullable=False, default=get_utc_now
)
infos: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
references: Mapped[list[AssetReference]] = relationship(
"AssetReference",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
foreign_keys=lambda: [AssetReference.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
preview_of: Mapped[list[AssetReference]] = relationship(
"AssetReference",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
foreign_keys=lambda: [AssetReference.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
@@ -73,54 +66,33 @@ class Asset(Base):
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
class AssetReference(Base):
"""Unified model combining file cache state and user-facing metadata.
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
Each row represents either:
- A filesystem reference (file_path is set) with cache state
- An API-created reference (file_path is NULL) without cache state
"""
__tablename__ = "asset_references"
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
# Cache state fields (from former AssetCacheState)
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
Index("ix_asset_cache_state_is_missing", "is_missing"),
Index("ix_asset_cache_state_enrichment_level", "enrichment_level"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"
),
CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_acs_enrichment_level_range",
),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
# Info fields (from former AssetInfo)
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False
)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL")
)
@@ -139,7 +111,7 @@ class AssetInfo(Base):
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
back_populates="references",
foreign_keys=[asset_id],
lazy="selectin",
)
@@ -149,37 +121,44 @@ class AssetInfo(Base):
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
back_populates="asset_info",
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
back_populates="asset_reference",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="asset_info",
tag_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="asset_reference",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
overlaps="tags,asset_references",
)
tags: Mapped[list[Tag]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
secondary="asset_reference_tags",
back_populates="asset_references",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
overlaps="tag_links,asset_reference_links,asset_references,tag",
)
__table_args__ = (
UniqueConstraint(
"asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"
Index("uq_asset_references_file_path", "file_path", unique=True),
Index("ix_asset_references_asset_id", "asset_id"),
Index("ix_asset_references_owner_id", "owner_id"),
Index("ix_asset_references_name", "name"),
Index("ix_asset_references_is_missing", "is_missing"),
Index("ix_asset_references_enrichment_level", "enrichment_level"),
Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
@@ -188,14 +167,17 @@ class AssetInfo(Base):
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
path_part = f" path={self.file_path!r}" if self.file_path else ""
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
class AssetReferenceMeta(Base):
__tablename__ = "asset_reference_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
@@ -205,21 +187,25 @@ class AssetInfoMeta(Base):
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
asset_reference: Mapped[AssetReference] = relationship(
back_populates="metadata_entries"
)
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
Index("ix_asset_reference_meta_key", "key"),
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
class AssetReferenceTag(Base):
__tablename__ = "asset_reference_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
@@ -229,12 +215,12 @@ class AssetInfoTag(Base):
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
Index("ix_asset_reference_tags_tag_name", "tag_name"),
Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
)
@@ -244,15 +230,15 @@ class Tag(Base):
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
overlaps="asset_references,tags",
)
asset_infos: Mapped[list[AssetInfo]] = relationship(
secondary="asset_info_tags",
asset_references: Mapped[list[AssetReference]] = relationship(
secondary="asset_reference_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
overlaps="asset_reference_links,tag_links,tags,asset_reference",
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)

View File

@@ -3,59 +3,60 @@ from app.assets.database.queries.asset import (
bulk_insert_assets,
get_asset_by_hash,
get_existing_asset_ids,
reassign_asset_references,
update_asset_hash_and_mime,
upsert_asset,
)
from app.assets.database.queries.asset_info import (
asset_info_exists_for_asset_id,
bulk_insert_asset_infos_ignore_conflicts,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_info_by_id,
get_asset_info_ids_by_ids,
get_or_create_asset_info,
insert_asset_info,
list_asset_infos_page,
set_asset_info_metadata,
set_asset_info_preview,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_timestamps,
update_asset_info_updated_at,
)
from app.assets.database.queries.cache_state import (
from app.assets.database.queries.asset_reference import (
CacheStateRow,
UnenrichedAssetRow,
bulk_insert_cache_states_ignore_conflicts,
UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level,
bulk_update_is_missing,
bulk_update_needs_verify,
convert_metadata_to_rows,
delete_assets_by_ids,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
get_cache_states_by_paths_and_asset_ids,
get_cache_states_for_prefixes,
get_unenriched_cache_states,
delete_reference_by_id,
delete_references_by_ids,
fetch_reference_and_asset,
fetch_reference_asset_and_tags,
get_or_create_reference,
get_reference_by_file_path,
get_reference_by_id,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_references_for_prefixes,
get_unenriched_references,
get_unreferenced_unhashed_asset_ids,
list_cache_states_by_asset_id,
mark_cache_states_missing_outside_prefixes,
restore_cache_states_by_paths,
insert_reference,
list_references_by_asset_id,
list_references_page,
mark_references_missing_outside_prefixes,
reference_exists_for_asset_id,
restore_references_by_paths,
set_reference_metadata,
set_reference_preview,
update_enrichment_level,
upsert_cache_state,
update_reference_access_time,
update_reference_name,
update_reference_timestamps,
update_reference_updated_at,
upsert_reference,
)
from app.assets.database.queries.tags import (
AddTagsDict,
RemoveTagsDict,
SetTagsDict,
add_missing_tag_for_asset_id,
add_tags_to_asset_info,
add_tags_to_reference,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_asset_tags,
get_reference_tags,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_asset_info,
set_asset_info_tags,
remove_tags_from_reference,
set_reference_tags,
)
__all__ = [
@@ -63,51 +64,54 @@ __all__ = [
"CacheStateRow",
"RemoveTagsDict",
"SetTagsDict",
"UnenrichedAssetRow",
"UnenrichedReferenceRow",
"add_missing_tag_for_asset_id",
"add_tags_to_asset_info",
"add_tags_to_reference",
"asset_exists_by_hash",
"asset_info_exists_for_asset_id",
"bulk_insert_asset_infos_ignore_conflicts",
"bulk_insert_assets",
"bulk_insert_cache_states_ignore_conflicts",
"bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_enrichment_level",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"delete_asset_info_by_id",
"convert_metadata_to_rows",
"delete_assets_by_ids",
"delete_cache_states_by_ids",
"delete_orphaned_seed_asset",
"delete_reference_by_id",
"delete_references_by_ids",
"ensure_tags_exist",
"fetch_asset_info_and_asset",
"fetch_asset_info_asset_and_tags",
"fetch_reference_and_asset",
"fetch_reference_asset_and_tags",
"get_asset_by_hash",
"get_existing_asset_ids",
"get_asset_info_by_id",
"get_asset_info_ids_by_ids",
"get_asset_tags",
"get_cache_states_by_paths_and_asset_ids",
"get_cache_states_for_prefixes",
"get_or_create_asset_info",
"get_unenriched_cache_states",
"get_or_create_reference",
"get_reference_by_file_path",
"get_reference_by_id",
"get_reference_ids_by_ids",
"get_reference_tags",
"get_references_by_paths_and_asset_ids",
"get_references_for_prefixes",
"get_unenriched_references",
"get_unreferenced_unhashed_asset_ids",
"insert_asset_info",
"list_asset_infos_page",
"list_cache_states_by_asset_id",
"insert_reference",
"list_references_by_asset_id",
"list_references_page",
"list_tags_with_usage",
"mark_cache_states_missing_outside_prefixes",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",
"reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id",
"remove_tags_from_asset_info",
"restore_cache_states_by_paths",
"set_asset_info_metadata",
"set_asset_info_preview",
"set_asset_info_tags",
"update_asset_info_access_time",
"update_asset_info_name",
"update_asset_info_timestamps",
"update_asset_info_updated_at",
"remove_tags_from_reference",
"restore_references_by_paths",
"set_reference_metadata",
"set_reference_preview",
"set_reference_tags",
"update_asset_hash_and_mime",
"update_enrichment_level",
"update_reference_access_time",
"update_reference_name",
"update_reference_timestamps",
"update_reference_updated_at",
"upsert_asset",
"upsert_cache_state",
"upsert_reference",
]

View File

@@ -82,7 +82,7 @@ def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows. Each dict should have: id, hash, size_bytes, mime_type, created_at."""
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
if not rows:
return
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
@@ -101,3 +101,39 @@ def get_existing_asset_ids(
select(Asset.id).where(Asset.id.in_(asset_ids))
).fetchall()
return {row[0] for row in rows}
def update_asset_hash_and_mime(
session: Session,
asset_id: str,
asset_hash: str | None = None,
mime_type: str | None = None,
) -> bool:
"""Update asset hash and/or mime_type. Returns True if asset was found."""
asset = session.get(Asset, asset_id)
if not asset:
return False
if asset_hash is not None:
asset.hash = asset_hash
if mime_type is not None:
asset.mime_type = mime_type
return True
def reassign_asset_references(
session: Session,
from_asset_id: str,
to_asset_id: str,
reference_id: str,
) -> None:
"""Reassign a reference from one asset to another.
Used when merging a stub asset into an existing asset with the same hash.
"""
from app.assets.database.models import AssetReference
ref = session.get(AssetReference, reference_id)
if ref:
ref.asset_id = to_asset_id
session.flush()

View File

@@ -1,527 +0,0 @@
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from typing import Sequence
import sqlalchemy as sa
from sqlalchemy import delete, exists, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import (
Asset,
AssetInfo,
AssetInfoMeta,
AssetInfoTag,
Tag,
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
def _check_is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return {"key": key, "ordinal": ordinal, "val_num": num}
if isinstance(value, str):
return {"key": key, "ordinal": ordinal, "val_str": value}
return {"key": key, "ordinal": ordinal, "val_json": value}
def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
if value is None:
return [_scalar_to_row(key, 0, None)]
if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)]
if isinstance(value, list):
if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_info_exists_for_asset_id(
session: Session,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_info_by_id(
session: Session,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def insert_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> AssetInfo | None:
"""Insert a new AssetInfo. Returns None if unique constraint violated."""
now = get_utc_now()
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset_id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
except IntegrityError:
return None
def get_or_create_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> tuple[AssetInfo, bool]:
"""Get existing or create new AssetInfo. Returns (info, created)."""
info = insert_asset_info(
session,
asset_id=asset_id,
owner_id=owner_id,
name=name,
preview_id=preview_id,
)
if info:
return info, True
existing = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset_id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
.unique()
.scalar_one_or_none()
)
if not existing:
raise RuntimeError("Failed to find AssetInfo after insert conflict.")
return existing, False
def update_asset_info_timestamps(
session: Session,
asset_info: AssetInfo,
preview_id: str | None = None,
) -> None:
"""Update timestamps and optionally preview_id on existing AssetInfo."""
now = get_utc_now()
if preview_id and asset_info.preview_id != preview_id:
asset_info.preview_id = preview_id
asset_info.updated_at = now
if asset_info.last_access_time < now:
asset_info.last_access_time = now
session.flush()
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def update_asset_info_access_time(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or get_utc_now()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts
)
)
session.execute(stmt.values(last_access_time=ts))
def update_asset_info_name(
session: Session,
asset_info_id: str,
name: str,
) -> None:
"""Update the name of an AssetInfo."""
now = get_utc_now()
session.execute(
sa.update(AssetInfo)
.where(AssetInfo.id == asset_info_id)
.values(name=name, updated_at=now)
)
def update_asset_info_updated_at(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
) -> None:
"""Update the updated_at timestamp of an AssetInfo."""
ts = ts or get_utc_now()
session.execute(
sa.update(AssetInfo).where(AssetInfo.id == asset_info_id).values(updated_at=ts)
)
def set_asset_info_metadata(
session: Session,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = get_utc_now()
session.flush()
session.execute(
delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)
)
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_asset_info_by_id(
session: Session,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def set_asset_info_preview(
session: Session,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = get_utc_now()
session.flush()
def bulk_insert_asset_infos_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert AssetInfo rows with ON CONFLICT DO NOTHING.
Each dict should have: id, owner_id, name, asset_id, preview_id,
user_metadata, created_at, updated_at, last_access_time
"""
if not rows:
return
ins = sqlite.insert(AssetInfo).on_conflict_do_nothing(
index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]
)
for chunk in iter_chunks(rows, calculate_rows_per_statement(9)):
session.execute(ins, chunk)
def get_asset_info_ids_by_ids(
session: Session,
info_ids: list[str],
) -> set[str]:
"""Query to find which AssetInfo IDs exist in the database."""
if not info_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(info_ids, MAX_BIND_PARAMS):
result = session.execute(select(AssetInfo.id).where(AssetInfo.id.in_(chunk)))
found.update(result.scalars().all())
return found

File diff suppressed because it is too large Load Diff

View File

@@ -1,451 +0,0 @@
import os
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string
class CacheStateRow(NamedTuple):
"""Row from cache state query with joined asset data."""
state_id: int
file_path: str
mtime_ns: int | None
needs_verify: bool
asset_id: str
asset_hash: str | None
size_bytes: int
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
(
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
)
.scalars()
.all()
)
def upsert_cache_state(
session: Session,
asset_id: str,
file_path: str,
mtime_ns: int,
) -> tuple[bool, bool]:
"""Upsert a cache state by file_path. Returns (created, updated).
Also restores cache states that were previously marked as missing.
"""
vals = {
"asset_id": asset_id,
"file_path": file_path,
"mtime_ns": int(mtime_ns),
"is_missing": False,
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
if created:
return True, False
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == file_path)
.where(
sa.or_(
AssetCacheState.asset_id != asset_id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
AssetCacheState.is_missing == True, # noqa: E712
)
)
.values(asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False)
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0
return False, updated
def mark_cache_states_missing_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Mark cache states as missing when file_path doesn't match any valid prefix.
This is a non-destructive soft-delete that preserves user metadata.
Cache states can be restored if the file reappears in a future scan.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states marked as missing
"""
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
result = session.execute(
sa.update(AssetCacheState)
.where(~matches_valid_prefix)
.where(AssetCacheState.is_missing == False) # noqa: E712
.values(is_missing=True)
)
return result.rowcount
def restore_cache_states_by_paths(session: Session, file_paths: list[str]) -> int:
"""Restore cache states that were previously marked as missing.
Called when a file path is re-scanned and found to exist.
Args:
session: Database session
file_paths: List of file paths that exist and should be restored
Returns:
Number of cache states restored
"""
if not file_paths:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.file_path.in_(file_paths))
.where(AssetCacheState.is_missing == True) # noqa: E712
.values(is_missing=False)
)
return result.rowcount
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
"""Get IDs of unhashed assets (hash=None) with no active cache states.
An asset is considered unreferenced if it has no cache states,
or all its cache states are marked as missing.
Returns:
List of asset IDs that are unreferenced
"""
active_cache_state_exists = (
sa.select(sa.literal(1))
.where(AssetCacheState.asset_id == Asset.id)
.where(AssetCacheState.is_missing == False) # noqa: E712
.correlate(Asset)
.exists()
)
unreferenced_subq = sa.select(Asset.id).where(
Asset.hash.is_(None), ~active_cache_state_exists
)
return [row[0] for row in session.execute(unreferenced_subq).all()]
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""Delete assets and their AssetInfos by ID.
Args:
session: Database session
asset_ids: List of asset IDs to delete
Returns:
Number of assets deleted
"""
if not asset_ids:
return 0
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
def get_cache_states_for_prefixes(
session: Session,
prefixes: list[str],
*,
include_missing: bool = False,
) -> list[CacheStateRow]:
"""Get all cache states with paths matching any of the given prefixes.
Args:
session: Database session
prefixes: List of absolute directory prefixes to match
include_missing: If False (default), exclude cache states marked as missing
Returns:
List of cache state rows with joined asset data, ordered by asset_id, state_id
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
)
if not include_missing:
query = query.where(AssetCacheState.is_missing == False) # noqa: E712
rows = session.execute(
query.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
).all()
return [
CacheStateRow(
state_id=row[0],
file_path=row[1],
mtime_ns=row[2],
needs_verify=row[3],
asset_id=row[4],
asset_hash=row[5],
size_bytes=int(row[6] or 0),
)
for row in rows
]
def bulk_update_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
"""Set needs_verify flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(needs_verify=value)
)
return result.rowcount
def bulk_update_is_missing(session: Session, state_ids: list[int], value: bool) -> int:
"""Set is_missing flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(is_missing=value)
)
return result.rowcount
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
"""Delete cache states by their IDs.
Returns: Number of rows deleted
"""
if not state_ids:
return 0
result = session.execute(
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
)
return result.rowcount
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
"""Delete a seed asset (hash is None) and its AssetInfos.
Returns: True if asset was deleted, False if not found
"""
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
asset = session.get(Asset, asset_id)
if asset:
session.delete(asset)
return True
return False
class UnenrichedAssetRow(NamedTuple):
"""Row for assets needing enrichment."""
cache_state_id: int
asset_id: str
asset_info_id: str
file_path: str
enrichment_level: int
def get_unenriched_cache_states(
session: Session,
prefixes: list[str],
max_level: int = 0,
limit: int = 1000,
) -> list[UnenrichedAssetRow]:
"""Get cache states that need enrichment (enrichment_level <= max_level).
Args:
session: Database session
prefixes: List of absolute directory prefixes to scan
max_level: Maximum enrichment level to include (0=stubs, 1=metadata done)
limit: Maximum number of rows to return
Returns:
List of unenriched asset rows with file paths
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetCacheState.id,
AssetCacheState.asset_id,
AssetInfo.id,
AssetCacheState.file_path,
AssetCacheState.enrichment_level,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.join(AssetInfo, AssetInfo.asset_id == Asset.id)
.where(sa.or_(*conds))
.where(AssetCacheState.is_missing == False) # noqa: E712
.where(AssetCacheState.enrichment_level <= max_level)
.order_by(AssetCacheState.id.asc())
.limit(limit)
)
rows = session.execute(query).all()
return [
UnenrichedAssetRow(
cache_state_id=row[0],
asset_id=row[1],
asset_info_id=row[2],
file_path=row[3],
enrichment_level=row[4],
)
for row in rows
]
def update_enrichment_level(
session: Session,
cache_state_id: int,
level: int,
) -> None:
"""Update the enrichment level for a cache state."""
session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id == cache_state_id)
.values(enrichment_level=level)
)
def bulk_update_enrichment_level(
session: Session,
cache_state_ids: list[int],
level: int,
) -> int:
"""Update enrichment level for multiple cache states.
Returns: Number of rows updated
"""
if not cache_state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(cache_state_ids))
.values(enrichment_level=level)
)
return result.rowcount
def bulk_insert_cache_states_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: asset_id, file_path, mtime_ns
The is_missing field is automatically set to False for new inserts.
"""
if not rows:
return
enriched_rows = [{**row, "is_missing": False} for row in rows]
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
index_elements=[AssetCacheState.file_path]
)
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(4)):
session.execute(ins, chunk)
def get_cache_states_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query cache states to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
paths = list(path_to_asset.keys())
winners: set[str] = set()
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
result = session.execute(
select(AssetCacheState.file_path).where(
AssetCacheState.file_path.in_(chunk),
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
)
)
winners.update(result.scalars().all())
return winners

View File

@@ -4,7 +4,7 @@ from typing import Iterable
import sqlalchemy as sa
from app.assets.database.models import AssetInfo
from app.assets.database.models import AssetReference
MAX_BIND_PARAMS = 800
@@ -30,8 +30,11 @@ def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
"""Build owner visibility predicate for reads.
Owner-less rows are visible to everyone.
"""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
return AssetReference.owner_id == ""
return AssetReference.owner_id.in_(["", owner_id])

View File

@@ -6,7 +6,12 @@ from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.database.models import (
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
@@ -47,22 +52,22 @@ def ensure_tags_exist(
session.execute(ins)
def get_asset_tags(session: Session, asset_info_id: str) -> list[str]:
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
]
def set_asset_info_tags(
def set_reference_tags(
session: Session,
asset_info_id: str,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsDict:
@@ -72,8 +77,8 @@ def set_asset_info_tags(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
@@ -86,8 +91,8 @@ def set_asset_info_tags(
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
@@ -99,9 +104,9 @@ def set_asset_info_tags(
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
@@ -109,22 +114,22 @@ def set_asset_info_tags(
return {"added": to_add, "removed": to_remove, "total": desired}
def add_tags_to_asset_info(
def add_tags_to_reference(
session: Session,
asset_info_id: str,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: AssetInfo | None = None,
reference_row: AssetReference | None = None,
) -> AddTagsDict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if not reference_row:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
total = get_reference_tags(session, reference_id=reference_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
@@ -134,8 +139,8 @@ def add_tags_to_asset_info(
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
sa.select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
@@ -149,8 +154,8 @@ def add_tags_to_asset_info(
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
@@ -162,7 +167,7 @@ def add_tags_to_asset_info(
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
after = set(get_reference_tags(session, reference_id=reference_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
@@ -170,26 +175,26 @@ def add_tags_to_asset_info(
}
def remove_tags_from_asset_info(
def remove_tags_from_reference(
session: Session,
asset_info_id: str,
reference_id: str,
tags: Sequence[str],
) -> RemoveTagsDict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
total = get_reference_tags(session, reference_id=reference_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
sa.select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
@@ -200,14 +205,14 @@ def remove_tags_from_asset_info(
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
total = get_reference_tags(session, reference_id=reference_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
@@ -218,29 +223,32 @@ def add_missing_tag_for_asset_id(
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
AssetReference.id.label("asset_reference_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(AssetReference.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == "missing")
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
sqlite.insert(AssetReferenceTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
["asset_reference_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
)
@@ -250,11 +258,11 @@ def remove_missing_tag_for_asset_id(
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(
sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)
sa.delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id.in_(
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
),
AssetInfoTag.tag_name == "missing",
AssetReferenceTag.tag_name == "missing",
)
)
@@ -270,13 +278,13 @@ def list_tags_with_usage(
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
AssetReferenceTag.tag_name.label("tag_name"),
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.group_by(AssetReferenceTag.tag_name)
.subquery()
)
@@ -308,7 +316,9 @@ def list_tags_with_usage(
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
Tag.name.in_(
select(AssetReferenceTag.tag_name).group_by(AssetReferenceTag.tag_name)
)
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
@@ -323,26 +333,31 @@ def bulk_insert_tags_and_meta(
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
"""Batch insert into asset_reference_tags and asset_reference_meta.
Uses ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: List of dicts with keys: asset_info_id, tag_name, origin, added_at
meta_rows: List of dicts with keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
"""
if tag_rows:
ins_tags = sqlite.insert(AssetInfoTag).on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetInfoMeta).on_conflict_do_nothing(
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
index_elements=[
AssetInfoMeta.asset_info_id,
AssetInfoMeta.key,
AssetInfoMeta.ordinal,
AssetReferenceMeta.asset_reference_id,
AssetReferenceMeta.key,
AssetReferenceMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):

View File

@@ -31,8 +31,9 @@ ALLOWED_ROOTS: tuple[Literal["models", "input", "output"], ...] = (
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""Escapes %, _ and the escape char in a LIKE prefix.
Returns (escaped_prefix, escape_char).
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards

View File

@@ -10,13 +10,16 @@ from app.assets.database.queries import (
bulk_update_enrichment_level,
bulk_update_is_missing,
bulk_update_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
delete_references_by_ids,
ensure_tags_exist,
get_cache_states_for_prefixes,
get_unenriched_cache_states,
get_asset_by_hash,
get_references_for_prefixes,
get_unenriched_references,
reassign_asset_references,
remove_missing_tag_for_asset_id,
set_asset_info_metadata,
set_reference_metadata,
update_asset_hash_and_mime,
)
from app.assets.services.bulk_ingest import (
SeedAssetSpec,
@@ -38,8 +41,8 @@ from app.assets.services.path_utils import (
from app.database.db import create_session, dependencies_available
class _StateInfo(TypedDict):
sid: int
class _RefInfo(TypedDict):
ref_id: str
fp: str
exists: bool
fast_ok: bool
@@ -49,7 +52,7 @@ class _StateInfo(TypedDict):
class _AssetAccumulator(TypedDict):
hash: str | None
size_db: int
states: list[_StateInfo]
refs: list[_RefInfo]
RootType = Literal["models", "input", "output"]
@@ -97,17 +100,17 @@ def collect_models_files() -> list[str]:
return out
def sync_cache_states_with_filesystem(
def sync_references_with_filesystem(
session,
root: RootType,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Reconcile cache states with filesystem for a root.
"""Reconcile asset references with filesystem for a root.
- Toggle needs_verify per state using fast mtime/size check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Toggle needs_verify per reference using fast mtime/size check
- For hashed assets with at least one fast-ok ref: delete stale missing refs
- For seed assets with all refs missing: delete Asset and its references
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
@@ -124,7 +127,7 @@ def sync_cache_states_with_filesystem(
if not prefixes:
return set() if collect_existing_paths else None
rows = get_cache_states_for_prefixes(
rows = get_references_for_prefixes(
session, prefixes, include_missing=update_missing_tags
)
@@ -132,7 +135,7 @@ def sync_cache_states_with_filesystem(
for row in rows:
acc = by_asset.get(row.asset_id)
if acc is None:
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "states": []}
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
by_asset[row.asset_id] = acc
fast_ok = False
@@ -152,9 +155,9 @@ def sync_cache_states_with_filesystem(
exists = False
logging.debug("OSError checking %s: %s", row.file_path, e)
acc["states"].append(
acc["refs"].append(
{
"sid": row.state_id,
"ref_id": row.reference_id,
"fp": row.file_path,
"exists": exists,
"fast_ok": fast_ok,
@@ -162,61 +165,63 @@ def sync_cache_states_with_filesystem(
}
)
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
to_mark_missing: list[int] = []
to_clear_missing: list[int] = []
to_set_verify: list[str] = []
to_clear_verify: list[str] = []
stale_ref_ids: list[str] = []
to_mark_missing: list[str] = []
to_clear_missing: list[str] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
refs = acc["refs"]
any_fast_ok = any(r["fast_ok"] for r in refs)
all_missing = all(not r["exists"] for r in refs)
for s in states:
if not s["exists"]:
to_mark_missing.append(s["sid"])
for r in refs:
if not r["exists"]:
to_mark_missing.append(r["ref_id"])
continue
if s["fast_ok"]:
to_clear_missing.append(s["sid"])
if s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if r["fast_ok"]:
to_clear_missing.append(r["ref_id"])
if r["needs_verify"]:
to_clear_verify.append(r["ref_id"])
if not r["fast_ok"] and not r["needs_verify"]:
to_set_verify.append(r["ref_id"])
if a_hash is None:
if states and all_missing:
if refs and all_missing:
delete_orphaned_seed_asset(session, aid)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["fp"]))
continue
if any_fast_ok:
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
for r in refs:
if not r["exists"]:
stale_ref_ids.append(r["ref_id"])
if update_missing_tags:
try:
remove_missing_tag_for_asset_id(session, asset_id=aid)
except Exception as e:
logging.warning("Failed to remove missing tag for asset %s: %s", aid, e)
logging.warning(
"Failed to remove missing tag for asset %s: %s", aid, e
)
elif update_missing_tags:
try:
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
except Exception as e:
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["fp"]))
delete_cache_states_by_ids(session, stale_state_ids)
stale_set = set(stale_state_ids)
to_mark_missing = [sid for sid in to_mark_missing if sid not in stale_set]
delete_references_by_ids(session, stale_ref_ids)
stale_set = set(stale_ref_ids)
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
bulk_update_is_missing(session, to_mark_missing, value=True)
bulk_update_is_missing(session, to_clear_missing, value=False)
bulk_update_needs_verify(session, to_set_verify, value=True)
@@ -226,13 +231,13 @@ def sync_cache_states_with_filesystem(
def sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's cache states with the filesystem.
"""Sync a single root's references with the filesystem.
Returns survivors (existing paths) or empty set on failure.
"""
try:
with create_session() as sess:
survivors = sync_cache_states_with_filesystem(
survivors = sync_references_with_filesystem(
sess,
root,
collect_existing_paths=True,
@@ -246,7 +251,7 @@ def sync_root_safely(root: RootType) -> set[str]:
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
"""Mark cache states as missing when outside the given prefixes.
"""Mark references as missing when outside the given prefixes.
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
"""
@@ -283,8 +288,8 @@ def build_asset_specs(
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
enable_metadata_extraction: If True, extract tier 1 & 2 metadata from files
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
compute_hashes: If True, compute blake3 hashes (slow for large files)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
@@ -398,7 +403,7 @@ def build_stub_specs(
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created infos."""
"""Insert asset specs into database, returning count of created refs."""
if not specs:
return 0
with create_session() as sess:
@@ -406,7 +411,7 @@ def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_infos
return result.inserted_refs
def seed_assets(
@@ -419,10 +424,10 @@ def seed_assets(
Args:
roots: Tuple of root types to scan (models, input, output)
enable_logging: If True, log progress and completion messages
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
compute_hashes: If True, compute blake3 hashes (slow for large files)
Note: This function does not mark missing assets. Call mark_missing_outside_prefixes_safely
separately if cleanup is needed.
Note: This function does not mark missing assets.
Call mark_missing_outside_prefixes_safely separately if cleanup is needed.
"""
if not dependencies_available():
if enable_logging:
@@ -443,7 +448,8 @@ def seed_assets(
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
"Assets scan(roots=%s) completed in %.3fs "
"(created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
@@ -471,7 +477,7 @@ def get_unenriched_assets_for_roots(
limit: Maximum number of rows to return
Returns:
List of UnenrichedAssetRow
List of UnenrichedReferenceRow
"""
prefixes: list[str] = []
for root in roots:
@@ -481,13 +487,15 @@ def get_unenriched_assets_for_roots(
return []
with create_session() as sess:
return get_unenriched_cache_states(sess, prefixes, max_level=max_level, limit=limit)
return get_unenriched_references(
sess, prefixes, max_level=max_level, limit=limit
)
def enrich_asset(
file_path: str,
cache_state_id: int,
asset_info_id: str,
reference_id: str,
asset_id: str,
extract_metadata: bool = True,
compute_hash: bool = False,
) -> int:
@@ -495,8 +503,8 @@ def enrich_asset(
Args:
file_path: Absolute path to the file
cache_state_id: ID of the cache state to update
asset_info_id: ID of the asset info to update
reference_id: ID of the reference to update
asset_id: ID of the asset to update (for mime_type and hash)
extract_metadata: If True, extract safetensors header and mime type
compute_hash: If True, compute blake3 hash
@@ -511,30 +519,46 @@ def enrich_asset(
return new_level
rel_fname = compute_relative_filename(file_path)
mime_type: str | None = None
if extract_metadata:
metadata = extract_file_metadata(
file_path,
stat_result=stat_p,
enable_safetensors=True,
relative_filename=rel_fname,
)
if metadata:
mime_type = metadata.content_type
new_level = ENRICHMENT_METADATA
full_hash: str | None = None
if compute_hash:
try:
digest = compute_blake3_hash(file_path)
full_hash = f"blake3:{digest}"
new_level = ENRICHMENT_HASHED
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
with create_session() as sess:
if extract_metadata:
metadata = extract_file_metadata(
file_path,
stat_result=stat_p,
enable_safetensors=True,
relative_filename=rel_fname,
)
if metadata:
user_metadata = metadata.to_user_metadata()
set_asset_info_metadata(sess, asset_info_id, user_metadata)
new_level = ENRICHMENT_METADATA
if extract_metadata and metadata:
user_metadata = metadata.to_user_metadata()
set_reference_metadata(sess, reference_id, user_metadata)
if compute_hash:
try:
digest = compute_blake3_hash(file_path)
# TODO: Update asset.hash field
# For now just mark the enrichment level
new_level = ENRICHMENT_HASHED
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
if full_hash:
existing = get_asset_by_hash(sess, full_hash)
if existing and existing.id != asset_id:
reassign_asset_references(sess, asset_id, existing.id, reference_id)
delete_orphaned_seed_asset(sess, asset_id)
if mime_type:
update_asset_hash_and_mime(sess, existing.id, mime_type=mime_type)
else:
update_asset_hash_and_mime(sess, asset_id, full_hash, mime_type)
elif mime_type:
update_asset_hash_and_mime(sess, asset_id, mime_type=mime_type)
bulk_update_enrichment_level(sess, [cache_state_id], new_level)
bulk_update_enrichment_level(sess, [reference_id], new_level)
sess.commit()
return new_level
@@ -548,7 +572,7 @@ def enrich_assets_batch(
"""Enrich a batch of assets.
Args:
rows: List of UnenrichedAssetRow from get_unenriched_assets_for_roots
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
extract_metadata: If True, extract metadata for each asset
compute_hash: If True, compute hash for each asset
@@ -562,8 +586,8 @@ def enrich_assets_batch(
try:
new_level = enrich_asset(
file_path=row.file_path,
cache_state_id=row.cache_state_id,
asset_info_id=row.asset_info_id,
reference_id=row.reference_id,
asset_id=row.asset_id,
extract_metadata=extract_metadata,
compute_hash=compute_hash,
)

View File

@@ -128,7 +128,7 @@ class AssetSeeder:
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
progress_callback: Optional callback called with progress updates
prune_first: If True, prune orphaned assets before scanning
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
compute_hashes: If True, compute blake3 hashes (slow)
Returns:
True if scan was started, False if already running
@@ -136,7 +136,7 @@ class AssetSeeder:
if self._disabled:
logging.debug("Asset seeder is disabled, skipping start")
return False
logging.info("Asset seeder start requested (roots=%s, phase=%s)", roots, phase.value)
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
with self._lock:
if self._state != State.IDLE:
logging.info("Asset seeder already running, skipping start")
@@ -295,12 +295,15 @@ class AssetSeeder:
if not self.wait(timeout=timeout):
return False
cb = progress_callback if progress_callback is not None else prev_callback
return self.start(
roots=roots if roots is not None else prev_roots,
phase=phase if phase is not None else prev_phase,
progress_callback=progress_callback if progress_callback is not None else prev_callback,
progress_callback=cb,
prune_first=prune_first if prune_first is not None else prev_prune,
compute_hashes=compute_hashes if compute_hashes is not None else prev_hashes,
compute_hashes=(
compute_hashes if compute_hashes is not None else prev_hashes
),
)
def wait(self, timeout: float | None = None) -> bool:
@@ -497,7 +500,7 @@ class AssetSeeder:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d cache states as missing before scan", marked)
logging.info("Marked %d refs as missing before scan", marked)
if self._check_pause_and_cancel():
logging.info("Asset scan cancelled after pruning phase")
@@ -508,7 +511,8 @@ class AssetSeeder:
# Phase 1: Fast scan (stub records)
if phase in (ScanPhase.FAST, ScanPhase.FULL):
total_created, skipped_existing, total_paths = self._run_fast_phase(roots)
created, skipped, paths = self._run_fast_phase(roots)
total_created, skipped_existing, total_paths = created, skipped, paths
if self._check_pause_and_cancel():
cancelled = True
@@ -542,12 +546,8 @@ class AssetSeeder:
elapsed = time.perf_counter() - t_start
logging.info(
"Asset scan(roots=%s, phase=%s) completed in %.3fs (created=%d, enriched=%d, skipped=%d)",
roots,
phase.value,
elapsed,
total_created,
total_enriched,
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
roots, phase.value, elapsed, total_created, total_enriched,
skipped_existing,
)
@@ -668,7 +668,10 @@ class AssetSeeder:
progress_interval = 1.0
# Get the target enrichment level based on compute_hashes
target_max_level = ENRICHMENT_STUB if not self._compute_hashes else ENRICHMENT_METADATA
if not self._compute_hashes:
target_max_level = ENRICHMENT_STUB
else:
target_max_level = ENRICHMENT_METADATA
self._emit_event(
"assets.seed.started",

View File

@@ -30,11 +30,11 @@ from app.assets.services.schemas import (
AddTagsResult,
AssetData,
AssetDetailResult,
AssetInfoData,
AssetSummaryData,
DownloadResolutionResult,
IngestResult,
ListAssetsResult,
ReferenceData,
RegisterAssetResult,
RemoveTagsResult,
SetTagsResult,
@@ -52,8 +52,8 @@ __all__ = [
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetInfoData",
"AssetSummaryData",
"ReferenceData",
"BulkInsertResult",
"DependencyMissingError",
"DownloadResolutionResult",

View File

@@ -7,23 +7,23 @@ from typing import Sequence
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
reference_exists_for_asset_id,
delete_reference_by_id,
fetch_reference_and_asset,
fetch_reference_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_asset_info_by_id,
list_asset_infos_page,
list_cache_states_by_asset_id,
set_asset_info_metadata,
set_asset_info_preview,
set_asset_info_tags,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_updated_at,
get_reference_by_id,
list_references_page,
list_references_by_asset_id,
set_reference_metadata,
set_reference_preview,
set_reference_tags,
update_reference_access_time,
update_reference_name,
update_reference_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_filename_for_asset
from app.assets.services.path_utils import compute_filename_for_reference
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
@@ -32,34 +32,34 @@ from app.assets.services.schemas import (
ListAssetsResult,
UserMetadata,
extract_asset_data,
extract_info_data,
extract_reference_data,
)
from app.database.db import create_session
def get_asset_detail(
asset_info_id: str,
reference_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
with create_session() as session:
result = fetch_asset_info_asset_and_tags(
result = fetch_reference_asset_and_tags(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
return None
info, asset, tags = result
ref, asset, tags = result
return AssetDetailResult(
info=extract_info_data(info),
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
def update_asset_metadata(
asset_info_id: str,
reference_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: UserMetadata = None,
@@ -67,58 +67,58 @@ def update_asset_metadata(
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info.owner_id and info.owner_id != owner_id:
ref = get_reference_by_id(session, reference_id=reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
if ref.owner_id and ref.owner_id != owner_id:
raise PermissionError("not owner")
touched = False
if name is not None and name != info.name:
update_asset_info_name(session, asset_info_id=asset_info_id, name=name)
if name is not None and name != ref.name:
update_reference_name(session, reference_id=reference_id, name=name)
touched = True
computed_filename = compute_filename_for_asset(session, info.asset_id)
computed_filename = compute_filename_for_reference(session, ref)
new_meta: dict | None = None
if user_metadata is not None:
new_meta = dict(user_metadata)
elif computed_filename:
current_meta = info.user_metadata or {}
current_meta = ref.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
if new_meta is not None:
if computed_filename:
new_meta["filename"] = computed_filename
set_asset_info_metadata(
session, asset_info_id=asset_info_id, user_metadata=new_meta
set_reference_metadata(
session, reference_id=reference_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
set_reference_tags(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
update_asset_info_updated_at(session, asset_info_id=asset_info_id)
update_reference_updated_at(session, reference_id=reference_id)
result = fetch_asset_info_asset_and_tags(
result = fetch_reference_asset_and_tags(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
raise RuntimeError("State changed during update")
info, asset, tag_list = result
ref, asset, tag_list = result
detail = AssetDetailResult(
info=extract_info_data(info),
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_list,
)
@@ -128,16 +128,17 @@ def update_asset_metadata(
def delete_asset_reference(
asset_info_id: str,
reference_id: str,
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
ref_row = get_reference_by_id(session, reference_id=reference_id)
asset_id = ref_row.asset_id if ref_row else None
file_path = ref_row.file_path if ref_row else None
deleted = delete_asset_info_by_id(
session, asset_info_id=asset_info_id, owner_id=owner_id
deleted = delete_reference_by_id(
session, reference_id=reference_id, owner_id=owner_id
)
if not deleted:
session.commit()
@@ -147,16 +148,19 @@ def delete_asset_reference(
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
still_exists = reference_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
# Orphaned asset - delete it and its files
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
refs = list_references_by_asset_id(session, asset_id=asset_id)
file_paths = [
s.file_path for s in (states or []) if getattr(s, "file_path", None)
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
]
# Also include the just-deleted file path
if file_path:
file_paths.append(file_path)
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
@@ -174,32 +178,32 @@ def delete_asset_reference(
def set_asset_preview(
asset_info_id: str,
reference_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
ref_row = get_reference_by_id(session, reference_id=reference_id)
if not ref_row:
raise ValueError(f"AssetReference {reference_id} not found")
if ref_row.owner_id and ref_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
set_reference_preview(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
preview_asset_id=preview_asset_id,
)
result = fetch_asset_info_asset_and_tags(
session, asset_info_id=asset_info_id, owner_id=owner_id
result = fetch_reference_asset_and_tags(
session, reference_id=reference_id, owner_id=owner_id
)
if not result:
raise RuntimeError("State changed during preview update")
info, asset, tags = result
ref, asset, tags = result
detail = AssetDetailResult(
info=extract_info_data(info),
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
@@ -231,7 +235,7 @@ def list_assets_page(
order: str = "desc",
) -> ListAssetsResult:
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
refs, tag_map, total = list_references_page(
session,
owner_id=owner_id,
include_tags=include_tags,
@@ -245,12 +249,12 @@ def list_assets_page(
)
items: list[AssetSummaryData] = []
for info in infos:
for ref in refs:
items.append(
AssetSummaryData(
info=extract_info_data(info),
asset=extract_asset_data(info.asset),
tags=tag_map.get(info.id, []),
ref=extract_reference_data(ref),
asset=extract_asset_data(ref.asset),
tags=tag_map.get(ref.id, []),
)
)
@@ -258,33 +262,40 @@ def list_assets_page(
def resolve_asset_for_download(
asset_info_id: str,
reference_id: str,
owner_id: str = "",
) -> DownloadResolutionResult:
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=asset_info_id, owner_id=owner_id
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
raise ValueError(f"AssetReference {reference_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(states)
if not abs_path:
raise FileNotFoundError(
f"No live path for AssetInfo {asset_info_id} (asset id={asset.id}, name={info.name})"
)
ref, asset = pair
update_asset_info_access_time(session, asset_info_id=asset_info_id)
# For references with file_path, use that directly
if ref.file_path and os.path.isfile(ref.file_path):
abs_path = ref.file_path
else:
# For API-created refs without file_path, find a path from other refs
refs = list_references_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(refs)
if not abs_path:
raise FileNotFoundError(
f"No live path for AssetReference {reference_id} "
f"(asset id={asset.id}, name={ref.name})"
)
update_reference_access_time(session, reference_id=reference_id)
session.commit()
ctype = (
asset.mime_type
or mimetypes.guess_type(info.name or abs_path)[0]
or mimetypes.guess_type(ref.name or abs_path)[0]
or "application/octet-stream"
)
download_name = info.name or os.path.basename(abs_path)
download_name = ref.name or os.path.basename(abs_path)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import logging
import os
import uuid
from dataclasses import dataclass
@@ -10,17 +9,16 @@ from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
from app.assets.database.queries import (
bulk_insert_asset_infos_ignore_conflicts,
bulk_insert_assets,
bulk_insert_cache_states_ignore_conflicts,
bulk_insert_references_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
get_asset_info_ids_by_ids,
get_cache_states_by_paths_and_asset_ids,
get_existing_asset_ids,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_unreferenced_unhashed_asset_ids,
mark_cache_states_missing_outside_prefixes,
restore_cache_states_by_paths,
mark_references_missing_outside_prefixes,
restore_references_by_paths,
)
from app.assets.helpers import get_utc_now
@@ -52,21 +50,15 @@ class AssetRow(TypedDict):
created_at: datetime
class CacheStateRow(TypedDict):
"""Row data for inserting a CacheState."""
class ReferenceRow(TypedDict):
"""Row data for inserting an AssetReference."""
id: str
asset_id: str
file_path: str
mtime_ns: int
class AssetInfoRow(TypedDict):
"""Row data for inserting an AssetInfo."""
id: str
owner_id: str
name: str
asset_id: str
preview_id: str | None
user_metadata: dict[str, Any] | None
created_at: datetime
@@ -74,27 +66,10 @@ class AssetInfoRow(TypedDict):
last_access_time: datetime
class AssetInfoRowInternal(TypedDict):
"""Internal row data for AssetInfo with extra tracking fields."""
id: str
owner_id: str
name: str
asset_id: str
preview_id: str | None
user_metadata: dict[str, Any] | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
_tags: list[str]
_filename: str
_extracted_metadata: ExtractedMetadata | None
class TagRow(TypedDict):
"""Row data for inserting a Tag."""
asset_info_id: str
asset_reference_id: str
tag_name: str
origin: str
added_at: datetime
@@ -103,7 +78,7 @@ class TagRow(TypedDict):
class MetadataRow(TypedDict):
"""Row data for inserting asset metadata."""
asset_info_id: str
asset_reference_id: str
key: str
ordinal: int
val_str: str | None
@@ -116,9 +91,9 @@ class MetadataRow(TypedDict):
class BulkInsertResult:
"""Result of bulk asset insertion."""
inserted_infos: int
won_states: int
lost_states: int
inserted_refs: int
won_paths: int
lost_paths: int
def batch_insert_seed_assets(
@@ -138,29 +113,28 @@ def batch_insert_seed_assets(
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim cache states with ON CONFLICT DO NOTHING
2. Claim references with ON CONFLICT DO NOTHING on file_path
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert AssetInfo for winners
6. Insert tags and metadata for successfully inserted AssetInfos
5. Insert tags and metadata for successfully inserted references
Returns:
BulkInsertResult with inserted_infos, won_states, lost_states
BulkInsertResult with inserted_refs, won_paths, lost_paths
"""
if not specs:
return BulkInsertResult(inserted_infos=0, won_states=0, lost_states=0)
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
current_time = get_utc_now()
asset_rows: list[AssetRow] = []
cache_state_rows: list[CacheStateRow] = []
reference_rows: list[ReferenceRow] = []
path_to_asset_id: dict[str, str] = {}
asset_id_to_info: dict[str, AssetInfoRowInternal] = {}
asset_id_to_ref_data: dict[str, dict] = {}
absolute_path_list: list[str] = []
for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"])
asset_id = str(uuid.uuid4())
asset_info_id = str(uuid.uuid4())
reference_id = str(uuid.uuid4())
absolute_path_list.append(absolute_path)
path_to_asset_id[absolute_path] = asset_id
@@ -174,13 +148,7 @@ def batch_insert_seed_assets(
"created_at": current_time,
}
)
cache_state_rows.append(
{
"asset_id": asset_id,
"file_path": absolute_path,
"mtime_ns": spec["mtime_ns"],
}
)
# Build user_metadata from extracted metadata or fallback to filename
extracted_metadata = spec.get("metadata")
if extracted_metadata:
@@ -190,35 +158,43 @@ def batch_insert_seed_assets(
else:
user_metadata = None
asset_id_to_info[asset_id] = {
"id": asset_info_id,
"owner_id": owner_id,
"name": spec["info_name"],
"asset_id": asset_id,
"preview_id": None,
"user_metadata": user_metadata,
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,
"_tags": spec["tags"],
"_filename": spec["fname"],
"_extracted_metadata": extracted_metadata,
reference_rows.append(
{
"id": reference_id,
"asset_id": asset_id,
"file_path": absolute_path,
"mtime_ns": spec["mtime_ns"],
"owner_id": owner_id,
"name": spec["info_name"],
"preview_id": None,
"user_metadata": user_metadata,
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,
}
)
asset_id_to_ref_data[asset_id] = {
"reference_id": reference_id,
"tags": spec["tags"],
"filename": spec["fname"],
"extracted_metadata": extracted_metadata,
}
bulk_insert_assets(session, asset_rows)
# Filter cache states to only those whose assets were actually inserted
# Filter reference rows to only those whose assets were actually inserted
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
inserted_asset_ids = get_existing_asset_ids(
session, [r["asset_id"] for r in cache_state_rows]
session, [r["asset_id"] for r in reference_rows]
)
cache_state_rows = [
r for r in cache_state_rows if r["asset_id"] in inserted_asset_ids
reference_rows = [
r for r in reference_rows if r["asset_id"] in inserted_asset_ids
]
bulk_insert_cache_states_ignore_conflicts(session, cache_state_rows)
restore_cache_states_by_paths(session, absolute_path_list)
winning_paths = get_cache_states_by_paths_and_asset_ids(session, path_to_asset_id)
bulk_insert_references_ignore_conflicts(session, reference_rows)
restore_references_by_paths(session, absolute_path_list)
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
all_paths_set = set(absolute_path_list)
losing_paths = all_paths_set - winning_paths
@@ -229,44 +205,34 @@ def batch_insert_seed_assets(
if not winning_paths:
return BulkInsertResult(
inserted_infos=0,
won_states=0,
lost_states=len(losing_paths),
inserted_refs=0,
won_paths=0,
lost_paths=len(losing_paths),
)
winner_info_rows = [
asset_id_to_info[path_to_asset_id[path]] for path in winning_paths
# Get reference IDs for winners
winning_ref_ids = [
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
for path in winning_paths
]
database_info_rows: list[AssetInfoRow] = [
{
"id": info_row["id"],
"owner_id": info_row["owner_id"],
"name": info_row["name"],
"asset_id": info_row["asset_id"],
"preview_id": info_row["preview_id"],
"user_metadata": info_row["user_metadata"],
"created_at": info_row["created_at"],
"updated_at": info_row["updated_at"],
"last_access_time": info_row["last_access_time"],
}
for info_row in winner_info_rows
]
bulk_insert_asset_infos_ignore_conflicts(session, database_info_rows)
all_info_ids = [info_row["id"] for info_row in winner_info_rows]
inserted_info_ids = get_asset_info_ids_by_ids(session, all_info_ids)
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
tag_rows: list[TagRow] = []
metadata_rows: list[MetadataRow] = []
if inserted_info_ids:
for info_row in winner_info_rows:
info_id = info_row["id"]
if info_id not in inserted_info_ids:
if inserted_ref_ids:
for path in winning_paths:
asset_id = path_to_asset_id[path]
ref_data = asset_id_to_ref_data[asset_id]
ref_id = ref_data["reference_id"]
if ref_id not in inserted_ref_ids:
continue
for tag in info_row["_tags"]:
for tag in ref_data["tags"]:
tag_rows.append(
{
"asset_info_id": info_id,
"asset_reference_id": ref_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time,
@@ -274,17 +240,17 @@ def batch_insert_seed_assets(
)
# Use extracted metadata for meta rows if available
extracted_metadata = info_row.get("_extracted_metadata")
extracted_metadata = ref_data.get("extracted_metadata")
if extracted_metadata:
metadata_rows.extend(extracted_metadata.to_meta_rows(info_id))
elif info_row["_filename"]:
metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id))
elif ref_data["filename"]:
# Fallback: just store filename
metadata_rows.append(
{
"asset_info_id": info_id,
"asset_reference_id": ref_id,
"key": "filename",
"ordinal": 0,
"val_str": info_row["_filename"],
"val_str": ref_data["filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
@@ -294,40 +260,36 @@ def batch_insert_seed_assets(
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(
inserted_infos=len(inserted_info_ids),
won_states=len(winning_paths),
lost_states=len(losing_paths),
inserted_refs=len(inserted_ref_ids),
won_paths=len(winning_paths),
lost_paths=len(losing_paths),
)
def mark_assets_missing_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Mark cache states as missing when outside valid prefixes.
"""Mark references as missing when outside valid prefixes.
This is a non-destructive operation that soft-deletes cache states
This is a non-destructive operation that soft-deletes references
by setting is_missing=True. User metadata is preserved and assets
can be restored if the file reappears in a future scan.
Note: This does NOT delete
unreferenced unhashed assets. Those are preserved so user metadata
remains intact even when base directories change.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states marked as missing
Number of references marked as missing
"""
return mark_cache_states_missing_outside_prefixes(session, valid_prefixes)
return mark_references_missing_outside_prefixes(session, valid_prefixes)
def cleanup_unreferenced_assets(session: Session) -> int:
"""Hard-delete unhashed assets with no active cache states.
"""Hard-delete unhashed assets with no active references.
This is a destructive operation intended for explicit cleanup.
Only deletes assets where hash=None and all cache states are missing.
Only deletes assets where hash=None and all references are missing.
Returns:
Number of assets deleted

View File

@@ -8,24 +8,23 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.models import Asset, AssetInfo, Tag
from app.assets.database.models import Asset, AssetReference, Tag
from app.assets.database.queries import (
add_tags_to_asset_info,
fetch_asset_info_and_asset,
add_tags_to_reference,
fetch_reference_and_asset,
get_asset_by_hash,
get_asset_tags,
get_or_create_asset_info,
get_reference_tags,
get_or_create_reference,
remove_missing_tag_for_asset_id,
set_asset_info_metadata,
set_asset_info_tags,
update_asset_info_timestamps,
set_reference_metadata,
set_reference_tags,
upsert_asset,
upsert_cache_state,
upsert_reference,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_filename_for_asset,
compute_filename_for_reference,
resolve_destination_from_tags,
validate_path_within_base,
)
@@ -35,7 +34,7 @@ from app.assets.services.schemas import (
UploadResult,
UserMetadata,
extract_asset_data,
extract_info_data,
extract_reference_data,
)
from app.database.db import create_session
@@ -58,9 +57,9 @@ def _ingest_file_from_path(
asset_created = False
asset_updated = False
state_created = False
state_updated = False
asset_info_id: str | None = None
ref_created = False
ref_updated = False
reference_id: str | None = None
with create_session() as session:
if preview_id:
@@ -74,49 +73,42 @@ def _ingest_file_from_path(
mime_type=mime_type,
)
state_created, state_updated = upsert_cache_state(
ref_created, ref_updated = upsert_reference(
session,
asset_id=asset.id,
file_path=locator,
name=info_name or os.path.basename(locator),
mtime_ns=mtime_ns,
owner_id=owner_id,
)
if info_name:
info, info_created = get_or_create_asset_info(
session,
asset_id=asset.id,
owner_id=owner_id,
name=info_name,
preview_id=preview_id,
)
if info_created:
asset_info_id = info.id
else:
update_asset_info_timestamps(
session, asset_info=info, preview_id=preview_id
)
asset_info_id = info.id
# Get the reference we just created/updated
from app.assets.database.queries import get_reference_by_file_path
ref = get_reference_by_file_path(session, locator)
if ref:
reference_id = ref.id
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
norm = normalize_tags(list(tags))
if norm and asset_info_id:
if norm:
if require_existing_tags:
_validate_tags_exist(session, norm)
add_tags_to_asset_info(
add_tags_to_reference(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
if asset_info_id:
_update_metadata_with_filename(
session,
asset_info_id=asset_info_id,
asset_id=asset.id,
info=info,
user_metadata=user_metadata,
)
_update_metadata_with_filename(
session,
reference_id=reference_id,
ref=ref,
user_metadata=user_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
@@ -128,9 +120,9 @@ def _ingest_file_from_path(
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
state_created=state_created,
state_updated=state_updated,
asset_info_id=asset_info_id,
ref_created=ref_created,
ref_updated=ref_updated,
reference_id=reference_id,
)
@@ -147,18 +139,17 @@ def _register_existing_asset(
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
info, info_created = get_or_create_asset_info(
ref, ref_created = get_or_create_reference(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
preview_id=None,
)
if not info_created:
tag_names = get_asset_tags(session, asset_info_id=info.id)
if not ref_created:
tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult(
info=extract_info_data(info),
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=False,
@@ -167,29 +158,29 @@ def _register_existing_asset(
return result
new_meta = dict(user_metadata or {})
computed_filename = compute_filename_for_asset(session, asset.id)
computed_filename = compute_filename_for_reference(session, ref)
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
set_asset_info_metadata(
set_reference_metadata(
session,
asset_info_id=info.id,
reference_id=ref.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
set_reference_tags(
session,
asset_info_id=info.id,
reference_id=ref.id,
tags=tags,
origin=tag_origin,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.refresh(info)
tag_names = get_reference_tags(session, reference_id=ref.id)
session.refresh(ref)
result = RegisterAssetResult(
info=extract_info_data(info),
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=True,
@@ -211,14 +202,13 @@ def _validate_tags_exist(session: Session, tags: list[str]) -> None:
def _update_metadata_with_filename(
session: Session,
asset_info_id: str,
asset_id: str,
info: AssetInfo,
reference_id: str,
ref: AssetReference,
user_metadata: UserMetadata,
) -> None:
computed_filename = compute_filename_for_asset(session, asset_id)
computed_filename = compute_filename_for_reference(session, ref)
current_meta = info.user_metadata or {}
current_meta = ref.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata:
for k, v in user_metadata.items():
@@ -227,9 +217,9 @@ def _update_metadata_with_filename(
new_meta["filename"] = computed_filename
if new_meta != current_meta:
set_asset_info_metadata(
set_reference_metadata(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
user_metadata=new_meta,
)
@@ -287,7 +277,7 @@ def upload_from_temp_path(
owner_id=owner_id,
)
return UploadResult(
info=result.info,
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,
@@ -334,21 +324,21 @@ def upload_from_temp_path(
tag_origin="manual",
require_existing_tags=False,
)
info_id = ingest_result.asset_info_id
if not info_id:
raise RuntimeError("failed to create asset metadata")
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=info_id, owner_id=owner_id
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
info=extract_info_data(info),
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
@@ -381,7 +371,7 @@ def create_from_hash(
)
return UploadResult(
info=result.info,
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,

View File

@@ -52,6 +52,7 @@ class ExtractedMetadata:
# Tier 1: Filesystem (always available)
filename: str = ""
file_path: str = "" # Full absolute path to the file
content_length: int = 0
content_type: str | None = None
format: str = "" # file extension without dot
@@ -76,12 +77,14 @@ class ExtractedMetadata:
resolve_url: str | None = None
def to_user_metadata(self) -> dict[str, Any]:
"""Convert to user_metadata dict for AssetInfo.user_metadata JSON field."""
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
data: dict[str, Any] = {
"filename": self.filename,
"content_length": self.content_length,
"format": self.format,
}
if self.file_path:
data["file_path"] = self.file_path
if self.content_type:
data["content_type"] = self.content_type
@@ -119,14 +122,14 @@ class ExtractedMetadata:
return data
def to_meta_rows(self, asset_info_id: str) -> list[dict]:
"""Convert to asset_info_meta rows for typed/indexed querying."""
def to_meta_rows(self, reference_id: str) -> list[dict]:
"""Convert to asset_reference_meta rows for typed/indexed querying."""
rows: list[dict] = []
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
if val:
rows.append({
"asset_info_id": asset_info_id,
"asset_reference_id": reference_id,
"key": key,
"ordinal": ordinal,
"val_str": val[:2048] if len(val) > 2048 else val,
@@ -138,7 +141,7 @@ class ExtractedMetadata:
def add_num(key: str, val: int | float | None) -> None:
if val is not None:
rows.append({
"asset_info_id": asset_info_id,
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
@@ -150,7 +153,7 @@ class ExtractedMetadata:
def add_bool(key: str, val: bool | None) -> None:
if val is not None:
rows.append({
"asset_info_id": asset_info_id,
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
@@ -168,7 +171,8 @@ class ExtractedMetadata:
# Tier 2
add_str("base_model", self.base_model)
add_str("air", self.air)
add_bool("has_preview_images", self.has_preview_images if self.has_preview_images else None)
has_previews = self.has_preview_images if self.has_preview_images else None
add_bool("has_preview_images", has_previews)
# trained_words as multiple rows with ordinals
if self.trained_words:
@@ -191,7 +195,9 @@ class ExtractedMetadata:
return rows
def _read_safetensors_header(path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE) -> dict[str, Any] | None:
def _read_safetensors_header(
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
) -> dict[str, Any] | None:
"""Read only the JSON header from a safetensors file.
This is very fast - reads 8 bytes for header length, then the JSON header.
@@ -220,7 +226,9 @@ def _read_safetensors_header(path: str, max_size: int = MAX_SAFETENSORS_HEADER_S
return None
def _extract_safetensors_metadata(header: dict[str, Any], meta: ExtractedMetadata) -> None:
def _extract_safetensors_metadata(
header: dict[str, Any], meta: ExtractedMetadata
) -> None:
"""Extract metadata from safetensors header __metadata__ section.
Modifies meta in-place.
@@ -230,7 +238,11 @@ def _extract_safetensors_metadata(header: dict[str, Any], meta: ExtractedMetadat
return
# Common model metadata
meta.base_model = st_meta.get("ss_base_model_version") or st_meta.get("modelspec.base_model") or st_meta.get("base_model")
meta.base_model = (
st_meta.get("ss_base_model_version")
or st_meta.get("modelspec.base_model")
or st_meta.get("base_model")
)
# Trained words / trigger words
trained_words = st_meta.get("ss_tag_frequency")
@@ -304,8 +316,8 @@ def extract_file_metadata(
meta = ExtractedMetadata()
# Tier 1: Filesystem metadata
# Use relative_filename if provided (for backward compatibility with existing behavior)
meta.filename = relative_filename if relative_filename else os.path.basename(abs_path)
meta.filename = relative_filename or os.path.basename(abs_path)
meta.file_path = abs_path
_, ext = os.path.splitext(abs_path)
meta.format = ext.lstrip(".").lower() if ext else ""
@@ -333,6 +345,6 @@ def extract_file_metadata(
try:
_extract_safetensors_metadata(header, meta)
except Exception as e:
logging.debug("Failed to extract safetensors metadata from %s: %s", abs_path, e)
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
return meta

View File

@@ -7,18 +7,15 @@ from app.assets.helpers import normalize_tags
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
"""Build list of (folder_name, base_paths[]) for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
Includes a category if any of its base paths lies under models_dir.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = (
values[0],
values[1],
) # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
# Unpack carefully to handle nodepacks that modify folder_paths
paths, _exts = values[0], values[1]
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
@@ -70,7 +67,6 @@ def compute_relative_filename(file_path: str) -> str | None:
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
@@ -92,18 +88,18 @@ def compute_relative_filename(file_path: str) -> str | None:
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
ValueError: path does not belong to any known root.
"""
fp_abs = os.path.abspath(file_path)
@@ -149,32 +145,35 @@ def get_asset_category_and_relative_path(
)
def compute_filename_for_reference(session, ref) -> str | None:
"""Compute the relative filename for an asset reference.
Uses the file_path from the reference if available.
"""
if ref.file_path:
return compute_relative_filename(ref.file_path)
return None
def compute_filename_for_asset(session, asset_id: str) -> str | None:
"""Compute the relative filename for an asset from its best live cache state path."""
from app.assets.database.queries import list_cache_states_by_asset_id
"""Compute the relative filename for an asset from its best live reference path."""
from app.assets.database.queries import list_references_by_asset_id
from app.assets.helpers import select_best_live_path
primary_path = select_best_live_path(
list_cache_states_by_asset_id(session, asset_id=asset_id)
list_references_by_asset_id(session, asset_id=asset_id)
)
return compute_relative_filename(primary_path) if primary_path else None
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
"""Return (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_asset_category_and_relative_path`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
- name: base filename with extension
- tags: [root_category] + parent folder names in order
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
ValueError: path does not belong to any known root.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetInfo
from app.assets.database.models import Asset, AssetReference
UserMetadata = dict[str, Any] | None
@@ -15,9 +15,12 @@ class AssetData:
@dataclass(frozen=True)
class AssetInfoData:
class ReferenceData:
"""Data transfer object for AssetReference."""
id: str
name: str
file_path: str | None
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
@@ -27,14 +30,14 @@ class AssetInfoData:
@dataclass(frozen=True)
class AssetDetailResult:
info: AssetInfoData
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class RegisterAssetResult:
info: AssetInfoData
ref: ReferenceData
asset: AssetData
tags: list[str]
created: bool
@@ -44,9 +47,9 @@ class RegisterAssetResult:
class IngestResult:
asset_created: bool
asset_updated: bool
state_created: bool
state_updated: bool
asset_info_id: str | None
ref_created: bool
ref_updated: bool
reference_id: str | None
@dataclass(frozen=True)
@@ -78,7 +81,7 @@ class TagUsage(NamedTuple):
@dataclass(frozen=True)
class AssetSummaryData:
info: AssetInfoData
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@@ -98,21 +101,22 @@ class DownloadResolutionResult:
@dataclass(frozen=True)
class UploadResult:
info: AssetInfoData
ref: ReferenceData
asset: AssetData
tags: list[str]
created_new: bool
def extract_info_data(info: AssetInfo) -> AssetInfoData:
return AssetInfoData(
id=info.id,
name=info.name,
user_metadata=info.user_metadata,
preview_id=info.preview_id,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
def extract_reference_data(ref: AssetReference) -> ReferenceData:
return ReferenceData(
id=ref.id,
name=ref.name,
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,
)

View File

@@ -1,33 +1,33 @@
from app.assets.database.queries import (
add_tags_to_asset_info,
get_asset_info_by_id,
add_tags_to_reference,
get_reference_by_id,
list_tags_with_usage,
remove_tags_from_asset_info,
remove_tags_from_reference,
)
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
from app.database.db import create_session
def apply_tags(
asset_info_id: str,
reference_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> AddTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
ref_row = get_reference_by_id(session, reference_id=reference_id)
if not ref_row:
raise ValueError(f"AssetReference {reference_id} not found")
if ref_row.owner_id and ref_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
data = add_tags_to_reference(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
reference_row=ref_row,
)
session.commit()
@@ -39,20 +39,20 @@ def apply_tags(
def remove_tags(
asset_info_id: str,
reference_id: str,
tags: list[str],
owner_id: str = "",
) -> RemoveTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
ref_row = get_reference_by_id(session, reference_id=reference_id)
if not ref_row:
raise ValueError(f"AssetReference {reference_id} not found")
if ref_row.owner_id and ref_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
data = remove_tags_from_reference(
session,
asset_info_id=asset_info_id,
reference_id=reference_id,
tags=tags,
)
session.commit()