diff --git a/alembic_db/versions/0002_merge_to_asset_references.py b/alembic_db/versions/0002_merge_to_asset_references.py new file mode 100644 index 000000000..1ac1b980c --- /dev/null +++ b/alembic_db/versions/0002_merge_to_asset_references.py @@ -0,0 +1,267 @@ +""" +Merge AssetInfo and AssetCacheState into unified asset_references table. + +This migration drops old tables and creates the new unified schema. +All existing data is discarded. + +Revision ID: 0002_merge_to_asset_references +Revises: 0001_assets +Create Date: 2025-02-11 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0002_merge_to_asset_references" +down_revision = "0001_assets" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Drop old tables (order matters due to FK constraints) + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + # Truncate assets table (cascades handled by dropping dependent tables first) + op.execute("DELETE FROM assets") + + # Create asset_references table + op.create_table( + "asset_references", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "asset_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("file_path", sa.Text(), nullable=True), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column( + "needs_verify", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column( + "preview_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True), + sa.CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + sa.CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), + ) + op.create_index( + "uq_asset_references_file_path", "asset_references", ["file_path"], unique=True + ) + op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"]) + op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"]) + op.create_index("ix_asset_references_name", "asset_references", ["name"]) + op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"]) + op.create_index( + "ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"] + ) + op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"]) + op.create_index( + "ix_asset_references_last_access_time", "asset_references", ["last_access_time"] + ) + op.create_index( + "ix_asset_references_owner_name", "asset_references", ["owner_id", "name"] + ) + op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"]) + + # Create asset_reference_tags table + op.create_table( + "asset_reference_tags", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "tag_name", + sa.String(length=512), + sa.ForeignKey("tags.name", ondelete="RESTRICT"), + nullable=False, + ), + sa.Column( + "origin", sa.String(length=32), nullable=False, server_default="manual" + ), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint( + "asset_reference_id", "tag_name", name="pk_asset_reference_tags" + ), + ) + op.create_index( + "ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"] + ) + op.create_index( + "ix_asset_reference_tags_asset_reference_id", + "asset_reference_tags", + ["asset_reference_id"], + ) + + # Create asset_reference_meta table + op.create_table( + "asset_reference_meta", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint( + "asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta" + ), + ) + op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"]) + op.create_index( + "ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_bool", + "asset_reference_meta", + ["key", "val_bool"], + ) + + +def downgrade() -> None: + """Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema. + + NOTE: Data is not recoverable. The upgrade discards all rows from the old + tables and truncates assets. After downgrade the old schema will be empty. + A filesystem rescan will repopulate data once the older code is running. + """ + # Drop new tables (order matters due to FK constraints) + op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta") + op.drop_table("asset_reference_meta") + + op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags") + op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags") + op.drop_table("asset_reference_tags") + + op.drop_index("ix_asset_references_deleted_at", table_name="asset_references") + op.drop_index("ix_asset_references_owner_name", table_name="asset_references") + op.drop_index("ix_asset_references_last_access_time", table_name="asset_references") + op.drop_index("ix_asset_references_created_at", table_name="asset_references") + op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references") + op.drop_index("ix_asset_references_is_missing", table_name="asset_references") + op.drop_index("ix_asset_references_name", table_name="asset_references") + op.drop_index("ix_asset_references_owner_id", table_name="asset_references") + op.drop_index("ix_asset_references_asset_id", table_name="asset_references") + op.drop_index("uq_asset_references_file_path", table_name="asset_references") + op.drop_table("asset_references") + + # Truncate assets (upgrade deleted all rows; downgrade starts fresh too) + op.execute("DELETE FROM assets") + + # Recreate old tables from 0001_assets schema + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7676e50b4..40dee9f46 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -1,56 +1,144 @@ +import asyncio +import functools +import json import logging -import uuid -import urllib.parse import os -import contextlib -from aiohttp import web +import urllib.parse +import uuid +from typing import Any +from aiohttp import web from pydantic import ValidationError -import app.assets.manager as manager -from app import user_manager -from app.assets.api import schemas_in -from app.assets.helpers import get_query_dict -from app.assets.scanner import seed_assets - import folder_paths +from app import user_manager +from app.assets.api import schemas_in, schemas_out +from app.assets.api.schemas_in import ( + AssetValidationError, + UploadError, +) +from app.assets.helpers import validate_blake3_hash +from app.assets.api.upload import ( + delete_temp_file_if_exists, + parse_multipart_upload, +) +from app.assets.seeder import ScanInProgressError, asset_seeder +from app.assets.services import ( + DependencyMissingError, + HashMismatchError, + apply_tags, + asset_exists, + create_from_hash, + delete_asset_reference, + get_asset_detail, + list_assets_page, + list_tags, + remove_tags, + resolve_asset_for_download, + update_asset_metadata, + upload_from_temp_path, +) ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None +_ASSETS_ENABLED = False + + +def _require_assets_feature_enabled(handler): + @functools.wraps(handler) + async def wrapper(request: web.Request) -> web.Response: + if not _ASSETS_ENABLED: + return _build_error_response( + 503, + "SERVICE_DISABLED", + "Assets system is disabled. Start the server with --enable-assets to use this feature.", + ) + return await handler(request) + + return wrapper + # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" -# Note to any custom node developers reading this code: -# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. -def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: - global USER_MANAGER - USER_MANAGER = user_manager_instance +def get_query_dict(request: web.Request) -> dict[str, Any]: + """Gets a dictionary of query parameters from the request. + + request.query is a MultiMapping[str], needs to be converted to a dict + to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) + if len(request.query.getall(key)) > 1 + else request.query.get(key) + for key in request.query.keys() + } + return query_dict + + +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, +# do not rely on the code in /app/assets remaining the same. + + +def register_assets_routes( + app: web.Application, + user_manager_instance: user_manager.UserManager | None = None, +) -> None: + global USER_MANAGER, _ASSETS_ENABLED + if user_manager_instance is not None: + USER_MANAGER = user_manager_instance + _ASSETS_ENABLED = True app.add_routes(ROUTES) -def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: - return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + +def disable_assets_routes() -> None: + """Disable asset routes at runtime (e.g. after DB init failure).""" + global _ASSETS_ENABLED + _ASSETS_ENABLED = False -def _validation_error_response(code: str, ve: ValidationError) -> web.Response: - return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) +def _build_error_response( + status: int, code: str, message: str, details: dict | None = None +) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response: + errors = json.loads(ve.json()) + return _build_error_response(400, code, "Validation failed.", {"errors": errors}) + + +def _validate_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" @ROUTES.head("/api/assets/hash/{hash}") +@_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() - if not hash_str or ":" not in hash_str: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = hash_str.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - exists = manager.asset_exists(asset_hash=hash_str) + try: + hash_str = validate_blake3_hash(hash_str) + except ValueError: + return _build_error_response( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + exists = asset_exists(hash_str) return web.Response(status=200 if exists else 404) @ROUTES.get("/api/assets") -async def list_assets(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def list_assets_route(request: web.Request) -> web.Response: """ GET request to list assets. """ @@ -58,78 +146,140 @@ async def list_assets(request: web.Request) -> web.Response: try: q = schemas_in.ListAssetsQuery.model_validate(query_dict) except ValidationError as ve: - return _validation_error_response("INVALID_QUERY", ve) + return _build_validation_error_response("INVALID_QUERY", ve) - payload = manager.list_assets( + sort = _validate_sort_field(q.sort) + order_candidate = (q.order or "desc").lower() + order = order_candidate if order_candidate in {"asc", "desc"} else "desc" + + result = list_assets_page( + owner_id=USER_MANAGER.get_request_user_id(request), include_tags=q.include_tags, exclude_tags=q.exclude_tags, name_contains=q.name_contains, metadata_filter=q.metadata_filter, limit=q.limit, offset=q.offset, - sort=q.sort, - order=q.order, - owner_id=USER_MANAGER.get_request_user_id(request), + sort=sort, + order=order, + ) + + summaries = [ + schemas_out.AssetSummary( + id=item.ref.id, + name=item.ref.name, + asset_hash=item.asset.hash if item.asset else None, + size=int(item.asset.size_bytes) if item.asset else None, + mime_type=item.asset.mime_type if item.asset else None, + tags=item.tags, + created_at=item.ref.created_at, + updated_at=item.ref.updated_at, + last_access_time=item.ref.last_access_time, + ) + for item in result.items + ] + + payload = schemas_out.AssetsList( + assets=summaries, + total=result.total, + has_more=(q.offset + len(summaries)) < result.total, ) return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") -async def get_asset(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def get_asset_route(request: web.Request) -> web.Response: """ GET request to get an asset's info as JSON. """ - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - result = manager.get_asset( - asset_info_id=asset_info_id, + result = get_asset_detail( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), ) + if not result: + return _build_error_response( + 404, + "ASSET_NOT_FOUND", + f"AssetReference {reference_id} not found", + {"id": reference_id}, + ) + + payload = schemas_out.AssetDetail( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + size=int(result.asset.size_bytes) if result.asset else None, + mime_type=result.asset.mime_type if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + ) except ValueError as e: - return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} + ) except Exception: logging.exception( - "get_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "get_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +@_require_assets_feature_enabled async def download_asset_content(request: web.Request) -> web.Response: - # question: do we need disposition? could we just stick with one of these? disposition = request.query.get("disposition", "attachment").lower().strip() if disposition not in {"inline", "attachment"}: disposition = "attachment" try: - abs_path, content_type, filename = manager.resolve_asset_content_for_download( - asset_info_id=str(uuid.UUID(request.match_info["id"])), + result = resolve_asset_for_download( + reference_id=str(uuid.UUID(request.match_info["id"])), owner_id=USER_MANAGER.get_request_user_id(request), ) + abs_path = result.abs_path + content_type = result.content_type + filename = result.download_name except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + return _build_error_response(404, "ASSET_NOT_FOUND", str(ve)) except NotImplementedError as nie: - return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie)) except FileNotFoundError: - return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + return _build_error_response( + 404, "FILE_NOT_FOUND", "Underlying file not found on disk." + ) - quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") - cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + _DANGEROUS_MIME_TYPES = { + "text/html", "text/html-sandboxed", "application/xhtml+xml", + "text/javascript", "text/css", + } + if content_type in _DANGEROUS_MIME_TYPES: + content_type = "application/octet-stream" + + safe_name = (filename or "").replace("\r", "").replace("\n", "") + encoded = urllib.parse.quote(safe_name) + cd = f"{disposition}; filename*=UTF-8''{encoded}" file_size = os.path.getsize(abs_path) + size_mb = file_size / (1024 * 1024) logging.info( - "download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s", + "download_asset_content: path=%s, size=%d bytes (%.2f MB), type=%s, name=%s", abs_path, file_size, - file_size / (1024 * 1024), + size_mb, content_type, filename, ) - async def file_sender(): + async def stream_file_chunks(): chunk_size = 64 * 1024 with open(abs_path, "rb") as f: while True: @@ -139,26 +289,30 @@ async def download_asset_content(request: web.Request) -> web.Response: yield chunk return web.Response( - body=file_sender(), + body=stream_file_chunks(), content_type=content_type, headers={ "Content-Disposition": cd, "Content-Length": str(file_size), + "X-Content-Type-Options": "nosniff", }, ) @ROUTES.post("/api/assets/from-hash") -async def create_asset_from_hash(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def create_asset_from_hash_route(request: web.Request) -> web.Response: try: payload = await request.json() body = schemas_in.CreateFromHashBody.model_validate(payload) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) - result = manager.create_asset_from_hash( + result = create_from_hash( hash_str=body.hash, name=body.name, tags=body.tags, @@ -166,246 +320,209 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") - return web.json_response(result.model_dump(mode="json"), status=201) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" + ) + + payload_out = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + return web.json_response(payload_out.model_dump(mode="json"), status=201) @ROUTES.post("/api/assets") +@_require_assets_feature_enabled async def upload_asset(request: web.Request) -> web.Response: """Multipart/form-data endpoint for Asset uploads.""" - if not (request.content_type or "").lower().startswith("multipart/"): - return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") - - reader = await request.multipart() - - file_present = False - file_client_name: str | None = None - tags_raw: list[str] = [] - provided_name: str | None = None - user_metadata_raw: str | None = None - provided_hash: str | None = None - provided_hash_exists: bool | None = None - - file_written = 0 - tmp_path: str | None = None - while True: - field = await reader.next() - if field is None: - break - - fname = getattr(field, "name", "") or "" - - if fname == "hash": - try: - s = ((await field.text()) or "").strip().lower() - except Exception: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - - if s: - if ":" not in s: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - provided_hash = f"{algo}:{digest}" - try: - provided_hash_exists = manager.asset_exists(asset_hash=provided_hash) - except Exception: - provided_hash_exists = None # do not fail the whole request here - - elif fname == "file": - file_present = True - file_client_name = (field.filename or "").strip() - - if provided_hash and provided_hash_exists is True: - # If client supplied a hash that we know exists, drain but do not write to disk - try: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - file_written += len(chunk) - except Exception: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") - continue # Do not create temp file; we will create AssetInfo from the existing content - - # Otherwise, store to temp for hashing/ingest - uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") - unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) - os.makedirs(unique_dir, exist_ok=True) - tmp_path = os.path.join(unique_dir, ".upload.part") - - try: - with open(tmp_path, "wb") as f: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - f.write(chunk) - file_written += len(chunk) - except Exception: - try: - if os.path.exists(tmp_path or ""): - os.remove(tmp_path) - finally: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") - elif fname == "tags": - tags_raw.append((await field.text()) or "") - elif fname == "name": - provided_name = (await field.text()) or None - elif fname == "user_metadata": - user_metadata_raw = (await field.text()) or None - - # If client did not send file, and we are not doing a from-hash fast path -> error - if not file_present and not (provided_hash and provided_hash_exists): - return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") - - if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): - # Empty upload is only acceptable if we are fast-pathing from existing hash - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") - try: - spec = schemas_in.UploadAssetSpec.model_validate({ - "tags": tags_raw, - "name": provided_name, - "user_metadata": user_metadata_raw, - "hash": provided_hash, - }) - except ValidationError as ve: - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _validation_error_response("INVALID_BODY", ve) - - # Validate models category against configured folders (consistent with previous behavior) - if spec.tags and spec.tags[0] == "models": - if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - return _error_response( - 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" - ) + parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists) + except UploadError as e: + return _build_error_response(e.status, e.code, e.message) owner_id = USER_MANAGER.get_request_user_id(request) - # Fast path: if a valid provided hash exists, create AssetInfo without writing anything - if spec.hash and provided_hash_exists is True: - try: - result = manager.create_asset_from_hash( + try: + spec = schemas_in.UploadAssetSpec.model_validate( + { + "tags": parsed.tags_raw, + "name": parsed.provided_name, + "user_metadata": parsed.user_metadata_raw, + "hash": parsed.provided_hash, + } + ) + except ValidationError as ve: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 400, "INVALID_BODY", f"Validation failed: {ve.json()}" + ) + + if spec.tags and spec.tags[0] == "models": + if ( + len(spec.tags) < 2 + or spec.tags[1] not in folder_paths.folder_names_and_paths + ): + delete_temp_file_if_exists(parsed.tmp_path) + category = spec.tags[1] if len(spec.tags) >= 2 else "" + return _build_error_response( + 400, "INVALID_BODY", f"unknown models category '{category}'" + ) + + try: + # Fast path: hash exists, create AssetReference without writing anything + if spec.hash and parsed.provided_hash_exists is True: + result = create_from_hash( hash_str=spec.hash, name=spec.name or (spec.hash.split(":", 1)[1]), tags=spec.tags, user_metadata=spec.user_metadata or {}, owner_id=owner_id, ) - except Exception: - logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + if result is None: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist" + ) + delete_temp_file_if_exists(parsed.tmp_path) + else: + # Otherwise, we must have a temp file path to ingest + if not parsed.tmp_path or not os.path.exists(parsed.tmp_path): + return _build_error_response( + 400, + "MISSING_INPUT", + "Provided hash not found and no file uploaded.", + ) - if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") - - # Drain temp if we accidentally saved (e.g., hash field came after file) - if tmp_path and os.path.exists(tmp_path): - with contextlib.suppress(Exception): - os.remove(tmp_path) - - status = 200 if (not result.created_new) else 201 - return web.json_response(result.model_dump(mode="json"), status=status) - - # Otherwise, we must have a temp file path to ingest - if not tmp_path or not os.path.exists(tmp_path): - # The only case we reach here without a temp file is: client sent a hash that does not exist and no file - return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") - - try: - created = manager.upload_asset_from_temp_path( - spec, - temp_path=tmp_path, - client_filename=file_client_name, - owner_id=owner_id, - expected_asset_hash=spec.hash, - ) - status = 201 if created.created_new else 200 - return web.json_response(created.model_dump(mode="json"), status=status) - except ValueError as e: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - msg = str(e) - if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": - return _error_response( - 400, - "HASH_MISMATCH", - "Uploaded file hash does not match provided hash.", + result = upload_from_temp_path( + temp_path=parsed.tmp_path, + name=spec.name, + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + client_filename=parsed.file_client_name, + owner_id=owner_id, + expected_hash=spec.hash, ) - return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except AssetValidationError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, e.code, str(e)) + except ValueError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "BAD_REQUEST", str(e)) + except HashMismatchError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "HASH_MISMATCH", str(e)) + except DependencyMissingError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(503, "DEPENDENCY_MISSING", e.message) except Exception: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + delete_temp_file_if_exists(parsed.tmp_path) + logging.exception("upload_asset failed for owner_id=%s", owner_id) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + payload = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + status = 201 if result.created_new else 200 + return web.json_response(payload.model_dump(mode="json"), status=status) @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") -async def update_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) +@_require_assets_feature_enabled +async def update_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) try: body = schemas_in.UpdateAssetBody.model_validate(await request.json()) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.update_asset( - asset_info_id=asset_info_id, + result = update_asset_metadata( + reference_id=reference_id, name=body.name, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.AssetUpdated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + updated_at=result.ref.updated_at, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "update_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "update_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") -async def delete_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) - delete_content = request.query.get("delete_content") - delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} +@_require_assets_feature_enabled +async def delete_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + delete_content_param = request.query.get("delete_content") + delete_content = ( + False + if delete_content_param is None + else delete_content_param.lower() not in {"0", "false", "no"} + ) try: - deleted = manager.delete_asset_reference( - asset_info_id=asset_info_id, + deleted = delete_asset_reference( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), delete_content_if_orphan=delete_content, ) except Exception: logging.exception( - "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "delete_asset_reference failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") if not deleted: - return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"AssetReference {reference_id} not found." + ) return web.Response(status=204) @ROUTES.get("/api/tags") +@_require_assets_feature_enabled async def get_tags(request: web.Request) -> web.Response: """ GET request to list all tags based on query parameters. @@ -415,12 +532,14 @@ async def get_tags(request: web.Request) -> web.Response: try: query = schemas_in.TagsListQuery.model_validate(query_map) except ValidationError as e: - return web.json_response( - {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, - status=400, + return _build_error_response( + 400, + "INVALID_QUERY", + "Invalid query parameters", + {"errors": json.loads(e.json())}, ) - result = manager.list_tags( + rows, total = list_tags( prefix=query.prefix, limit=query.limit, offset=query.offset, @@ -428,87 +547,212 @@ async def get_tags(request: web.Request) -> web.Response: include_zero=query.include_zero, owner_id=USER_MANAGER.get_request_user_id(request), ) - return web.json_response(result.model_dump(mode="json")) + + tags = [ + schemas_out.TagUsage(name=name, count=count, type=tag_type) + for (name, tag_type, count) in rows + ] + payload = schemas_out.TagsList( + tags=tags, total=total, has_more=(query.offset + len(tags)) < total + ) + return web.json_response(payload.model_dump(mode="json")) @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def add_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsAdd.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsAdd.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags add.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.add_tags_to_asset( - asset_info_id=asset_info_id, + result = apply_tags( + reference_id=reference_id, tags=data.tags, origin="manual", owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.TagsAdd( + added=result.added, + already_present=result.already_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "add_tags_to_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def delete_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsRemove.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsRemove.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags remove.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.remove_tags_from_asset( - asset_info_id=asset_info_id, + result = remove_tags( + reference_id=reference_id, tags=data.tags, owner_id=USER_MANAGER.get_request_user_id(request), ) + payload = schemas_out.TagsRemove( + removed=result.removed, + not_present=result.not_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "remove_tags_from_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.post("/api/assets/seed") -async def seed_assets_endpoint(request: web.Request) -> web.Response: - """Trigger asset seeding for specified roots (models, input, output).""" +@_require_assets_feature_enabled +async def seed_assets(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output). + + Query params: + wait: If "true", block until scan completes (synchronous behavior for tests) + + Returns: + 202 Accepted if scan started + 409 Conflict if scan already running + 200 OK with final stats if wait=true + """ try: payload = await request.json() roots = payload.get("roots", ["models", "input", "output"]) except Exception: roots = ["models", "input", "output"] - valid_roots = [r for r in roots if r in ("models", "input", "output")] + valid_roots = tuple(r for r in roots if r in ("models", "input", "output")) if not valid_roots: - return _error_response(400, "INVALID_BODY", "No valid roots specified") + return _build_error_response(400, "INVALID_BODY", "No valid roots specified") + wait_param = request.query.get("wait", "").lower() + should_wait = wait_param in ("true", "1", "yes") + + started = asset_seeder.start(roots=valid_roots) + if not started: + return web.json_response({"status": "already_running"}, status=409) + + if should_wait: + await asyncio.to_thread(asset_seeder.wait) + status = asset_seeder.get_status() + return web.json_response( + { + "status": "completed", + "progress": { + "scanned": status.progress.scanned if status.progress else 0, + "total": status.progress.total if status.progress else 0, + "created": status.progress.created if status.progress else 0, + "skipped": status.progress.skipped if status.progress else 0, + }, + "errors": status.errors, + }, + status=200, + ) + + return web.json_response({"status": "started"}, status=202) + + +@ROUTES.get("/api/assets/seed/status") +@_require_assets_feature_enabled +async def get_seed_status(request: web.Request) -> web.Response: + """Get current scan status and progress.""" + status = asset_seeder.get_status() + return web.json_response( + { + "state": status.state.value, + "progress": { + "scanned": status.progress.scanned, + "total": status.progress.total, + "created": status.progress.created, + "skipped": status.progress.skipped, + } + if status.progress + else None, + "errors": status.errors, + }, + status=200, + ) + + +@ROUTES.post("/api/assets/seed/cancel") +@_require_assets_feature_enabled +async def cancel_seed(request: web.Request) -> web.Response: + """Request cancellation of in-progress scan.""" + cancelled = asset_seeder.cancel() + if cancelled: + return web.json_response({"status": "cancelling"}, status=200) + return web.json_response({"status": "idle"}, status=200) + + +@ROUTES.post("/api/assets/prune") +@_require_assets_feature_enabled +async def mark_missing_assets(request: web.Request) -> web.Response: + """Mark assets as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and metadata + are preserved, but references are flagged as missing. They can be + restored if the file reappears in a future scan. + + Returns: + 200 OK with count of marked assets + 409 Conflict if a scan is currently running + """ try: - seed_assets(tuple(valid_roots)) - except Exception: - logging.exception("seed_assets failed for roots=%s", valid_roots) - return _error_response(500, "INTERNAL", "Seed operation failed") - - return web.json_response({"seeded": valid_roots}, status=200) + marked = asset_seeder.mark_missing_outside_prefixes() + except ScanInProgressError: + return web.json_response( + {"status": "scan_running", "marked": 0}, + status=409, + ) + return web.json_response({"status": "completed", "marked": marked}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 6707ffb0c..d255c938e 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,6 +1,8 @@ import json +from dataclasses import dataclass from typing import Any, Literal +from app.assets.helpers import validate_blake3_hash from pydantic import ( BaseModel, ConfigDict, @@ -10,6 +12,41 @@ from pydantic import ( model_validator, ) + +class UploadError(Exception): + """Error during upload parsing with HTTP status and code.""" + + def __init__(self, status: int, code: str, message: str): + super().__init__(message) + self.status = status + self.code = code + self.message = message + + +class AssetValidationError(Exception): + """Validation error in asset processing (invalid tags, metadata, etc.).""" + + def __init__(self, code: str, message: str): + super().__init__(message) + self.code = code + self.message = message + + +@dataclass +class ParsedUpload: + """Result of parsing a multipart upload request.""" + + file_present: bool + file_written: int + file_client_name: str | None + tmp_path: str | None + tags_raw: list[str] + provided_name: str | None + user_metadata_raw: str | None + provided_hash: str | None + provided_hash_exists: bool | None + + class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) @@ -21,7 +58,9 @@ class ListAssetsQuery(BaseModel): limit: conint(ge=1, le=500) = 20 offset: conint(ge=0) = 0 - sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( + "created_at" + ) order: Literal["asc", "desc"] = "desc" @field_validator("include_tags", "exclude_tags", mode="before") @@ -61,7 +100,7 @@ class UpdateAssetBody(BaseModel): user_metadata: dict[str, Any] | None = None @model_validator(mode="after") - def _at_least_one(self): + def _validate_at_least_one_field(self): if self.name is None and self.user_metadata is None: raise ValueError("Provide at least one of: name, user_metadata.") return self @@ -78,19 +117,11 @@ class CreateFromHashBody(BaseModel): @field_validator("hash") @classmethod def _require_blake3(cls, v): - s = (v or "").strip().lower() - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return s + return validate_blake3_hash(v or "") @field_validator("tags", mode="before") @classmethod - def _tags_norm(cls, v): + def _normalize_tags_field(cls, v): if v is None: return [] if isinstance(v, list): @@ -154,15 +185,16 @@ class TagsRemove(TagsAdd): class UploadAssetSpec(BaseModel): """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); - if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + if root == 'models', second must be a valid category - name: display name - user_metadata: arbitrary JSON object (optional) - - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + - hash: optional canonical 'blake3:' for validation / fast-path - Files created via this endpoint are stored on disk using the **content hash** as the filename stem - and the original extension is preserved when available. + Files are stored using the content hash as filename stem. """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) tags: list[str] = Field(..., min_length=1) @@ -175,17 +207,10 @@ class UploadAssetSpec(BaseModel): def _parse_hash(cls, v): if v is None: return None - s = str(v).strip().lower() + s = str(v).strip() if not s: return None - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return f"{algo}:{digest}" + return validate_blake3_hash(s) @field_validator("tags", mode="before") @classmethod @@ -260,5 +285,7 @@ class UploadAssetSpec(BaseModel): raise ValueError("first tag must be one of: models, input, output") if root == "models": if len(self.tags) < 2: - raise ValueError("models uploads require a category tag as the second tag") + raise ValueError( + "models uploads require a category tag as the second tag" + ) return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index b6fb3da0c..f36447856 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -19,7 +19,7 @@ class AssetSummary(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "updated_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -40,7 +40,7 @@ class AssetUpdated(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("updated_at") - def _ser_updated(self, v: datetime | None, _info): + def _serialize_updated_at(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -59,7 +59,7 @@ class AssetDetail(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py new file mode 100644 index 000000000..721c12f4d --- /dev/null +++ b/app/assets/api/upload.py @@ -0,0 +1,171 @@ +import logging +import os +import uuid +from typing import Callable + +from aiohttp import web + +import folder_paths +from app.assets.api.schemas_in import ParsedUpload, UploadError +from app.assets.helpers import validate_blake3_hash + + +def normalize_and_validate_hash(s: str) -> str: + """Validate and normalize a hash string. + + Returns canonical 'blake3:' or raises UploadError. + """ + try: + return validate_blake3_hash(s) + except ValueError: + raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") + + +async def parse_multipart_upload( + request: web.Request, + check_hash_exists: Callable[[str], bool], +) -> ParsedUpload: + """ + Parse a multipart/form-data upload request. + + Args: + request: The aiohttp request + check_hash_exists: Callable(hash_str) -> bool to check if a hash exists + + Returns: + ParsedUpload with parsed fields and temp file path + + Raises: + UploadError: On validation or I/O errors + """ + if not (request.content_type or "").lower().startswith("multipart/"): + raise UploadError( + 415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads." + ) + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + + file_written = 0 + tmp_path: str | None = None + + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + raise UploadError( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + + if s: + provided_hash = normalize_and_validate_hash(s) + try: + provided_hash_exists = check_hash_exists(provided_hash) + except Exception as e: + logging.exception( + "check_hash_exists failed for hash=%s: %s", provided_hash, e + ) + raise UploadError( + 500, + "HASH_CHECK_FAILED", + "Backend error while checking asset hash.", + ) + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # Hash exists - drain file but don't write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file." + ) + continue + + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + delete_temp_file_if_exists(tmp_path) + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file." + ) + + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + if not file_present and not (provided_hash and provided_hash_exists): + raise UploadError( + 400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'." + ) + + if ( + file_present + and file_written == 0 + and not (provided_hash and provided_hash_exists) + ): + delete_temp_file_if_exists(tmp_path) + raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + return ParsedUpload( + file_present=file_present, + file_written=file_written, + file_client_name=file_client_name, + tmp_path=tmp_path, + tags_raw=tags_raw, + provided_name=provided_name, + user_metadata_raw=user_metadata_raw, + provided_hash=provided_hash, + provided_hash_exists=provided_hash_exists, + ) + + +def delete_temp_file_if_exists(tmp_path: str | None) -> None: + """Safely remove a temp file and its parent directory if empty.""" + if tmp_path: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError as e: + logging.debug("Failed to delete temp file %s: %s", tmp_path, e) + try: + parent = os.path.dirname(tmp_path) + if parent and os.path.isdir(parent): + os.rmdir(parent) # only succeeds if empty + except OSError: + pass diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py deleted file mode 100644 index c7b75290a..000000000 --- a/app/assets/database/bulk_ops.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import uuid -import sqlalchemy -from typing import Iterable -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import utcnow -from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta - -MAX_BIND_PARAMS = 800 - -def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: - if not rows: - return [] - rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) - for i in range(0, len(rows), rows_per_stmt): - yield rows[i:i + rows_per_stmt] - -def _iter_chunks(seq, n: int): - for i in range(0, len(seq), n): - yield seq[i:i + n] - -def _rows_per_stmt(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def seed_from_paths_batch( - session: Session, - *, - specs: list[dict], - owner_id: str = "", -) -> dict: - """Each spec is a dict with keys: - - abs_path: str - - size_bytes: int - - mtime_ns: int - - info_name: str - - tags: list[str] - - fname: Optional[str] - """ - if not specs: - return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} - - now = utcnow() - asset_rows: list[dict] = [] - state_rows: list[dict] = [] - path_to_asset: dict[str, str] = {} - asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row - path_list: list[str] = [] - - for sp in specs: - ap = os.path.abspath(sp["abs_path"]) - aid = str(uuid.uuid4()) - iid = str(uuid.uuid4()) - path_list.append(ap) - path_to_asset[ap] = aid - - asset_rows.append( - { - "id": aid, - "hash": None, - "size_bytes": sp["size_bytes"], - "mime_type": None, - "created_at": now, - } - ) - state_rows.append( - { - "asset_id": aid, - "file_path": ap, - "mtime_ns": sp["mtime_ns"], - } - ) - asset_to_info[aid] = { - "id": iid, - "owner_id": owner_id, - "name": sp["info_name"], - "asset_id": aid, - "preview_id": None, - "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, - "created_at": now, - "updated_at": now, - "last_access_time": now, - "_tags": sp["tags"], - "_filename": sp["fname"], - } - - # insert all seed Assets (hash=NULL) - ins_asset = sqlite.insert(Asset) - for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): - session.execute(ins_asset, chunk) - - # try to claim AssetCacheState (file_path) - # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted - ins_state = ( - sqlite.insert(AssetCacheState) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): - session.execute(ins_state, chunk) - - # Query to find which of our paths won (were actually inserted) - winners_by_path: set[str] = set() - for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetCacheState.file_path) - .where(AssetCacheState.file_path.in_(chunk)) - .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk])) - ) - winners_by_path.update(result.scalars().all()) - - all_paths_set = set(path_list) - losers_by_path = all_paths_set - winners_by_path - lost_assets = [path_to_asset[p] for p in losers_by_path] - if lost_assets: # losers get their Asset removed - for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): - session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk))) - - if not winners_by_path: - return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} - - # insert AssetInfo only for winners - # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted - winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] - ins_info = ( - sqlite.insert(AssetInfo) - .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - ) - for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): - session.execute(ins_info, chunk) - - # Query to find which info rows were actually inserted (by matching our generated IDs) - all_info_ids = [row["id"] for row in winner_info_rows] - inserted_info_ids: set[str] = set() - for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) - ) - inserted_info_ids.update(result.scalars().all()) - - # build and insert tag + meta rows for the AssetInfo - tag_rows: list[dict] = [] - meta_rows: list[dict] = [] - if inserted_info_ids: - for row in winner_info_rows: - iid = row["id"] - if iid not in inserted_info_ids: - continue - for t in row["_tags"]: - tag_rows.append({ - "asset_info_id": iid, - "tag_name": t, - "origin": "automatic", - "added_at": now, - }) - if row["_filename"]: - meta_rows.append( - { - "asset_info_id": iid, - "key": "filename", - "ordinal": 0, - "val_str": row["_filename"], - "val_num": None, - "val_bool": None, - "val_json": None, - } - ) - - bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) - return { - "inserted_infos": len(inserted_info_ids), - "won_states": len(winners_by_path), - "lost_states": len(losers_by_path), - } - - -def bulk_insert_tags_and_meta( - session: Session, - *, - tag_rows: list[dict], - meta_rows: list[dict], - max_bind_params: int, -) -> None: - """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. - - tag_rows keys: asset_info_id, tag_name, origin, added_at - - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json - """ - if tag_rows: - ins_links = ( - sqlite.insert(AssetInfoTag) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): - session.execute(ins_links, chunk) - if meta_rows: - ins_meta = ( - sqlite.insert(AssetInfoMeta) - .on_conflict_do_nothing( - index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] - ) - ) - for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): - session.execute(ins_meta, chunk) diff --git a/app/assets/database/models.py b/app/assets/database/models.py index 3cd28f68b..03c1c1707 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -2,8 +2,8 @@ from __future__ import annotations import uuid from datetime import datetime - from typing import Any + from sqlalchemy import ( JSON, BigInteger, @@ -16,102 +16,102 @@ from sqlalchemy import ( Numeric, String, Text, - UniqueConstraint, ) from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship -from app.assets.helpers import utcnow -from app.database.models import to_dict, Base +from app.assets.helpers import get_utc_now +from app.database.models import Base class Asset(Base): __tablename__ = "assets" - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) hash: Mapped[str | None] = mapped_column(String(256), nullable=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - infos: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + references: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), - foreign_keys=lambda: [AssetInfo.asset_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id), + foreign_keys=lambda: [AssetReference.asset_id], cascade="all,delete-orphan", passive_deletes=True, ) - preview_of: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + preview_of: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="preview_asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), - foreign_keys=lambda: [AssetInfo.preview_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id), + foreign_keys=lambda: [AssetReference.preview_id], viewonly=True, ) - cache_states: Mapped[list[AssetCacheState]] = relationship( - back_populates="asset", - cascade="all, delete-orphan", - passive_deletes=True, - ) - __table_args__ = ( Index("uq_assets_hash", "hash", unique=True), Index("ix_assets_mime_type", "mime_type"), CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) - def __repr__(self) -> str: return f"" -class AssetCacheState(Base): - __tablename__ = "asset_cache_state" +class AssetReference(Base): + """Unified model combining file cache state and user-facing metadata. - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) - file_path: Mapped[str] = mapped_column(Text, nullable=False) - mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + Each row represents either: + - A filesystem reference (file_path is set) with cache state + - An API-created reference (file_path is NULL) without cache state + """ - asset: Mapped[Asset] = relationship(back_populates="cache_states") + __tablename__ = "asset_references" - __table_args__ = ( - Index("ix_asset_cache_state_file_path", "file_path"), - Index("ix_asset_cache_state_asset_id", "asset_id"), - CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), - UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + asset_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) + # Cache state fields (from former AssetCacheState) + file_path: Mapped[str | None] = mapped_column(Text, nullable=True) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - def __repr__(self) -> str: - return f"" - - -class AssetInfo(Base): - __tablename__ = "assets_info" - - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + # Info fields (from former AssetInfo) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) - preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + preview_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="SET NULL") + ) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True) + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + deleted_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=False), nullable=True, default=None + ) asset: Mapped[Asset] = relationship( "Asset", - back_populates="infos", + back_populates="references", foreign_keys=[asset_id], lazy="selectin", ) @@ -121,51 +121,59 @@ class AssetInfo(Base): foreign_keys=[preview_id], ) - metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( - back_populates="asset_info", + metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, ) - tag_links: Mapped[list[AssetInfoTag]] = relationship( - back_populates="asset_info", + tag_links: Mapped[list[AssetReferenceTag]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, - overlaps="tags,asset_infos", + overlaps="tags,asset_references", ) tags: Mapped[list[Tag]] = relationship( - secondary="asset_info_tags", - back_populates="asset_infos", + secondary="asset_reference_tags", + back_populates="asset_references", lazy="selectin", viewonly=True, - overlaps="tag_links,asset_info_links,asset_infos,tag", + overlaps="tag_links,asset_reference_links,asset_references,tag", ) __table_args__ = ( - UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), - Index("ix_assets_info_owner_name", "owner_id", "name"), - Index("ix_assets_info_owner_id", "owner_id"), - Index("ix_assets_info_asset_id", "asset_id"), - Index("ix_assets_info_name", "name"), - Index("ix_assets_info_created_at", "created_at"), - Index("ix_assets_info_last_access_time", "last_access_time"), + Index("uq_asset_references_file_path", "file_path", unique=True), + Index("ix_asset_references_asset_id", "asset_id"), + Index("ix_asset_references_owner_id", "owner_id"), + Index("ix_asset_references_name", "name"), + Index("ix_asset_references_is_missing", "is_missing"), + Index("ix_asset_references_enrichment_level", "enrichment_level"), + Index("ix_asset_references_created_at", "created_at"), + Index("ix_asset_references_last_access_time", "last_access_time"), + Index("ix_asset_references_deleted_at", "deleted_at"), + Index("ix_asset_references_owner_name", "owner_id", "name"), + CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - data = to_dict(self, include_none=include_none) - data["tags"] = [t.name for t in self.tags] - return data - def __repr__(self) -> str: - return f"" + path_part = f" path={self.file_path!r}" if self.file_path else "" + return f"" -class AssetInfoMeta(Base): - __tablename__ = "asset_info_meta" +class AssetReferenceMeta(Base): + __tablename__ = "asset_reference_meta" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) key: Mapped[str] = mapped_column(String(256), primary_key=True) ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) @@ -175,36 +183,40 @@ class AssetInfoMeta(Base): val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) - asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") + asset_reference: Mapped[AssetReference] = relationship( + back_populates="metadata_entries" + ) __table_args__ = ( - Index("ix_asset_info_meta_key", "key"), - Index("ix_asset_info_meta_key_val_str", "key", "val_str"), - Index("ix_asset_info_meta_key_val_num", "key", "val_num"), - Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + Index("ix_asset_reference_meta_key", "key"), + Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), + Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), + Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), ) -class AssetInfoTag(Base): - __tablename__ = "asset_info_tags" +class AssetReferenceTag(Base): + __tablename__ = "asset_reference_tags" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) tag_name: Mapped[str] = mapped_column( String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") added_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") - tag: Mapped[Tag] = relationship(back_populates="asset_info_links") + asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_reference_links") __table_args__ = ( - Index("ix_asset_info_tags_tag_name", "tag_name"), - Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + Index("ix_asset_reference_tags_tag_name", "tag_name"), + Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"), ) @@ -214,20 +226,18 @@ class Tag(Base): name: Mapped[str] = mapped_column(String(512), primary_key=True) tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") - asset_info_links: Mapped[list[AssetInfoTag]] = relationship( + asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship( back_populates="tag", - overlaps="asset_infos,tags", + overlaps="asset_references,tags", ) - asset_infos: Mapped[list[AssetInfo]] = relationship( - secondary="asset_info_tags", + asset_references: Mapped[list[AssetReference]] = relationship( + secondary="asset_reference_tags", back_populates="tags", viewonly=True, - overlaps="asset_info_links,tag_links,tags,asset_info", + overlaps="asset_reference_links,tag_links,tags,asset_reference", ) - __table_args__ = ( - Index("ix_tags_tag_type", "tag_type"), - ) + __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) def __repr__(self) -> str: return f"" diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py deleted file mode 100644 index d6b33ec7b..000000000 --- a/app/assets/database/queries.py +++ /dev/null @@ -1,976 +0,0 @@ -import os -import logging -import sqlalchemy as sa -from collections import defaultdict -from datetime import datetime -from typing import Iterable, Any -from sqlalchemy import select, delete, exists, func -from sqlalchemy.dialects import sqlite -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, contains_eager, noload -from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag -from app.assets.helpers import ( - compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow -) -from typing import Sequence - - -def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - - -def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: - """ - Return the best on-disk path among cache states: - 1) Prefer a path that exists with needs_verify == False (already verified). - 2) Otherwise, pick the first path that exists. - 3) Otherwise return empty string. - """ - alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] - if not alive: - return "" - for s in alive: - if not getattr(s, "needs_verify", False): - return s.file_path - return alive[0].file_path - - -def apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: dict | None = None, -) -> sa.sql.Select: - """Apply filters using asset_info_meta projection table.""" - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - if value is None: - no_row_for_key = sa.not_( - sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetInfoMeta.val_json.is_(None), - AssetInfoMeta.val_str.is_(None), - AssetInfoMeta.val_num.is_(None), - AssetInfoMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float)): - from decimal import Decimal - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt - - -def asset_exists_by_hash( - session: Session, - *, - asset_hash: str, -) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - row = ( - session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).first() - return row is not None - - -def asset_info_exists_for_asset_id( - session: Session, - *, - asset_id: str, -) -> bool: - q = ( - select(sa.literal(True)) - .select_from(AssetInfo) - .where(AssetInfo.asset_id == asset_id) - .limit(1) - ) - return (session.execute(q)).first() is not None - - -def get_asset_by_hash( - session: Session, - *, - asset_hash: str, -) -> Asset | None: - return ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - - -def get_asset_info_by_id( - session: Session, - *, - asset_info_id: str, -) -> AssetInfo | None: - return session.get(AssetInfo, asset_info_id) - - -def list_asset_infos_page( - session: Session, - owner_id: str = "", - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", -) -> tuple[list[AssetInfo], dict[str, list[str]], int]: - base = ( - select(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) - .where(visible_owner_clause(owner_id)) - ) - - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - - base = apply_tag_filters(base, include_tags, exclude_tags) - base = apply_metadata_filter(base, metadata_filter) - - sort = (sort or "created_at").lower() - order = (order or "desc").lower() - sort_map = { - "name": AssetInfo.name, - "created_at": AssetInfo.created_at, - "updated_at": AssetInfo.updated_at, - "last_access_time": AssetInfo.last_access_time, - "size": Asset.size_bytes, - } - sort_col = sort_map.get(sort, AssetInfo.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() - - base = base.order_by(sort_exp).limit(limit).offset(offset) - - count_stmt = ( - select(sa.func.count()) - .select_from(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where(visible_owner_clause(owner_id)) - ) - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = apply_metadata_filter(count_stmt, metadata_filter) - - total = int((session.execute(count_stmt)).scalar_one() or 0) - - infos = (session.execute(base)).unique().scalars().all() - - id_list: list[str] = [i.id for i in infos] - tag_map: dict[str, list[str]] = defaultdict(list) - if id_list: - rows = session.execute( - select(AssetInfoTag.asset_info_id, Tag.name) - .join(Tag, Tag.name == AssetInfoTag.tag_name) - .where(AssetInfoTag.asset_info_id.in_(id_list)) - .order_by(AssetInfoTag.added_at) - ) - for aid, tag_name in rows.all(): - tag_map[aid].append(tag_name) - - return infos, tag_map, total - - -def fetch_asset_info_asset_and_tags( - session: Session, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset, list[str]] | None: - stmt = ( - select(AssetInfo, Asset, Tag.name) - .join(Asset, Asset.id == AssetInfo.asset_id) - .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) - .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .options(noload(AssetInfo.tags)) - .order_by(Tag.name.asc()) - ) - - rows = (session.execute(stmt)).all() - if not rows: - return None - - first_info, first_asset, _ = rows[0] - tags: list[str] = [] - seen: set[str] = set() - for _info, _asset, tag_name in rows: - if tag_name and tag_name not in seen: - seen.add(tag_name) - tags.append(tag_name) - return first_info, first_asset, tags - - -def fetch_asset_info_and_asset( - session: Session, - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset] | None: - stmt = ( - select(AssetInfo, Asset) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .limit(1) - .options(noload(AssetInfo.tags)) - ) - row = session.execute(stmt) - pair = row.first() - if not pair: - return None - return pair[0], pair[1] - -def list_cache_states_by_asset_id( - session: Session, *, asset_id: str -) -> Sequence[AssetCacheState]: - return ( - session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_id == asset_id) - .order_by(AssetCacheState.id.asc()) - ) - ).scalars().all() - - -def touch_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - ts: datetime | None = None, - only_if_newer: bool = True, -) -> None: - ts = ts or utcnow() - stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) - if only_if_newer: - stmt = stmt.where( - sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) - ) - session.execute(stmt.values(last_access_time=ts)) - - -def create_asset_info_for_existing_asset( - session: Session, - *, - asset_hash: str, - name: str, - user_metadata: dict | None = None, - tags: Sequence[str] | None = None, - tag_origin: str = "manual", - owner_id: str = "", -) -> AssetInfo: - """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" - now = utcnow() - asset = get_asset_by_hash(session, asset_hash=asset_hash) - if not asset: - raise ValueError(f"Unknown asset hash {asset_hash}") - - info = AssetInfo( - owner_id=owner_id, - name=name, - asset_id=asset.id, - preview_id=None, - created_at=now, - updated_at=now, - last_access_time=now, - ) - try: - with session.begin_nested(): - session.add(info) - session.flush() - except IntegrityError: - existing = ( - session.execute( - select(AssetInfo) - .options(noload(AssetInfo.tags)) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == name, - AssetInfo.owner_id == owner_id, - ) - .limit(1) - ) - ).unique().scalars().first() - if not existing: - raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") - return existing - - # metadata["filename"] hack - new_meta = dict(user_metadata or {}) - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - if computed_filename: - new_meta["filename"] = computed_filename - if new_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=info.id, - user_metadata=new_meta, - ) - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=info.id, - tags=tags, - origin=tag_origin, - ) - return info - - -def set_asset_info_tags( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", -) -> dict: - desired = normalize_tags(tags) - - current = set( - tag_name for (tag_name,) in ( - session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) - ).all() - ) - - to_add = [t for t in desired if t not in current] - to_remove = [t for t in current if t not in desired] - - if to_add: - ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) - for t in to_add - ]) - session.flush() - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) - ) - session.flush() - - return {"added": to_add, "removed": to_remove, "total": desired} - - -def replace_asset_info_metadata_projection( - session: Session, - *, - asset_info_id: str, - user_metadata: dict | None = None, -) -> None: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info.user_metadata = user_metadata or {} - info.updated_at = utcnow() - session.flush() - - session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) - session.flush() - - if not user_metadata: - return - - rows: list[AssetInfoMeta] = [] - for k, v in user_metadata.items(): - for r in project_kv(k, v): - rows.append( - AssetInfoMeta( - asset_info_id=asset_info_id, - key=r["key"], - ordinal=int(r["ordinal"]), - val_str=r.get("val_str"), - val_num=r.get("val_num"), - val_bool=r.get("val_bool"), - val_json=r.get("val_json"), - ) - ) - if rows: - session.add_all(rows) - session.flush() - - -def ingest_fs_asset( - session: Session, - *, - asset_hash: str, - abs_path: str, - size_bytes: int, - mtime_ns: int, - mime_type: str | None = None, - info_name: str | None = None, - owner_id: str = "", - preview_id: str | None = None, - user_metadata: dict | None = None, - tags: Sequence[str] = (), - tag_origin: str = "manual", - require_existing_tags: bool = False, -) -> dict: - """ - Idempotently upsert: - - Asset by content hash (create if missing) - - AssetCacheState(file_path) pointing to asset_id - - Optionally AssetInfo + tag links and metadata projection - Returns flags and ids. - """ - locator = os.path.abspath(abs_path) - now = utcnow() - - if preview_id: - if not session.get(Asset, preview_id): - preview_id = None - - out: dict[str, Any] = { - "asset_created": False, - "asset_updated": False, - "state_created": False, - "state_updated": False, - "asset_info_id": None, - } - - # 1) Asset by hash - asset = ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - if not asset: - vals = { - "hash": asset_hash, - "size_bytes": int(size_bytes), - "mime_type": mime_type, - "created_at": now, - } - res = session.execute( - sqlite.insert(Asset) - .values(**vals) - .on_conflict_do_nothing(index_elements=[Asset.hash]) - ) - if int(res.rowcount or 0) > 0: - out["asset_created"] = True - asset = ( - session.execute( - select(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).scalars().first() - if not asset: - raise RuntimeError("Asset row not found after upsert.") - else: - changed = False - if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: - asset.size_bytes = int(size_bytes) - changed = True - if mime_type and asset.mime_type != mime_type: - asset.mime_type = mime_type - changed = True - if changed: - out["asset_updated"] = True - - # 2) AssetCacheState upsert by file_path (unique) - vals = { - "asset_id": asset.id, - "file_path": locator, - "mtime_ns": int(mtime_ns), - } - ins = ( - sqlite.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - - res = session.execute(ins) - if int(res.rowcount or 0) > 0: - out["state_created"] = True - else: - upd = ( - sa.update(AssetCacheState) - .where(AssetCacheState.file_path == locator) - .where( - sa.or_( - AssetCacheState.asset_id != asset.id, - AssetCacheState.mtime_ns.is_(None), - AssetCacheState.mtime_ns != int(mtime_ns), - ) - ) - .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) - ) - res2 = session.execute(upd) - if int(res2.rowcount or 0) > 0: - out["state_updated"] = True - - # 3) Optional AssetInfo + tags + metadata - if info_name: - try: - with session.begin_nested(): - info = AssetInfo( - owner_id=owner_id, - name=info_name, - asset_id=asset.id, - preview_id=preview_id, - created_at=now, - updated_at=now, - last_access_time=now, - ) - session.add(info) - session.flush() - out["asset_info_id"] = info.id - except IntegrityError: - pass - - existing_info = ( - session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == info_name, - (AssetInfo.owner_id == owner_id), - ) - .limit(1) - ) - ).unique().scalar_one_or_none() - if not existing_info: - raise RuntimeError("Failed to update or insert AssetInfo.") - - if preview_id and existing_info.preview_id != preview_id: - existing_info.preview_id = preview_id - - existing_info.updated_at = now - if existing_info.last_access_time < now: - existing_info.last_access_time = now - session.flush() - out["asset_info_id"] = existing_info.id - - norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] - if norm and out["asset_info_id"] is not None: - if not require_existing_tags: - ensure_tags_exist(session, norm, tag_type="user") - - existing_tag_names = set( - name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() - ) - missing = [t for t in norm if t not in existing_tag_names] - if missing and require_existing_tags: - raise ValueError(f"Unknown tags: {missing}") - - existing_links = set( - tag_name - for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) - ) - ).all() - ) - to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] - if to_add: - session.add_all( - [ - AssetInfoTag( - asset_info_id=out["asset_info_id"], - tag_name=t, - origin=tag_origin, - added_at=now, - ) - for t in to_add - ] - ) - session.flush() - - # metadata["filename"] hack - if out["asset_info_id"] is not None: - primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - computed_filename = compute_relative_filename(primary_path) if primary_path else None - - current_meta = existing_info.user_metadata or {} - new_meta = dict(current_meta) - if user_metadata is not None: - for k, v in user_metadata.items(): - new_meta[k] = v - if computed_filename: - new_meta["filename"] = computed_filename - - if new_meta != current_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=out["asset_info_id"], - user_metadata=new_meta, - ) - - try: - remove_missing_tag_for_asset_id(session, asset_id=asset.id) - except Exception: - logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) - return out - - -def update_asset_info_full( - session: Session, - *, - asset_info_id: str, - name: str | None = None, - tags: Sequence[str] | None = None, - user_metadata: dict | None = None, - tag_origin: str = "manual", - asset_info_row: Any = None, -) -> AssetInfo: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - else: - info = asset_info_row - - touched = False - if name is not None and name != info.name: - info.name = name - touched = True - - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - - if user_metadata is not None: - new_meta = dict(user_metadata) - if computed_filename: - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - else: - if computed_filename: - current_meta = info.user_metadata or {} - if current_meta.get("filename") != computed_filename: - new_meta = dict(current_meta) - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=tag_origin, - ) - touched = True - - if touched and user_metadata is None: - info.updated_at = utcnow() - session.flush() - - return info - - -def delete_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - owner_id: str, -) -> bool: - stmt = sa.delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - return int((session.execute(stmt)).rowcount or 0) > 0 - - -def list_tags_with_usage( - session: Session, - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - include_zero: bool = True, - order: str = "count_desc", - owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - counts_sq = ( - select( - AssetInfoTag.tag_name.label("tag_name"), - func.count(AssetInfoTag.asset_info_id).label("cnt"), - ) - .select_from(AssetInfoTag) - .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(visible_owner_clause(owner_id)) - .group_by(AssetInfoTag.tag_name) - .subquery() - ) - - q = ( - select( - Tag.name, - Tag.tag_type, - func.coalesce(counts_sq.c.cnt, 0).label("count"), - ) - .select_from(Tag) - .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) - ) - - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - q = q.where(Tag.name.like(escaped + "%", escape=esc)) - - if not include_zero: - q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - - if order == "name_asc": - q = q.order_by(Tag.name.asc()) - else: - q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) - - total_q = select(func.count()).select_from(Tag) - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) - if not include_zero: - total_q = total_q.where( - Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) - ) - - rows = (session.execute(q.limit(limit).offset(offset))).all() - total = (session.execute(total_q)).scalar_one() - - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] - return rows_norm, int(total or 0) - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - session.execute(ins) - - -def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: - return [ - tag_name for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - ] - - -def add_tags_to_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", - create_if_missing: bool = True, - asset_info_row: Any = None, -) -> dict: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": [], "already_present": [], "total_tags": total} - - if create_if_missing: - ensure_tags_exist(session, norm, tag_type="user") - - current = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - want = set(norm) - to_add = sorted(want - current) - - if to_add: - with session.begin_nested() as nested: - try: - session.add_all( - [ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_at=utcnow(), - ) - for t in to_add - ] - ) - session.flush() - except IntegrityError: - nested.rollback() - - after = set(get_asset_tags(session, asset_info_id=asset_info_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } - - -def remove_tags_from_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], -) -> dict: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": [], "not_present": [], "total_tags": total} - - existing = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - to_remove = sorted(set(t for t in norm if t in existing)) - not_present = sorted(set(t for t in norm if t not in existing)) - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where( - AssetInfoTag.asset_info_id == asset_info_id, - AssetInfoTag.tag_name.in_(to_remove), - ) - ) - session.flush() - - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} - - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sa.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) - - -def set_asset_info_preview( - session: Session, - *, - asset_info_id: str, - preview_asset_id: str | None = None, -) -> None: - """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - if preview_asset_id is None: - info.preview_id = None - else: - # validate preview asset exists - if not session.get(Asset, preview_asset_id): - raise ValueError(f"Preview Asset {preview_asset_id} not found") - info.preview_id = preview_asset_id - - info.updated_at = utcnow() - session.flush() diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py new file mode 100644 index 000000000..7888d0645 --- /dev/null +++ b/app/assets/database/queries/__init__.py @@ -0,0 +1,121 @@ +from app.assets.database.queries.asset import ( + asset_exists_by_hash, + bulk_insert_assets, + get_asset_by_hash, + get_existing_asset_ids, + reassign_asset_references, + update_asset_hash_and_mime, + upsert_asset, +) +from app.assets.database.queries.asset_reference import ( + CacheStateRow, + UnenrichedReferenceRow, + bulk_insert_references_ignore_conflicts, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + convert_metadata_to_rows, + delete_assets_by_ids, + delete_orphaned_seed_asset, + delete_reference_by_id, + delete_references_by_ids, + fetch_reference_and_asset, + fetch_reference_asset_and_tags, + get_or_create_reference, + get_reference_by_file_path, + get_reference_by_id, + get_reference_with_owner_check, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_references_for_prefixes, + get_unenriched_references, + get_unreferenced_unhashed_asset_ids, + insert_reference, + list_references_by_asset_id, + list_references_page, + mark_references_missing_outside_prefixes, + reference_exists_for_asset_id, + restore_references_by_paths, + set_reference_metadata, + set_reference_preview, + soft_delete_reference_by_id, + update_reference_access_time, + update_reference_name, + update_reference_timestamps, + update_reference_updated_at, + upsert_reference, +) +from app.assets.database.queries.tags import ( + AddTagsResult, + RemoveTagsResult, + SetTagsResult, + add_missing_tag_for_asset_id, + add_tags_to_reference, + bulk_insert_tags_and_meta, + ensure_tags_exist, + get_reference_tags, + list_tags_with_usage, + remove_missing_tag_for_asset_id, + remove_tags_from_reference, + set_reference_tags, + validate_tags_exist, +) + +__all__ = [ + "AddTagsResult", + "CacheStateRow", + "RemoveTagsResult", + "SetTagsResult", + "UnenrichedReferenceRow", + "add_missing_tag_for_asset_id", + "add_tags_to_reference", + "asset_exists_by_hash", + "bulk_insert_assets", + "bulk_insert_references_ignore_conflicts", + "bulk_insert_tags_and_meta", + "bulk_update_enrichment_level", + "bulk_update_is_missing", + "bulk_update_needs_verify", + "convert_metadata_to_rows", + "delete_assets_by_ids", + "delete_orphaned_seed_asset", + "delete_reference_by_id", + "delete_references_by_ids", + "ensure_tags_exist", + "fetch_reference_and_asset", + "fetch_reference_asset_and_tags", + "get_asset_by_hash", + "get_existing_asset_ids", + "get_or_create_reference", + "get_reference_by_file_path", + "get_reference_by_id", + "get_reference_with_owner_check", + "get_reference_ids_by_ids", + "get_reference_tags", + "get_references_by_paths_and_asset_ids", + "get_references_for_prefixes", + "get_unenriched_references", + "get_unreferenced_unhashed_asset_ids", + "insert_reference", + "list_references_by_asset_id", + "list_references_page", + "list_tags_with_usage", + "mark_references_missing_outside_prefixes", + "reassign_asset_references", + "reference_exists_for_asset_id", + "remove_missing_tag_for_asset_id", + "remove_tags_from_reference", + "restore_references_by_paths", + "set_reference_metadata", + "set_reference_preview", + "soft_delete_reference_by_id", + "set_reference_tags", + "update_asset_hash_and_mime", + "update_reference_access_time", + "update_reference_name", + "update_reference_timestamps", + "update_reference_updated_at", + "upsert_asset", + "upsert_reference", + "validate_tags_exist", +] diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py new file mode 100644 index 000000000..a21f5b68f --- /dev/null +++ b/app/assets/database/queries/asset.py @@ -0,0 +1,140 @@ +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import sqlite +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks + + +def asset_exists_by_hash( + session: Session, + asset_hash: str, +) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)) + .select_from(Asset) + .where(Asset.hash == asset_hash) + .limit(1) + ) + ).first() + return row is not None + + +def get_asset_by_hash( + session: Session, + asset_hash: str, +) -> Asset | None: + return ( + (session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))) + .scalars() + .first() + ) + + +def upsert_asset( + session: Session, + asset_hash: str, + size_bytes: int, + mime_type: str | None = None, +) -> tuple[Asset, bool, bool]: + """Upsert an Asset by hash. Returns (asset, created, updated).""" + vals = {"hash": asset_hash, "size_bytes": int(size_bytes)} + if mime_type: + vals["mime_type"] = mime_type + + ins = ( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + .scalars() + .first() + ) + if not asset: + raise RuntimeError("Asset row not found after upsert.") + + updated = False + if not created: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + updated = True + + return asset, created, updated + + +def bulk_insert_assets( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash.""" + if not rows: + return + ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash]) + for chunk in iter_chunks(rows, calculate_rows_per_statement(5)): + session.execute(ins, chunk) + + +def get_existing_asset_ids( + session: Session, + asset_ids: list[str], +) -> set[str]: + """Return the subset of asset_ids that exist in the database.""" + if not asset_ids: + return set() + found: set[str] = set() + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + rows = session.execute( + select(Asset.id).where(Asset.id.in_(chunk)) + ).fetchall() + found.update(row[0] for row in rows) + return found + + +def update_asset_hash_and_mime( + session: Session, + asset_id: str, + asset_hash: str | None = None, + mime_type: str | None = None, +) -> bool: + """Update asset hash and/or mime_type. Returns True if asset was found.""" + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset_hash is not None: + asset.hash = asset_hash + if mime_type is not None: + asset.mime_type = mime_type + return True + + +def reassign_asset_references( + session: Session, + from_asset_id: str, + to_asset_id: str, + reference_id: str, +) -> None: + """Reassign a reference from one asset to another. + + Used when merging a stub asset into an existing asset with the same hash. + """ + ref = session.get(AssetReference, reference_id) + if ref and ref.asset_id == from_asset_id: + ref.asset_id = to_asset_id + + session.flush() diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py new file mode 100644 index 000000000..6524791cc --- /dev/null +++ b/app/assets/database/queries/asset_reference.py @@ -0,0 +1,1033 @@ +"""Query functions for the unified AssetReference table. + +This module replaces the separate asset_info.py and cache_state.py query modules, +providing a unified interface for the merged asset_references table. +""" + +from collections import defaultdict +from datetime import datetime +from decimal import Decimal +from typing import NamedTuple, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, exists, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, noload + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + build_prefix_like_conditions, + build_visible_owner_clause, + calculate_rows_per_statement, + iter_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +def _check_is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def _scalar_to_row(key: str, ordinal: int, value) -> dict: + """Convert a scalar value to a typed projection row.""" + if value is None: + return { + "key": key, + "ordinal": ordinal, + "val_str": None, + "val_num": None, + "val_bool": None, + "val_json": None, + } + if isinstance(value, bool): + return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return {"key": key, "ordinal": ordinal, "val_num": num} + if isinstance(value, str): + return {"key": key, "ordinal": ordinal, "val_str": value} + return {"key": key, "ordinal": ordinal, "val_json": value} + + +def convert_metadata_to_rows(key: str, value) -> list[dict]: + """Turn a metadata key/value into typed projection rows.""" + if value is None: + return [_scalar_to_row(key, 0, None)] + + if _check_is_scalar(value): + return [_scalar_to_row(key, 0, value)] + + if isinstance(value, list): + if all(_check_is_scalar(x) for x in value): + return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] + return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] + + return [{"key": key, "ordinal": 0, "val_json": value}] + + +def _apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def _apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_reference_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetReferenceMeta.val_json.is_(None), + AssetReferenceMeta.val_str.is_(None), + AssetReferenceMeta.val_num.is_(None), + AssetReferenceMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetReferenceMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetReferenceMeta.val_str == value) + return _exists_for_pred(key, AssetReferenceMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def get_reference_by_id( + session: Session, + reference_id: str, +) -> AssetReference | None: + return session.get(AssetReference, reference_id) + + +def get_reference_with_owner_check( + session: Session, + reference_id: str, + owner_id: str, +) -> AssetReference: + """Fetch a reference and verify ownership. + + Raises: + ValueError: if reference not found or soft-deleted + PermissionError: if owner_id doesn't match + """ + ref = get_reference_by_id(session, reference_id=reference_id) + if not ref or ref.deleted_at is not None: + raise ValueError(f"AssetReference {reference_id} not found") + if ref.owner_id and ref.owner_id != owner_id: + raise PermissionError("not owner") + return ref + + +def get_reference_by_file_path( + session: Session, + file_path: str, +) -> AssetReference | None: + """Get a reference by its file path.""" + return ( + session.execute( + select(AssetReference).where(AssetReference.file_path == file_path).limit(1) + ) + .scalars() + .first() + ) + + +def reference_exists_for_asset_id( + session: Session, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.deleted_at.is_(None)) + .limit(1) + ) + return session.execute(q).first() is not None + + +def insert_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> AssetReference | None: + """Insert a new AssetReference. Returns None if unique constraint violated.""" + now = get_utc_now() + try: + with session.begin_nested(): + ref = AssetReference( + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + except IntegrityError: + return None + + +def get_or_create_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> tuple[AssetReference, bool]: + """Get existing or create new AssetReference. + + For filesystem references (file_path is set), uniqueness is by file_path. + For API references (file_path is None), we look for matching + asset_id + owner_id + name. + + Returns (reference, created). + """ + ref = insert_reference( + session, + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + ) + if ref: + return ref, True + + # Find existing - priority to file_path match, then name match + if file_path: + existing = get_reference_by_file_path(session, file_path) + else: + existing = ( + session.execute( + select(AssetReference) + .where( + AssetReference.asset_id == asset_id, + AssetReference.name == name, + AssetReference.owner_id == owner_id, + AssetReference.file_path.is_(None), + ) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + if not existing: + raise RuntimeError("Failed to find AssetReference after insert conflict.") + return existing, False + + +def update_reference_timestamps( + session: Session, + reference: AssetReference, + preview_id: str | None = None, +) -> None: + """Update timestamps and optionally preview_id on existing AssetReference.""" + now = get_utc_now() + if preview_id and reference.preview_id != preview_id: + reference.preview_id = preview_id + reference.updated_at = now + + +def list_references_page( + session: Session, + owner_id: str = "", + limit: int = 100, + offset: int = 0, + name_contains: str | None = None, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + metadata_filter: dict | None = None, + sort: str | None = None, + order: str | None = None, +) -> tuple[list[AssetReference], dict[str, list[str]], int]: + """List references with pagination, filtering, and sorting. + + Returns (references, tag_map, total_count). + """ + base = ( + select(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .options(noload(AssetReference.tags)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + base = _apply_tag_filters(base, include_tags, exclude_tags) + base = _apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetReference.name, + "created_at": AssetReference.created_at, + "updated_at": AssetReference.updated_at, + "last_access_time": AssetReference.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetReference.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + count_stmt = count_stmt.where( + AssetReference.name.ilike(f"%{escaped}%", escape=esc) + ) + count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + + total = int(session.execute(count_stmt).scalar_one() or 0) + refs = session.execute(base).unique().scalars().all() + + id_list: list[str] = [r.id for r in refs] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetReferenceTag.asset_reference_id, Tag.name) + .join(Tag, Tag.name == AssetReferenceTag.tag_name) + .where(AssetReferenceTag.asset_reference_id.in_(id_list)) + .order_by(AssetReferenceTag.added_at) + ) + for ref_id, tag_name in rows.all(): + tag_map[ref_id].append(tag_name) + + return list(refs), tag_map, total + + +def fetch_reference_asset_and_tags( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset, list[str]] | None: + stmt = ( + select(AssetReference, Asset, Tag.name) + .join(Asset, Asset.id == AssetReference.asset_id) + .join( + AssetReferenceTag, + AssetReferenceTag.asset_reference_id == AssetReference.id, + isouter=True, + ) + .join(Tag, Tag.name == AssetReferenceTag.tag_name, isouter=True) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .options(noload(AssetReference.tags)) + .order_by(Tag.name.asc()) + ) + + rows = session.execute(stmt).all() + if not rows: + return None + + first_ref, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _ref, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_ref, first_asset, tags + + +def fetch_reference_and_asset( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset] | None: + stmt = ( + select(AssetReference, Asset) + .join(Asset, Asset.id == AssetReference.asset_id) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetReference.tags)) + ) + pair = session.execute(stmt).first() + if not pair: + return None + return pair[0], pair[1] + + +def update_reference_access_time( + session: Session, + reference_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or get_utc_now() + stmt = sa.update(AssetReference).where(AssetReference.id == reference_id) + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetReference.last_access_time.is_(None), + AssetReference.last_access_time < ts, + ) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def update_reference_name( + session: Session, + reference_id: str, + name: str, +) -> None: + """Update the name of an AssetReference.""" + now = get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(name=name, updated_at=now) + ) + + +def update_reference_updated_at( + session: Session, + reference_id: str, + ts: datetime | None = None, +) -> None: + """Update the updated_at timestamp of an AssetReference.""" + ts = ts or get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(updated_at=ts) + ) + + +def set_reference_metadata( + session: Session, + reference_id: str, + user_metadata: dict | None = None, +) -> None: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + ref.user_metadata = user_metadata or {} + ref.updated_at = get_utc_now() + session.flush() + + session.execute( + delete(AssetReferenceMeta).where( + AssetReferenceMeta.asset_reference_id == reference_id + ) + ) + session.flush() + + if not user_metadata: + return + + rows: list[AssetReferenceMeta] = [] + for k, v in user_metadata.items(): + for r in convert_metadata_to_rows(k, v): + rows.append( + AssetReferenceMeta( + asset_reference_id=reference_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetReference).where( + AssetReference.id == reference_id, + build_visible_owner_clause(owner_id), + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def soft_delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + """Mark a reference as soft-deleted by setting deleted_at timestamp. + + Returns True if the reference was found and marked deleted. + """ + now = get_utc_now() + stmt = ( + sa.update(AssetReference) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .values(deleted_at=now) + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def set_reference_preview( + session: Session, + reference_id: str, + preview_asset_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + if preview_asset_id is None: + ref.preview_id = None + else: + if not session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + ref.preview_id = preview_asset_id + + ref.updated_at = get_utc_now() + session.flush() + + +class CacheStateRow(NamedTuple): + """Row from reference query with cache state data.""" + + reference_id: str + file_path: str + mtime_ns: int | None + needs_verify: bool + asset_id: str + asset_hash: str | None + size_bytes: int | None + + +def list_references_by_asset_id( + session: Session, + asset_id: str, +) -> Sequence[AssetReference]: + return ( + session.execute( + select(AssetReference) + .where(AssetReference.asset_id == asset_id) + .order_by(AssetReference.id.asc()) + ) + .scalars() + .all() + ) + + +def upsert_reference( + session: Session, + asset_id: str, + file_path: str, + name: str, + mtime_ns: int, + owner_id: str = "", +) -> tuple[bool, bool]: + """Upsert a reference by file_path. Returns (created, updated). + + Also restores references that were previously marked as missing. + """ + now = get_utc_now() + vals = { + "asset_id": asset_id, + "file_path": file_path, + "name": name, + "owner_id": owner_id, + "mtime_ns": int(mtime_ns), + "is_missing": False, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ins = ( + sqlite.insert(AssetReference) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetReference.file_path]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + if created: + return True, False + + upd = ( + sa.update(AssetReference) + .where(AssetReference.file_path == file_path) + .where( + sa.or_( + AssetReference.asset_id != asset_id, + AssetReference.mtime_ns.is_(None), + AssetReference.mtime_ns != int(mtime_ns), + AssetReference.is_missing == True, # noqa: E712 + AssetReference.deleted_at.isnot(None), + ) + ) + .values( + asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, + deleted_at=None, updated_at=now, + ) + ) + res2 = session.execute(upd) + updated = int(res2.rowcount or 0) > 0 + return False, updated + + +def mark_references_missing_outside_prefixes( + session: Session, + valid_prefixes: list[str], +) -> int: + """Mark references as missing when file_path doesn't match any valid prefix. + + Returns number of references marked as missing. + """ + if not valid_prefixes: + return 0 + + conds = build_prefix_like_conditions(valid_prefixes) + matches_valid_prefix = sa.or_(*conds) + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(~matches_valid_prefix) + .where(AssetReference.is_missing == False) # noqa: E712 + .values(is_missing=True) + ) + return result.rowcount + + +def restore_references_by_paths(session: Session, file_paths: list[str]) -> int: + """Restore references that were previously marked as missing. + + Returns number of references restored. + """ + if not file_paths: + return 0 + + total = 0 + for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.in_(chunk)) + .where(AssetReference.is_missing == True) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .values(is_missing=False) + ) + total += result.rowcount + return total + + +def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]: + """Get IDs of unhashed assets (hash=None) with no active references. + + An asset is considered unreferenced if it has no references, + or all its references are marked as missing. + + Returns list of asset IDs that are unreferenced. + """ + active_ref_exists = ( + sa.select(sa.literal(1)) + .where(AssetReference.asset_id == Asset.id) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .correlate(Asset) + .exists() + ) + unreferenced_subq = sa.select(Asset.id).where( + Asset.hash.is_(None), ~active_ref_exists + ) + return [row[0] for row in session.execute(unreferenced_subq).all()] + + +def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int: + """Delete assets and their references by ID. + + Returns number of assets deleted. + """ + if not asset_ids: + return 0 + total = 0 + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk)) + ) + result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk))) + total += result.rowcount + return total + + +def get_references_for_prefixes( + session: Session, + prefixes: list[str], + *, + include_missing: bool = False, +) -> list[CacheStateRow]: + """Get all references with file paths matching any of the given prefixes. + + Args: + session: Database session + prefixes: List of absolute directory prefixes to match + include_missing: If False (default), exclude references marked as missing + + Returns: + List of cache state rows with joined asset data + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.file_path, + AssetReference.mtime_ns, + AssetReference.needs_verify, + AssetReference.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + ) + + if not include_missing: + query = query.where(AssetReference.is_missing == False) # noqa: E712 + + rows = session.execute( + query.order_by(AssetReference.asset_id.asc(), AssetReference.id.asc()) + ).all() + + return [ + CacheStateRow( + reference_id=row[0], + file_path=row[1], + mtime_ns=row[2], + needs_verify=row[3], + asset_id=row[4], + asset_hash=row[5], + size_bytes=int(row[6]) if row[6] is not None else None, + ) + for row in rows + ] + + +def bulk_update_needs_verify( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set needs_verify flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(needs_verify=value) + ) + total += result.rowcount + return total + + +def bulk_update_is_missing( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set is_missing flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(is_missing=value) + ) + total += result.rowcount + return total + + +def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: + """Delete references by their IDs. + + Returns: Number of rows deleted + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.delete(AssetReference).where(AssetReference.id.in_(chunk)) + ) + total += result.rowcount + return total + + +def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool: + """Delete a seed asset (hash is None) and its references. + + Returns: True if asset was deleted, False if not found or has a hash + """ + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset.hash is not None: + return False + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id == asset_id) + ) + session.delete(asset) + return True + + +class UnenrichedReferenceRow(NamedTuple): + """Row for references needing enrichment.""" + + reference_id: str + asset_id: str + file_path: str + enrichment_level: int + + +def get_unenriched_references( + session: Session, + prefixes: list[str], + max_level: int = 0, + limit: int = 1000, +) -> list[UnenrichedReferenceRow]: + """Get references that need enrichment (enrichment_level <= max_level). + + Args: + session: Database session + prefixes: List of absolute directory prefixes to scan + max_level: Maximum enrichment level to include (0=stubs, 1=metadata done) + limit: Maximum number of rows to return + + Returns: + List of unenriched reference rows with file paths + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.asset_id, + AssetReference.file_path, + AssetReference.enrichment_level, + ) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.enrichment_level <= max_level) + .order_by(AssetReference.id.asc()) + .limit(limit) + ) + + rows = session.execute(query).all() + return [ + UnenrichedReferenceRow( + reference_id=row[0], + asset_id=row[1], + file_path=row[2], + enrichment_level=row[3], + ) + for row in rows + ] + + +def bulk_update_enrichment_level( + session: Session, + reference_ids: list[str], + level: int, +) -> int: + """Update enrichment level for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(reference_ids)) + .values(enrichment_level=level) + ) + return result.rowcount + + +def bulk_insert_references_ignore_conflicts( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert reference rows with ON CONFLICT DO NOTHING on file_path. + + Each dict should have: id, asset_id, file_path, name, owner_id, mtime_ns, etc. + The is_missing field is automatically set to False for new inserts. + """ + if not rows: + return + enriched_rows = [{**row, "is_missing": False} for row in rows] + ins = sqlite.insert(AssetReference).on_conflict_do_nothing( + index_elements=[AssetReference.file_path] + ) + for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)): + session.execute(ins, chunk) + + +def get_references_by_paths_and_asset_ids( + session: Session, + path_to_asset: dict[str, str], +) -> set[str]: + """Query references to find paths where our asset_id won the insert. + + Args: + path_to_asset: Mapping of file_path -> asset_id we tried to insert + + Returns: + Set of file_paths where our asset_id is present + """ + if not path_to_asset: + return set() + + pairs = list(path_to_asset.items()) + winners: set[str] = set() + + # Each pair uses 2 bind params, so chunk at MAX_BIND_PARAMS // 2 + for chunk in iter_chunks(pairs, MAX_BIND_PARAMS // 2): + pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_( + chunk + ) + result = session.execute( + select(AssetReference.file_path).where(pairwise) + ) + winners.update(result.scalars().all()) + + return winners + + +def get_reference_ids_by_ids( + session: Session, + reference_ids: list[str], +) -> set[str]: + """Query to find which reference IDs exist in the database.""" + if not reference_ids: + return set() + + found: set[str] = set() + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + select(AssetReference.id).where(AssetReference.id.in_(chunk)) + ) + found.update(result.scalars().all()) + return found diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py new file mode 100644 index 000000000..194c39a1e --- /dev/null +++ b/app/assets/database/queries/common.py @@ -0,0 +1,54 @@ +"""Shared utilities for database query modules.""" + +import os +from typing import Iterable + +import sqlalchemy as sa + +from app.assets.database.models import AssetReference +from app.assets.helpers import escape_sql_like_string + +MAX_BIND_PARAMS = 800 + + +def calculate_rows_per_statement(cols: int) -> int: + """Calculate how many rows can fit in one statement given column count.""" + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def iter_chunks(seq, n: int): + """Yield successive n-sized chunks from seq.""" + for i in range(0, len(seq), n): + yield seq[i : i + n] + + +def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]: + """Yield chunks of rows sized to fit within bind param limits.""" + if not rows: + return + yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row)) + + +def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. + + Owner-less rows are visible to everyone. + """ + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetReference.owner_id == "" + return AssetReference.owner_id.in_(["", owner_id]) + + +def build_prefix_like_conditions( + prefixes: list[str], +) -> list[sa.sql.ColumnElement]: + """Build LIKE conditions for matching file paths under directory prefixes.""" + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_sql_like_string(base) + conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + return conds diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py new file mode 100644 index 000000000..8b25fee67 --- /dev/null +++ b/app/assets/database/queries/tags.py @@ -0,0 +1,356 @@ +from dataclasses import dataclass +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + build_visible_owner_clause, + iter_row_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +@dataclass(frozen=True) +class AddTagsResult: + added: list[str] + already_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class RemoveTagsResult: + removed: list[str] + not_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class SetTagsResult: + added: list[str] + removed: list[str] + total: list[str] + + +def validate_tags_exist(session: Session, tags: list[str]) -> None: + """Raise ValueError if any of the given tag names do not exist.""" + existing_tag_names = set( + name + for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() + ) + missing = [t for t in tags if t not in existing_tag_names] + if missing: + raise ValueError(f"Unknown tags: {missing}") + + +def ensure_tags_exist( + session: Session, names: Iterable[str], tag_type: str = "user" +) -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_reference_tags(session: Session, reference_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + session.execute( + select(AssetReferenceTag.tag_name).where( + AssetReferenceTag.asset_reference_id == reference_id + ) + ) + ).all() + ] + + +def set_reference_tags( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> SetTagsResult: + desired = normalize_tags(tags) + + current = set(get_reference_tags(session, reference_id)) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + return SetTagsResult(added=to_add, removed=to_remove, total=desired) + + +def add_tags_to_reference( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + reference_row: AssetReference | None = None, +) -> AddTagsResult: + if not reference_row: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return AddTagsResult(added=[], already_present=[], total_tags=total) + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = set(get_reference_tags(session, reference_id)) + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_reference_tags(session, reference_id=reference_id)) + return AddTagsResult( + added=sorted(((after - current) & want)), + already_present=sorted(want & current), + total_tags=sorted(after), + ) + + +def remove_tags_from_reference( + session: Session, + reference_id: str, + tags: Sequence[str], +) -> RemoveTagsResult: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=[], not_present=[], total_tags=total) + + existing = set(get_reference_tags(session, reference_id)) + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total) + + +def add_missing_tag_for_asset_id( + session: Session, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sa.select( + AssetReference.id.label("asset_reference_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(get_utc_now()).label("added_at"), + ) + .where(AssetReference.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == "missing") + ) + ) + ) + ) + session.execute( + sqlite.insert(AssetReferenceTag) + .from_select( + ["asset_reference_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + ) + + +def remove_missing_tag_for_asset_id( + session: Session, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id.in_( + sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id) + ), + AssetReferenceTag.tag_name == "missing", + ) + ) + + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetReferenceTag.tag_name.label("tag_name"), + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .select_from(AssetReferenceTag) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + visible_tags_sq = ( + select(AssetReferenceTag.tag_name) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + ) + total_q = total_q.where(Tag.name.in_(visible_tags_sq)) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +def bulk_insert_tags_and_meta( + session: Session, + tag_rows: list[dict], + meta_rows: list[dict], +) -> None: + """Batch insert into asset_reference_tags and asset_reference_meta. + + Uses ON CONFLICT DO NOTHING. + + Args: + session: Database session + tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at + meta_rows: Dicts with: asset_reference_id, key, ordinal, val_* + """ + if tag_rows: + ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + for chunk in iter_row_chunks(tag_rows, cols_per_row=4): + session.execute(ins_tags, chunk) + + if meta_rows: + ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing( + index_elements=[ + AssetReferenceMeta.asset_reference_id, + AssetReferenceMeta.key, + AssetReferenceMeta.ordinal, + ] + ) + for chunk in iter_row_chunks(meta_rows, cols_per_row=7): + session.execute(ins_meta, chunk) diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py deleted file mode 100644 index 3ab6497c2..000000000 --- a/app/assets/database/tags.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Iterable - -import sqlalchemy -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import normalize_tags, utcnow -from app.assets.database.models import Tag, AssetInfoTag, AssetInfo - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - return session.execute(ins) - -def add_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, - origin: str = "automatic", -) -> None: - select_rows = ( - sqlalchemy.select( - AssetInfo.id.label("asset_info_id"), - sqlalchemy.literal("missing").label("tag_name"), - sqlalchemy.literal(origin).label("origin"), - sqlalchemy.literal(utcnow()).label("added_at"), - ) - .where(AssetInfo.asset_id == asset_id) - .where( - sqlalchemy.not_( - sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) - ) - ) - ) - session.execute( - sqlite.insert(AssetInfoTag) - .from_select( - ["asset_info_id", "tag_name", "origin", "added_at"], - select_rows, - ) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sqlalchemy.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) diff --git a/app/assets/hashing.py b/app/assets/hashing.py deleted file mode 100644 index 4b72084b9..000000000 --- a/app/assets/hashing.py +++ /dev/null @@ -1,75 +0,0 @@ -from blake3 import blake3 -from typing import IO -import os -import asyncio - - -DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB - -# NOTE: this allows hashing different representations of a file-like object -def blake3_hash( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """ - Returns a BLAKE3 hex digest for ``fp``, which may be: - - a filename (str/bytes) or PathLike - - an open binary file object - If ``fp`` is a file object, it must be opened in **binary** mode and support - ``read``, ``seek``, and ``tell``. The function will seek to the start before - reading and will attempt to restore the original position afterward. - """ - # duck typing to check if input is a file-like object - if hasattr(fp, "read"): - return _hash_file_obj(fp, chunk_size) - - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - -async def blake3_hash_async( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """Async wrapper for ``blake3_hash_sync``. - Uses a worker thread so the event loop remains responsive. - """ - # If it is a path, open inside the worker thread to keep I/O off the loop. - if hasattr(fp, "read"): - return await asyncio.to_thread(blake3_hash, fp, chunk_size) - - def _worker() -> str: - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - return await asyncio.to_thread(_worker) - - -def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: - """ - Hash an already-open binary file object by streaming in chunks. - - Seeks to the beginning before reading (if supported). - - Restores the original position afterward (if tell/seek are supported). - """ - if chunk_size <= 0: - chunk_size = DEFAULT_CHUNK - - # in case file object is already open and not at the beginning, track so can be restored after hashing - orig_pos = file_obj.tell() - - try: - # seek to the beginning before reading - if orig_pos != 0: - file_obj.seek(0) - - h = blake3() - while True: - chunk = file_obj.read(chunk_size) - if not chunk: - break - h.update(chunk) - return h.hexdigest() - finally: - # restore original position in file object, if needed - if orig_pos != 0: - file_obj.seek(orig_pos) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 5030b123a..3798f3933 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -1,226 +1,42 @@ -import contextlib import os -from decimal import Decimal -from aiohttp import web from datetime import datetime, timezone -from pathlib import Path -from typing import Literal, Any - -import folder_paths +from typing import Sequence -RootType = Literal["models", "input", "output"] -ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - -def get_query_dict(request: web.Request) -> dict[str, Any]: +def select_best_live_path(states: Sequence) -> str: """ - Gets a dictionary of query parameters from the request. - - 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. """ - query_dict = { - key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) - for key in request.query.keys() - } - return query_dict + alive = [ + s + for s in states + if getattr(s, "file_path", None) and os.path.isfile(s.file_path) + ] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path -def list_tree(base_dir: str) -> list[str]: - out: list[str] = [] - base_abs = os.path.abspath(base_dir) - if not os.path.isdir(base_abs): - return out - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - out.append(os.path.abspath(os.path.join(dirpath, name))) - return out -def prefixes_for_root(root: RootType) -> list[str]: - if root == "models": - bases: list[str] = [] - for _bucket, paths in get_comfy_models_folders(): - bases.extend(paths) - return [os.path.abspath(p) for p in bases] - if root == "input": - return [os.path.abspath(folder_paths.get_input_directory())] - if root == "output": - return [os.path.abspath(folder_paths.get_output_directory())] - return [] +def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char in a LIKE prefix. -def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: - """Escapes %, _ and the escape char itself in a LIKE prefix. - Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + Returns (escaped_prefix, escape_char). """ s = s.replace(escape, escape + escape) # escape the escape char first s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards return s, escape -def fast_asset_file_check( - *, - mtime_db: int | None, - size_db: int | None, - stat_result: os.stat_result, -) -> bool: - if mtime_db is None: - return False - actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) - if int(mtime_db) != int(actual_mtime_ns): - return False - sz = int(size_db or 0) - if sz > 0: - return int(stat_result.st_size) == sz - return True -def utcnow() -> datetime: +def get_utc_now() -> datetime: """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" return datetime.now(timezone.utc).replace(tzinfo=None) -def get_comfy_models_folders() -> list[tuple[str, list[str]]]: - """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. - - We trust `folder_paths.folder_names_and_paths` and include a category if - *any* of its base paths lies under the Comfy `models_dir`. - """ - targets: list[tuple[str, list[str]]] = [] - models_root = os.path.abspath(folder_paths.models_dir) - for name, values in folder_paths.folder_names_and_paths.items(): - paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI - if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): - targets.append((name, paths)) - return targets - -def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: - """Validates and maps tags -> (base_dir, subdirs_for_fs)""" - root = tags[0] - if root == "models": - if len(tags) < 2: - raise ValueError("at least two tags required for model asset") - try: - bases = folder_paths.folder_names_and_paths[tags[1]][0] - except KeyError: - raise ValueError(f"unknown model category '{tags[1]}'") - if not bases: - raise ValueError(f"no base path configured for category '{tags[1]}'") - base_dir = os.path.abspath(bases[0]) - raw_subdirs = tags[2:] - else: - base_dir = os.path.abspath( - folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() - ) - raw_subdirs = tags[1:] - for i in raw_subdirs: - if i in (".", ".."): - raise ValueError("invalid path component in tags") - - return base_dir, raw_subdirs if raw_subdirs else [] - -def ensure_within_base(candidate: str, base: str) -> None: - cand_abs = os.path.abspath(candidate) - base_abs = os.path.abspath(base) - try: - if os.path.commonpath([cand_abs, base_abs]) != base_abs: - raise ValueError("destination escapes base directory") - except Exception: - raise ValueError("invalid destination path") - -def compute_relative_filename(file_path: str) -> str | None: - """ - Return the model's path relative to the last well-known folder (the model category), - using forward slashes, eg: - /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" - /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" - - For non-model paths, returns None. - NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. - """ - try: - root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) - except ValueError: - return None - - p = Path(rel_path) - parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] - if not parts: - return None - - if root_category == "models": - # parts[0] is the category ("checkpoints", "vae", etc) – drop it - inside = parts[1:] if len(parts) > 1 else [parts[0]] - return "/".join(inside) - return "/".join(parts) # input/output: keep all parts - -def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: - """Given an absolute or relative file path, determine which root category the path belongs to: - - 'input' if the file resides under `folder_paths.get_input_directory()` - - 'output' if the file resides under `folder_paths.get_output_directory()` - - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` - - Returns: - (root_category, relative_path_inside_that_root) - For 'models', the relative path is prefixed with the category name: - e.g. ('models', 'vae/test/sub/ae.safetensors') - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - fp_abs = os.path.abspath(file_path) - - def _is_within(child: str, parent: str) -> bool: - try: - return os.path.commonpath([child, parent]) == parent - except Exception: - return False - - def _rel(child: str, parent: str) -> str: - return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) - - # 1) input - input_base = os.path.abspath(folder_paths.get_input_directory()) - if _is_within(fp_abs, input_base): - return "input", _rel(fp_abs, input_base) - - # 2) output - output_base = os.path.abspath(folder_paths.get_output_directory()) - if _is_within(fp_abs, output_base): - return "output", _rel(fp_abs, output_base) - - # 3) models (check deepest matching base to avoid ambiguity) - best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) - for bucket, bases in get_comfy_models_folders(): - for b in bases: - base_abs = os.path.abspath(b) - if not _is_within(fp_abs, base_abs): - continue - cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) - if best is None or cand[0] > best[0]: - best = cand - - if best is not None: - _, bucket, rel_inside = best - combined = os.path.join(bucket, rel_inside) - return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) - - raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") - -def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: - """Return a tuple (name, tags) derived from a filesystem path. - - Semantics: - - Root category is determined by `get_relative_to_root_category_path_of_asset`. - - The returned `name` is the base filename with extension from the relative path. - - The returned `tags` are: - [root_category] + parent folders of the relative path (in order) - For 'models', this means: - file '/.../ModelsDir/vae/test_tag/ae.safetensors' - -> root_category='models', some_path='vae/test_tag/ae.safetensors' - -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) - p = Path(some_path) - parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] - return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) def normalize_tags(tags: list[str] | None) -> list[str]: """ @@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]: - Stripping whitespace and converting to lowercase. - Removing duplicates. """ - return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) -def collect_models_files() -> list[str]: - out: list[str] = [] - for folder_name, bases in get_comfy_models_folders(): - rel_files = folder_paths.get_filename_list(folder_name) or [] - for rel_path in rel_files: - abs_path = folder_paths.get_full_path(folder_name, rel_path) - if not abs_path: - continue - abs_path = os.path.abspath(abs_path) - allowed = False - for b in bases: - base_abs = os.path.abspath(b) - with contextlib.suppress(Exception): - if os.path.commonpath([abs_path, base_abs]) == base_abs: - allowed = True - break - if allowed: - out.append(abs_path) - return out -def is_scalar(v): - if v is None: - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False +def validate_blake3_hash(s: str) -> str: + """Validate and normalize a blake3 hash string. -def project_kv(key: str, value): + Returns canonical 'blake3:' or raises ValueError. """ - Turn a metadata key/value into typed projection rows. - Returns list[dict] with keys: - key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) - """ - rows: list[dict] = [] - - def _null_row(ordinal: int) -> dict: - return { - "key": key, "ordinal": ordinal, - "val_str": None, "val_num": None, "val_bool": None, "val_json": None - } - - if value is None: - rows.append(_null_row(0)) - return rows - - if is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append(_null_row(i)) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows + s = s.strip().lower() + if not s or ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if ( + algo != "blake3" + or len(digest) != 64 + or any(c for c in digest if c not in "0123456789abcdef") + ): + raise ValueError("hash must be 'blake3:'") + return f"{algo}:{digest}" diff --git a/app/assets/manager.py b/app/assets/manager.py deleted file mode 100644 index a68c8c8ae..000000000 --- a/app/assets/manager.py +++ /dev/null @@ -1,516 +0,0 @@ -import os -import mimetypes -import contextlib -from typing import Sequence - -from app.database.db import create_session -from app.assets.api import schemas_out, schemas_in -from app.assets.database.queries import ( - asset_exists_by_hash, - asset_info_exists_for_asset_id, - get_asset_by_hash, - get_asset_info_by_id, - fetch_asset_info_asset_and_tags, - fetch_asset_info_and_asset, - create_asset_info_for_existing_asset, - touch_asset_info_by_id, - update_asset_info_full, - delete_asset_info_by_id, - list_cache_states_by_asset_id, - list_asset_infos_page, - list_tags_with_usage, - get_asset_tags, - add_tags_to_asset_info, - remove_tags_from_asset_info, - pick_best_live_path, - ingest_fs_asset, - set_asset_info_preview, -) -from app.assets.helpers import resolve_destination_from_tags, ensure_within_base -from app.assets.database.models import Asset - - -def _safe_sort_field(requested: str | None) -> str: - if not requested: - return "created_at" - v = requested.lower() - if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: - return v - return "created_at" - - -def _get_size_mtime_ns(path: str) -> tuple[int, int]: - st = os.stat(path, follow_symlinks=True) - return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - - -def _safe_filename(name: str | None, fallback: str) -> str: - n = os.path.basename((name or "").strip() or fallback) - if n: - return n - return fallback - - -def asset_exists(*, asset_hash: str) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - with create_session() as session: - return asset_exists_by_hash(session, asset_hash=asset_hash) - - -def list_assets( - *, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", - owner_id: str = "", -) -> schemas_out.AssetsList: - sort = _safe_sort_field(sort) - order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() - - with create_session() as session: - infos, tag_map, total = list_asset_infos_page( - session, - owner_id=owner_id, - include_tags=include_tags, - exclude_tags=exclude_tags, - name_contains=name_contains, - metadata_filter=metadata_filter, - limit=limit, - offset=offset, - sort=sort, - order=order, - ) - - summaries: list[schemas_out.AssetSummary] = [] - for info in infos: - asset = info.asset - tags = tag_map.get(info.id, []) - summaries.append( - schemas_out.AssetSummary( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - created_at=info.created_at, - updated_at=info.updated_at, - last_access_time=info.last_access_time, - ) - ) - - return schemas_out.AssetsList( - assets=summaries, - total=total, - has_more=(offset + len(summaries)) < total, - ) - - -def get_asset( - *, - asset_info_id: str, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise ValueError(f"AssetInfo {asset_info_id} not found") - info, asset, tag_names = res - preview_id = info.preview_id - - return schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - - -def resolve_asset_content_for_download( - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[str, str, str]: - with create_session() as session: - pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not pair: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info, asset = pair - states = list_cache_states_by_asset_id(session, asset_id=asset.id) - abs_path = pick_best_live_path(states) - if not abs_path: - raise FileNotFoundError - - touch_asset_info_by_id(session, asset_info_id=asset_info_id) - session.commit() - - ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" - download_name = info.name or os.path.basename(abs_path) - return abs_path, ctype, download_name - - -def upload_asset_from_temp_path( - spec: schemas_in.UploadAssetSpec, - *, - temp_path: str, - client_filename: str | None = None, - owner_id: str = "", - expected_asset_hash: str | None = None, -) -> schemas_out.AssetCreated: - """ - Create new asset or update existing asset from a temporary file path. - """ - try: - # NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment - import app.assets.hashing as hashing - digest = hashing.blake3_hash(temp_path) - except Exception as e: - raise RuntimeError(f"failed to hash uploaded file: {e}") - asset_hash = "blake3:" + digest - - if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): - raise ValueError("HASH_MISMATCH") - - with create_session() as session: - existing = get_asset_by_hash(session, asset_hash=asset_hash) - if existing is not None: - with contextlib.suppress(Exception): - if temp_path and os.path.exists(temp_path): - os.remove(temp_path) - - display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) - info = create_asset_info_for_existing_asset( - session, - asset_hash=asset_hash, - name=display_name, - user_metadata=spec.user_metadata or {}, - tags=spec.tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - session.commit() - - return schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=existing.hash, - size=int(existing.size_bytes) if existing.size_bytes is not None else None, - mime_type=existing.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - - base_dir, subdirs = resolve_destination_from_tags(spec.tags) - dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir - os.makedirs(dest_dir, exist_ok=True) - - src_for_ext = (client_filename or spec.name or "").strip() - _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" - ext = _ext if 0 < len(_ext) <= 16 else "" - hashed_basename = f"{digest}{ext}" - dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) - ensure_within_base(dest_abs, base_dir) - - content_type = ( - mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] - or mimetypes.guess_type(hashed_basename, strict=False)[0] - or "application/octet-stream" - ) - - try: - os.replace(temp_path, dest_abs) - except Exception as e: - raise RuntimeError(f"failed to move uploaded file into place: {e}") - - try: - size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) - except OSError as e: - raise RuntimeError(f"failed to stat destination file: {e}") - - with create_session() as session: - result = ingest_fs_asset( - session, - asset_hash=asset_hash, - abs_path=dest_abs, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - mime_type=content_type, - info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), - owner_id=owner_id, - preview_id=None, - user_metadata=spec.user_metadata or {}, - tags=spec.tags, - tag_origin="manual", - require_existing_tags=False, - ) - info_id = result["asset_info_id"] - if not info_id: - raise RuntimeError("failed to create asset metadata") - - pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) - if not pair: - raise RuntimeError("inconsistent DB state after ingest") - info, asset = pair - tag_names = get_asset_tags(session, asset_info_id=info.id) - created_result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=result["asset_created"], - ) - session.commit() - - return created_result - - -def update_asset( - *, - asset_info_id: str, - name: str | None = None, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetUpdated: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - info = update_asset_info_full( - session, - asset_info_id=asset_info_id, - name=name, - tags=tags, - user_metadata=user_metadata, - tag_origin="manual", - asset_info_row=info_row, - ) - - tag_names = get_asset_tags(session, asset_info_id=asset_info_id) - result = schemas_out.AssetUpdated( - id=info.id, - name=info.name, - asset_hash=info.asset.hash if info.asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - updated_at=info.updated_at, - ) - session.commit() - - return result - - -def set_asset_preview( - *, - asset_info_id: str, - preview_asset_id: str | None = None, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - set_asset_info_preview( - session, - asset_info_id=asset_info_id, - preview_asset_id=preview_asset_id, - ) - - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise RuntimeError("State changed during preview update") - info, asset, tags = res - result = schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - session.commit() - - return result - - -def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - asset_id = info_row.asset_id if info_row else None - deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not deleted: - session.commit() - return False - - if not delete_content_if_orphan or not asset_id: - session.commit() - return True - - still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id) - if still_exists: - session.commit() - return True - - states = list_cache_states_by_asset_id(session, asset_id=asset_id) - file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] - - asset_row = session.get(Asset, asset_id) - if asset_row is not None: - session.delete(asset_row) - - session.commit() - for p in file_paths: - with contextlib.suppress(Exception): - if p and os.path.isfile(p): - os.remove(p) - return True - - -def create_asset_from_hash( - *, - hash_str: str, - name: str, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetCreated | None: - canonical = hash_str.strip().lower() - with create_session() as session: - asset = get_asset_by_hash(session, asset_hash=canonical) - if not asset: - return None - - info = create_asset_info_for_existing_asset( - session, - asset_hash=canonical, - name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), - user_metadata=user_metadata or {}, - tags=tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - session.commit() - - return result - - -def add_tags_to_asset( - *, - asset_info_id: str, - tags: list[str], - origin: str = "manual", - owner_id: str = "", -) -> schemas_out.TagsAdd: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - data = add_tags_to_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=origin, - create_if_missing=True, - asset_info_row=info_row, - ) - session.commit() - return schemas_out.TagsAdd(**data) - - -def remove_tags_from_asset( - *, - asset_info_id: str, - tags: list[str], - owner_id: str = "", -) -> schemas_out.TagsRemove: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - data = remove_tags_from_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - ) - session.commit() - return schemas_out.TagsRemove(**data) - - -def list_tags( - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - order: str = "count_desc", - include_zero: bool = True, - owner_id: str = "", -) -> schemas_out.TagsList: - limit = max(1, min(1000, limit)) - offset = max(0, offset) - - with create_session() as session: - rows, total = list_tags_with_usage( - session, - prefix=prefix, - limit=limit, - offset=offset, - include_zero=include_zero, - order=order, - owner_id=owner_id, - ) - - tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] - return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 0172a5c2f..e27ea5123 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -1,263 +1,567 @@ -import contextlib -import time import logging import os -import sqlalchemy +from pathlib import Path +from typing import Callable, Literal, TypedDict import folder_paths -from app.database.db import create_session, dependencies_available -from app.assets.helpers import ( - collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, - list_tree,prefixes_for_root, escape_like_prefix, - RootType +from app.assets.database.queries import ( + add_missing_tag_for_asset_id, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + delete_orphaned_seed_asset, + delete_references_by_ids, + ensure_tags_exist, + get_asset_by_hash, + get_references_for_prefixes, + get_unenriched_references, + mark_references_missing_outside_prefixes, + reassign_asset_references, + remove_missing_tag_for_asset_id, + set_reference_metadata, + update_asset_hash_and_mime, ) -from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id -from app.assets.database.bulk_ops import seed_from_paths_batch -from app.assets.database.models import Asset, AssetCacheState, AssetInfo +from app.assets.services.bulk_ingest import ( + SeedAssetSpec, + batch_insert_seed_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + is_visible, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.metadata_extract import extract_file_metadata +from app.assets.services.path_utils import ( + compute_relative_filename, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, +) +from app.database.db import create_session -def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: - """ - Scan the given roots and seed the assets into the database. - """ - if not dependencies_available(): - if enable_logging: - logging.warning("Database dependencies not available, skipping assets scan") - return - t_start = time.perf_counter() - created = 0 - skipped_existing = 0 - orphans_pruned = 0 - paths: list[str] = [] - try: - existing_paths: set[str] = set() - for r in roots: - try: - survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) - if survivors: - existing_paths.update(survivors) - except Exception as e: - logging.exception("fast DB scan failed for %s: %s", r, e) +class _RefInfo(TypedDict): + ref_id: str + file_path: str + exists: bool + stat_unchanged: bool + needs_verify: bool - try: - orphans_pruned = _prune_orphaned_assets(roots) - except Exception as e: - logging.exception("orphan pruning failed: %s", e) - if "models" in roots: - paths.extend(collect_models_files()) - if "input" in roots: - paths.extend(list_tree(folder_paths.get_input_directory())) - if "output" in roots: - paths.extend(list_tree(folder_paths.get_output_directory())) +class _AssetAccumulator(TypedDict): + hash: str | None + size_db: int + refs: list[_RefInfo] - specs: list[dict] = [] - tag_pool: set[str] = set() - for p in paths: - abs_p = os.path.abspath(p) - if abs_p in existing_paths: - skipped_existing += 1 + +RootType = Literal["models", "input", "output"] + + +def get_prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def get_all_known_prefixes() -> list[str]: + """Get all known asset prefixes across all root types.""" + all_roots: tuple[RootType, ...] = ("models", "input", "output") + return [p for root in all_roots for p in get_prefixes_for_root(root)] + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + if not all(is_visible(part) for part in Path(rel_path).parts): continue - try: - stat_p = os.stat(abs_p, follow_symlinks=False) - except OSError: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: continue - # skip empty files - if not stat_p.st_size: - continue - name, tags = get_name_and_tags_from_asset_path(abs_p) - specs.append( - { - "abs_path": abs_p, - "size_bytes": stat_p.st_size, - "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), - "info_name": name, - "tags": tags, - "fname": compute_relative_filename(abs_p), - } - ) - for t in tags: - tag_pool.add(t) - # if no file specs, nothing to do - if not specs: - return - with create_session() as sess: - if tag_pool: - ensure_tags_exist(sess, tag_pool, tag_type="user") - - result = seed_from_paths_batch(sess, specs=specs, owner_id="") - created += result["inserted_infos"] - sess.commit() - finally: - if enable_logging: - logging.info( - "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)", - roots, - time.perf_counter() - t_start, - created, - skipped_existing, - orphans_pruned, - len(paths), - ) + abs_path = os.path.abspath(abs_path) + allowed = False + abs_p = Path(abs_path) + for b in bases: + if abs_p.is_relative_to(os.path.abspath(b)): + allowed = True + break + if allowed: + out.append(abs_path) + return out -def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: - """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" - all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] - if not all_prefixes: - return 0 - - def make_prefix_condition(prefix: str): - base = prefix if prefix.endswith(os.sep) else prefix + os.sep - escaped, esc = escape_like_prefix(base) - return AssetCacheState.file_path.like(escaped + "%", escape=esc) - - matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) - - orphan_subq = ( - sqlalchemy.select(Asset.id) - .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) - .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) - ).scalar_subquery() - - with create_session() as sess: - sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) - result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) - sess.commit() - return result.rowcount - - -def _fast_db_consistency_pass( +def sync_references_with_filesystem( + session, root: RootType, - *, collect_existing_paths: bool = False, update_missing_tags: bool = False, ) -> set[str] | None: - """Fast DB+FS pass for a root: - - Toggle needs_verify per state using fast check - - For hashed assets with at least one fast-ok state in this root: delete stale missing states - - For seed assets with all states missing: delete Asset and its AssetInfos - - Optionally add/remove 'missing' tags based on fast-ok in this root - - Optionally return surviving absolute paths + """Reconcile asset references with filesystem for a root. + + - Toggle needs_verify per reference using mtime/size stat check + - For hashed assets with at least one stat-unchanged ref: delete stale missing refs + - For seed assets with all refs missing: delete Asset and its references + - Optionally add/remove 'missing' tags based on stat check in this root + - Optionally return surviving absolute paths + + Args: + session: Database session + root: Root type to scan + collect_existing_paths: If True, return set of surviving file paths + update_missing_tags: If True, update 'missing' tags based on file status + + Returns: + Set of surviving absolute paths if collect_existing_paths=True, else None """ - prefixes = prefixes_for_root(root) + prefixes = get_prefixes_for_root(root) if not prefixes: return set() if collect_existing_paths else None - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_like_prefix(base) - conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + rows = get_references_for_prefixes( + session, prefixes, include_missing=update_missing_tags + ) + + by_asset: dict[str, _AssetAccumulator] = {} + for row in rows: + acc = by_asset.get(row.asset_id) + if acc is None: + acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []} + by_asset[row.asset_id] = acc + + stat_unchanged = False + try: + exists = True + stat_unchanged = verify_file_unchanged( + mtime_db=row.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(row.file_path, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except PermissionError: + exists = True + logging.debug("Permission denied accessing %s", row.file_path) + except OSError as e: + exists = False + logging.debug("OSError checking %s: %s", row.file_path, e) + + acc["refs"].append( + { + "ref_id": row.reference_id, + "file_path": row.file_path, + "exists": exists, + "stat_unchanged": stat_unchanged, + "needs_verify": row.needs_verify, + } + ) + + to_set_verify: list[str] = [] + to_clear_verify: list[str] = [] + stale_ref_ids: list[str] = [] + to_mark_missing: list[str] = [] + to_clear_missing: list[str] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + refs = acc["refs"] + any_unchanged = any(r["stat_unchanged"] for r in refs) + all_missing = all(not r["exists"] for r in refs) + + for r in refs: + if not r["exists"]: + to_mark_missing.append(r["ref_id"]) + continue + if r["stat_unchanged"]: + to_clear_missing.append(r["ref_id"]) + if r["needs_verify"]: + to_clear_verify.append(r["ref_id"]) + if not r["stat_unchanged"] and not r["needs_verify"]: + to_set_verify.append(r["ref_id"]) + + if a_hash is None: + if refs and all_missing: + delete_orphaned_seed_asset(session, aid) + else: + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + continue + + if any_unchanged: + for r in refs: + if not r["exists"]: + stale_ref_ids.append(r["ref_id"]) + if update_missing_tags: + try: + remove_missing_tag_for_asset_id(session, asset_id=aid) + except Exception as e: + logging.warning( + "Failed to remove missing tag for asset %s: %s", aid, e + ) + elif update_missing_tags: + try: + add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic") + except Exception as e: + logging.warning("Failed to add missing tag for asset %s: %s", aid, e) + + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + + delete_references_by_ids(session, stale_ref_ids) + stale_set = set(stale_ref_ids) + to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set] + bulk_update_is_missing(session, to_mark_missing, value=True) + bulk_update_is_missing(session, to_clear_missing, value=False) + bulk_update_needs_verify(session, to_set_verify, value=True) + bulk_update_needs_verify(session, to_clear_verify, value=False) + + return survivors if collect_existing_paths else None + + +def sync_root_safely(root: RootType) -> set[str]: + """Sync a single root's references with the filesystem. + + Returns survivors (existing paths) or empty set on failure. + """ + try: + with create_session() as sess: + survivors = sync_references_with_filesystem( + sess, + root, + collect_existing_paths=True, + update_missing_tags=True, + ) + sess.commit() + return survivors or set() + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", root, e) + return set() + + +def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int: + """Mark references as missing when outside the given prefixes. + + This is a non-destructive soft-delete. Returns count marked or 0 on failure. + """ + try: + with create_session() as sess: + count = mark_references_missing_outside_prefixes(sess, prefixes) + sess.commit() + return count + except Exception as e: + logging.exception("marking missing assets failed: %s", e) + return 0 + + +def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: + """Collect all file paths for the given roots.""" + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_files_recursively(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_files_recursively(folder_paths.get_output_directory())) + return paths + + +def build_asset_specs( + paths: list[str], + existing_paths: set[str], + enable_metadata_extraction: bool = True, + compute_hashes: bool = False, +) -> tuple[list[SeedAssetSpec], set[str], int]: + """Build asset specs from paths, returning (specs, tag_pool, skipped_count). + + Args: + paths: List of file paths to process + existing_paths: Set of paths that already exist in the database + enable_metadata_extraction: If True, extract tier 1 & 2 metadata + compute_hashes: If True, compute blake3 hashes (slow for large files) + """ + specs: list[SeedAssetSpec] = [] + tag_pool: set[str] = set() + skipped = 0 + + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=True) + except OSError: + continue + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + rel_fname = compute_relative_filename(abs_p) + + # Extract metadata (tier 1: filesystem, tier 2: safetensors header) + metadata = None + if enable_metadata_extraction: + metadata = extract_file_metadata( + abs_p, + stat_result=stat_p, + relative_filename=rel_fname, + ) + + # Compute hash if requested + asset_hash: str | None = None + if compute_hashes: + try: + digest, _ = compute_blake3_hash(abs_p) + asset_hash = "blake3:" + digest + except Exception as e: + logging.warning("Failed to hash %s: %s", abs_p, e) + + mime_type = metadata.content_type if metadata else None + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": get_mtime_ns(stat_p), + "info_name": name, + "tags": tags, + "fname": rel_fname, + "metadata": metadata, + "hash": asset_hash, + "mime_type": mime_type, + } + ) + tag_pool.update(tags) + + return specs, tag_pool, skipped + + + +def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: + """Insert asset specs into database, returning count of created refs.""" + if not specs: + return 0 + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + result = batch_insert_seed_assets(sess, specs=specs, owner_id="") + sess.commit() + return result.inserted_refs + + +# Enrichment level constants +ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only +ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type) +ENRICHMENT_HASHED = 2 # Hash computed (blake3) + + +def get_unenriched_assets_for_roots( + roots: tuple[RootType, ...], + max_level: int = ENRICHMENT_STUB, + limit: int = 1000, +) -> list: + """Get assets that need enrichment for the given roots. + + Args: + roots: Tuple of root types to scan + max_level: Maximum enrichment level to include + limit: Maximum number of rows to return + + Returns: + List of UnenrichedReferenceRow + """ + prefixes: list[str] = [] + for root in roots: + prefixes.extend(get_prefixes_for_root(root)) + + if not prefixes: + return [] with create_session() as sess: - rows = ( - sess.execute( - sqlalchemy.select( - AssetCacheState.id, - AssetCacheState.file_path, - AssetCacheState.mtime_ns, - AssetCacheState.needs_verify, - AssetCacheState.asset_id, - Asset.hash, - Asset.size_bytes, - ) - .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(sqlalchemy.or_(*conds)) - .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + return get_unenriched_references( + sess, prefixes, max_level=max_level, limit=limit + ) + + +def enrich_asset( + session, + file_path: str, + reference_id: str, + asset_id: str, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> int: + """Enrich a single asset with metadata and/or hash. + + Args: + session: Database session (caller manages lifecycle) + file_path: Absolute path to the file + reference_id: ID of the reference to update + asset_id: ID of the asset to update (for mime_type and hash) + extract_metadata: If True, extract safetensors header and mime type + compute_hash: If True, compute blake3 hash + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + New enrichment level achieved + """ + new_level = ENRICHMENT_STUB + + try: + stat_p = os.stat(file_path, follow_symlinks=True) + except OSError: + return new_level + + rel_fname = compute_relative_filename(file_path) + mime_type: str | None = None + metadata = None + + if extract_metadata: + metadata = extract_file_metadata( + file_path, + stat_result=stat_p, + relative_filename=rel_fname, + ) + if metadata: + mime_type = metadata.content_type + new_level = ENRICHMENT_METADATA + + full_hash: str | None = None + if compute_hash: + try: + mtime_before = get_mtime_ns(stat_p) + size_before = stat_p.st_size + + # Restore checkpoint if available and file unchanged + checkpoint = None + if hash_checkpoints is not None: + checkpoint = hash_checkpoints.get(file_path) + if checkpoint is not None: + cur_stat = os.stat(file_path, follow_symlinks=True) + if (checkpoint.mtime_ns != get_mtime_ns(cur_stat) + or checkpoint.file_size != cur_stat.st_size): + checkpoint = None + hash_checkpoints.pop(file_path, None) + else: + mtime_before = get_mtime_ns(cur_stat) + + digest, new_checkpoint = compute_blake3_hash( + file_path, + interrupt_check=interrupt_check, + checkpoint=checkpoint, ) - ).all() - by_asset: dict[str, dict] = {} - for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: - acc = by_asset.get(aid) - if acc is None: - acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} - by_asset[aid] = acc + if digest is None: + # Interrupted — save checkpoint for later resumption + if hash_checkpoints is not None and new_checkpoint is not None: + new_checkpoint.mtime_ns = mtime_before + new_checkpoint.file_size = size_before + hash_checkpoints[file_path] = new_checkpoint + return new_level + + # Completed — clear any saved checkpoint + if hash_checkpoints is not None: + hash_checkpoints.pop(file_path, None) + + stat_after = os.stat(file_path, follow_symlinks=True) + mtime_after = get_mtime_ns(stat_after) + if mtime_before != mtime_after: + logging.warning("File modified during hashing, discarding hash: %s", file_path) + else: + full_hash = f"blake3:{digest}" + metadata_ok = not extract_metadata or metadata is not None + if metadata_ok: + new_level = ENRICHMENT_HASHED + except Exception as e: + logging.warning("Failed to hash %s: %s", file_path, e) + + if extract_metadata and metadata: + user_metadata = metadata.to_user_metadata() + set_reference_metadata(session, reference_id, user_metadata) + + if full_hash: + existing = get_asset_by_hash(session, full_hash) + if existing and existing.id != asset_id: + reassign_asset_references(session, asset_id, existing.id, reference_id) + delete_orphaned_seed_asset(session, asset_id) + if mime_type: + update_asset_hash_and_mime(session, existing.id, mime_type=mime_type) + else: + update_asset_hash_and_mime(session, asset_id, full_hash, mime_type) + elif mime_type: + update_asset_hash_and_mime(session, asset_id, mime_type=mime_type) + + bulk_update_enrichment_level(session, [reference_id], new_level) + session.commit() + + return new_level + + +def enrich_assets_batch( + rows: list, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> tuple[int, list[str]]: + """Enrich a batch of assets. + + Uses a single DB session for the entire batch, committing after each + individual asset to avoid long-held transactions while eliminating + per-asset session creation overhead. + + Args: + rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots + extract_metadata: If True, extract metadata for each asset + compute_hash: If True, compute hash for each asset + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + Tuple of (enriched_count, failed_reference_ids) + """ + enriched = 0 + failed_ids: list[str] = [] + + with create_session() as sess: + for row in rows: + if interrupt_check is not None and interrupt_check(): + break - fast_ok = False try: - exists = True - fast_ok = fast_asset_file_check( - mtime_db=mtime_db, - size_db=acc["size_db"], - stat_result=os.stat(fp, follow_symlinks=True), + new_level = enrich_asset( + sess, + file_path=row.file_path, + reference_id=row.reference_id, + asset_id=row.asset_id, + extract_metadata=extract_metadata, + compute_hash=compute_hash, + interrupt_check=interrupt_check, + hash_checkpoints=hash_checkpoints, ) - except FileNotFoundError: - exists = False - except OSError: - exists = False - - acc["states"].append({ - "sid": sid, - "fp": fp, - "exists": exists, - "fast_ok": fast_ok, - "needs_verify": bool(needs_verify), - }) - - to_set_verify: list[int] = [] - to_clear_verify: list[int] = [] - stale_state_ids: list[int] = [] - survivors: set[str] = set() - - for aid, acc in by_asset.items(): - a_hash = acc["hash"] - states = acc["states"] - any_fast_ok = any(s["fast_ok"] for s in states) - all_missing = all(not s["exists"] for s in states) - - for s in states: - if not s["exists"]: - continue - if s["fast_ok"] and s["needs_verify"]: - to_clear_verify.append(s["sid"]) - if not s["fast_ok"] and not s["needs_verify"]: - to_set_verify.append(s["sid"]) - - if a_hash is None: - if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) - asset = sess.get(Asset, aid) - if asset: - sess.delete(asset) + if new_level > row.enrichment_level: + enriched += 1 else: - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - continue + failed_ids.append(row.reference_id) + except Exception as e: + logging.warning("Failed to enrich %s: %s", row.file_path, e) + sess.rollback() + failed_ids.append(row.reference_id) - if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records - for s in states: - if not s["exists"]: - stale_state_ids.append(s["sid"]) - if update_missing_tags: - with contextlib.suppress(Exception): - remove_missing_tag_for_asset_id(sess, asset_id=aid) - elif update_missing_tags: - with contextlib.suppress(Exception): - add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - - if stale_state_ids: - sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) - if to_set_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_set_verify)) - .values(needs_verify=True) - ) - if to_clear_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_clear_verify)) - .values(needs_verify=False) - ) - sess.commit() - return survivors if collect_existing_paths else None + return enriched, failed_ids diff --git a/app/assets/seeder.py b/app/assets/seeder.py new file mode 100644 index 000000000..029448464 --- /dev/null +++ b/app/assets/seeder.py @@ -0,0 +1,794 @@ +"""Background asset seeder with thread management and cancellation support.""" + +import logging +import os +import threading +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable + +from app.assets.scanner import ( + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + RootType, + build_asset_specs, + collect_paths_for_roots, + enrich_assets_batch, + get_all_known_prefixes, + get_prefixes_for_root, + get_unenriched_assets_for_roots, + insert_asset_specs, + mark_missing_outside_prefixes_safely, + sync_root_safely, +) +from app.database.db import dependencies_available + + +class ScanInProgressError(Exception): + """Raised when an operation cannot proceed because a scan is running.""" + + +class State(Enum): + """Seeder state machine states.""" + + IDLE = "IDLE" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + CANCELLING = "CANCELLING" + + +class ScanPhase(Enum): + """Scan phase options.""" + + FAST = "fast" # Phase 1: filesystem only (stubs) + ENRICH = "enrich" # Phase 2: metadata + hash + FULL = "full" # Both phases sequentially + + +@dataclass +class Progress: + """Progress information for a scan operation.""" + + scanned: int = 0 + total: int = 0 + created: int = 0 + skipped: int = 0 + + +@dataclass +class ScanStatus: + """Current status of the asset seeder.""" + + state: State + progress: Progress | None + errors: list[str] = field(default_factory=list) + + +ProgressCallback = Callable[[Progress], None] + + +class _AssetSeeder: + """Background asset scanning manager. + + Spawns ephemeral daemon threads for scanning. + Each scan creates a new thread that exits when complete. + Use the module-level ``asset_seeder`` instance. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = State.IDLE + self._progress: Progress | None = None + self._last_progress: Progress | None = None + self._errors: list[str] = [] + self._thread: threading.Thread | None = None + self._cancel_event = threading.Event() + self._run_gate = threading.Event() + self._run_gate.set() # Start unpaused (set = running, clear = paused) + self._roots: tuple[RootType, ...] = () + self._phase: ScanPhase = ScanPhase.FULL + self._compute_hashes: bool = False + self._prune_first: bool = False + self._progress_callback: ProgressCallback | None = None + self._disabled: bool = False + + def disable(self) -> None: + """Disable the asset seeder, preventing any scans from starting.""" + self._disabled = True + logging.info("Asset seeder disabled") + + def is_disabled(self) -> bool: + """Check if the asset seeder is disabled.""" + return self._disabled + + def start( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + phase: ScanPhase = ScanPhase.FULL, + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + compute_hashes: bool = False, + ) -> bool: + """Start a background scan for the given roots. + + Args: + roots: Tuple of root types to scan (models, input, output) + phase: Scan phase to run (FAST, ENRICH, or FULL for both) + progress_callback: Optional callback called with progress updates + prune_first: If True, prune orphaned assets before scanning + compute_hashes: If True, compute blake3 hashes (slow) + + Returns: + True if scan was started, False if already running + """ + if self._disabled: + logging.debug("Asset seeder is disabled, skipping start") + return False + logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value) + with self._lock: + if self._state != State.IDLE: + logging.info("Asset seeder already running, skipping start") + return False + self._state = State.RUNNING + self._progress = Progress() + self._errors = [] + self._roots = roots + self._phase = phase + self._prune_first = prune_first + self._compute_hashes = compute_hashes + self._progress_callback = progress_callback + self._cancel_event.clear() + self._run_gate.set() # Ensure unpaused when starting + self._thread = threading.Thread( + target=self._run_scan, + name="_AssetSeeder", + daemon=True, + ) + self._thread.start() + return True + + def start_fast( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + ) -> bool: + """Start a fast scan (phase 1 only) - creates stub records. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + prune_first: If True, prune orphaned assets before scanning + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.FAST, + progress_callback=progress_callback, + prune_first=prune_first, + compute_hashes=False, + ) + + def start_enrich( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + compute_hashes: bool = False, + ) -> bool: + """Start an enrichment scan (phase 2 only) - extracts metadata and hashes. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + compute_hashes: If True, compute blake3 hashes + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.ENRICH, + progress_callback=progress_callback, + prune_first=False, + compute_hashes=compute_hashes, + ) + + def cancel(self) -> bool: + """Request cancellation of the current scan. + + Returns: + True if cancellation was requested, False if not running or paused + """ + with self._lock: + if self._state not in (State.RUNNING, State.PAUSED): + return False + logging.info("Asset seeder cancelling (was %s)", self._state.value) + self._state = State.CANCELLING + self._cancel_event.set() + self._run_gate.set() # Unblock if paused so thread can exit + return True + + def stop(self) -> bool: + """Stop the current scan (alias for cancel). + + Returns: + True if stop was requested, False if not running + """ + return self.cancel() + + def pause(self) -> bool: + """Pause the current scan. + + The scan will complete its current batch before pausing. + + Returns: + True if pause was requested, False if not running + """ + with self._lock: + if self._state != State.RUNNING: + return False + logging.info("Asset seeder pausing") + self._state = State.PAUSED + self._run_gate.clear() + return True + + def resume(self) -> bool: + """Resume a paused scan. + + This is a noop if the scan is not in the PAUSED state + + Returns: + True if resumed, False if not paused + """ + with self._lock: + if self._state != State.PAUSED: + return False + logging.info("Asset seeder resuming") + self._state = State.RUNNING + self._run_gate.set() + self._emit_event("assets.seed.resumed", {}) + return True + + def restart( + self, + roots: tuple[RootType, ...] | None = None, + phase: ScanPhase | None = None, + progress_callback: ProgressCallback | None = None, + prune_first: bool | None = None, + compute_hashes: bool | None = None, + timeout: float = 5.0, + ) -> bool: + """Cancel any running scan and start a new one. + + Args: + roots: Roots to scan (defaults to previous roots) + phase: Scan phase (defaults to previous phase) + progress_callback: Progress callback (defaults to previous) + prune_first: Prune before scan (defaults to previous) + compute_hashes: Compute hashes (defaults to previous) + timeout: Max seconds to wait for current scan to stop + + Returns: + True if new scan was started, False if failed to stop previous + """ + logging.info("Asset seeder restart requested") + with self._lock: + prev_roots = self._roots + prev_phase = self._phase + prev_callback = self._progress_callback + prev_prune = self._prune_first + prev_hashes = self._compute_hashes + + self.cancel() + if not self.wait(timeout=timeout): + return False + + cb = progress_callback if progress_callback is not None else prev_callback + return self.start( + roots=roots if roots is not None else prev_roots, + phase=phase if phase is not None else prev_phase, + progress_callback=cb, + prune_first=prune_first if prune_first is not None else prev_prune, + compute_hashes=( + compute_hashes if compute_hashes is not None else prev_hashes + ), + ) + + def wait(self, timeout: float | None = None) -> bool: + """Wait for the current scan to complete. + + Args: + timeout: Maximum seconds to wait, or None for no timeout + + Returns: + True if scan completed, False if timeout expired or no scan running + """ + with self._lock: + thread = self._thread + if thread is None: + return True + thread.join(timeout=timeout) + return not thread.is_alive() + + def get_status(self) -> ScanStatus: + """Get the current status and progress of the seeder.""" + with self._lock: + src = self._progress or self._last_progress + return ScanStatus( + state=self._state, + progress=Progress( + scanned=src.scanned, + total=src.total, + created=src.created, + skipped=src.skipped, + ) + if src + else None, + errors=list(self._errors), + ) + + def shutdown(self, timeout: float = 5.0) -> None: + """Gracefully shutdown: cancel any running scan and wait for thread. + + Args: + timeout: Maximum seconds to wait for thread to exit + """ + self.cancel() + self.wait(timeout=timeout) + with self._lock: + self._thread = None + + def mark_missing_outside_prefixes(self) -> int: + """Mark references as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and their + metadata are preserved, but references are flagged as missing. + They can be restored if the file reappears in a future scan. + + This operation is decoupled from scanning to prevent partial scans + from accidentally marking assets belonging to other roots. + + Should be called explicitly when cleanup is desired, typically after + a full scan of all roots or during maintenance. + + Returns: + Number of references marked as missing + + Raises: + ScanInProgressError: If a scan is currently running + """ + with self._lock: + if self._state != State.IDLE: + raise ScanInProgressError( + "Cannot mark missing assets while scan is running" + ) + self._state = State.RUNNING + + try: + if not dependencies_available(): + logging.warning( + "Database dependencies not available, skipping mark missing" + ) + return 0 + + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d references as missing", marked) + return marked + finally: + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _is_cancelled(self) -> bool: + """Check if cancellation has been requested.""" + return self._cancel_event.is_set() + + def _is_paused_or_cancelled(self) -> bool: + """Non-blocking check: True if paused or cancelled. + + Use as interrupt_check for I/O-bound work (e.g. hashing) so that + file handles are released immediately on pause rather than held + open while blocked. The caller is responsible for blocking on + _check_pause_and_cancel() afterward. + """ + return not self._run_gate.is_set() or self._cancel_event.is_set() + + def _check_pause_and_cancel(self) -> bool: + """Block while paused, then check if cancelled. + + Call this at checkpoint locations in scan loops. It will: + 1. Block indefinitely while paused (until resume or cancel) + 2. Return True if cancelled, False to continue + + Returns: + True if scan should stop, False to continue + """ + if not self._run_gate.is_set(): + self._emit_event("assets.seed.paused", {}) + self._run_gate.wait() # Blocks if paused + return self._is_cancelled() + + def _emit_event(self, event_type: str, data: dict) -> None: + """Emit a WebSocket event if server is available.""" + try: + from server import PromptServer + + if hasattr(PromptServer, "instance") and PromptServer.instance: + PromptServer.instance.send_sync(event_type, data) + except Exception: + pass + + def _update_progress( + self, + scanned: int | None = None, + total: int | None = None, + created: int | None = None, + skipped: int | None = None, + ) -> None: + """Update progress counters (thread-safe).""" + callback: ProgressCallback | None = None + progress: Progress | None = None + + with self._lock: + if self._progress is None: + return + if scanned is not None: + self._progress.scanned = scanned + if total is not None: + self._progress.total = total + if created is not None: + self._progress.created = created + if skipped is not None: + self._progress.skipped = skipped + if self._progress_callback: + callback = self._progress_callback + progress = Progress( + scanned=self._progress.scanned, + total=self._progress.total, + created=self._progress.created, + skipped=self._progress.skipped, + ) + + if callback and progress: + try: + callback(progress) + except Exception: + pass + + _MAX_ERRORS = 200 + + def _add_error(self, message: str) -> None: + """Add an error message (thread-safe), capped at _MAX_ERRORS.""" + with self._lock: + if len(self._errors) < self._MAX_ERRORS: + self._errors.append(message) + + def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: + """Log the directories that will be scanned.""" + import folder_paths + + for root in roots: + if root == "models": + logging.info( + "Asset scan [models] directory: %s", + os.path.abspath(folder_paths.models_dir), + ) + else: + prefixes = get_prefixes_for_root(root) + if prefixes: + logging.info("Asset scan [%s] directories: %s", root, prefixes) + + def _run_scan(self) -> None: + """Main scan loop running in background thread.""" + t_start = time.perf_counter() + roots = self._roots + phase = self._phase + cancelled = False + total_created = 0 + total_enriched = 0 + skipped_existing = 0 + total_paths = 0 + + try: + if not dependencies_available(): + self._add_error("Database dependencies not available") + self._emit_event( + "assets.seed.error", + {"message": "Database dependencies not available"}, + ) + return + + if self._prune_first: + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d refs as missing before scan", marked) + + if self._check_pause_and_cancel(): + logging.info("Asset scan cancelled after pruning phase") + cancelled = True + return + + self._log_scan_config(roots) + + # Phase 1: Fast scan (stub records) + if phase in (ScanPhase.FAST, ScanPhase.FULL): + created, skipped, paths = self._run_fast_phase(roots) + total_created, skipped_existing, total_paths = created, skipped, paths + + if self._check_pause_and_cancel(): + cancelled = True + return + + self._emit_event( + "assets.seed.fast_complete", + { + "roots": list(roots), + "created": total_created, + "skipped": skipped_existing, + "total": total_paths, + }, + ) + + # Phase 2: Enrichment scan (metadata + hashes) + if phase in (ScanPhase.ENRICH, ScanPhase.FULL): + if self._check_pause_and_cancel(): + cancelled = True + return + + enrich_cancelled, total_enriched = self._run_enrich_phase(roots) + + if enrich_cancelled: + cancelled = True + return + + self._emit_event( + "assets.seed.enrich_complete", + { + "roots": list(roots), + "enriched": total_enriched, + }, + ) + + elapsed = time.perf_counter() - t_start + logging.info( + "Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d", + roots, + phase.value, + elapsed, + total_created, + total_enriched, + skipped_existing, + ) + + self._emit_event( + "assets.seed.completed", + { + "phase": phase.value, + "total": total_paths, + "created": total_created, + "enriched": total_enriched, + "skipped": skipped_existing, + "elapsed": round(elapsed, 3), + }, + ) + + except Exception as e: + self._add_error(f"Scan failed: {e}") + logging.exception("Asset scan failed") + self._emit_event("assets.seed.error", {"message": str(e)}) + finally: + if cancelled: + self._emit_event( + "assets.seed.cancelled", + { + "scanned": self._progress.scanned if self._progress else 0, + "total": total_paths, + "created": total_created, + }, + ) + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: + """Run phase 1: fast scan to create stub records. + + Returns: + Tuple of (total_created, skipped_existing, total_paths) + """ + t_fast_start = time.perf_counter() + total_created = 0 + skipped_existing = 0 + + existing_paths: set[str] = set() + t_sync = time.perf_counter() + for r in roots: + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + existing_paths.update(sync_root_safely(r)) + logging.debug( + "Fast scan: sync_root phase took %.3fs (%d existing paths)", + time.perf_counter() - t_sync, + len(existing_paths), + ) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + + t_collect = time.perf_counter() + paths = collect_paths_for_roots(roots) + logging.debug( + "Fast scan: collect_paths took %.3fs (%d paths found)", + time.perf_counter() - t_collect, + len(paths), + ) + total_paths = len(paths) + self._update_progress(total=total_paths) + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "total": total_paths, "phase": "fast"}, + ) + + # Use stub specs (no metadata extraction, no hashing) + t_specs = time.perf_counter() + specs, tag_pool, skipped_existing = build_asset_specs( + paths, + existing_paths, + enable_metadata_extraction=False, + compute_hashes=False, + ) + logging.debug( + "Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)", + time.perf_counter() - t_specs, + len(specs), + skipped_existing, + ) + self._update_progress(skipped=skipped_existing) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, total_paths + + batch_size = 500 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + for i in range(0, len(specs), batch_size): + if self._check_pause_and_cancel(): + logging.info( + "Fast scan cancelled after %d/%d files (created=%d)", + i, + len(specs), + total_created, + ) + return total_created, skipped_existing, total_paths + + batch = specs[i : i + batch_size] + batch_tags = {t for spec in batch for t in spec["tags"]} + try: + created = insert_asset_specs(batch, batch_tags) + total_created += created + except Exception as e: + self._add_error(f"Batch insert failed at offset {i}: {e}") + logging.exception("Batch insert failed at offset %d", i) + + scanned = i + len(batch) + now = time.perf_counter() + self._update_progress(scanned=scanned, created=total_created) + + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "fast", + "scanned": scanned, + "total": len(specs), + "created": total_created, + }, + ) + last_progress_time = now + + self._update_progress(scanned=len(specs), created=total_created) + logging.info( + "Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)", + time.perf_counter() - t_fast_start, + total_created, + skipped_existing, + total_paths, + ) + return total_created, skipped_existing, total_paths + + def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]: + """Run phase 2: enrich existing records with metadata and hashes. + + Returns: + Tuple of (cancelled, total_enriched) + """ + total_enriched = 0 + batch_size = 100 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + # Get the target enrichment level based on compute_hashes + if not self._compute_hashes: + target_max_level = ENRICHMENT_STUB + else: + target_max_level = ENRICHMENT_METADATA + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "phase": "enrich"}, + ) + + skip_ids: set[str] = set() + consecutive_empty = 0 + max_consecutive_empty = 3 + + # Hash checkpoints survive across batches so interrupted hashes + # can be resumed without re-reading the entire file. + hash_checkpoints: dict[str, object] = {} + + while True: + if self._check_pause_and_cancel(): + logging.info("Enrich scan cancelled after %d assets", total_enriched) + return True, total_enriched + + # Fetch next batch of unenriched assets + unenriched = get_unenriched_assets_for_roots( + roots, + max_level=target_max_level, + limit=batch_size, + ) + + # Filter out previously failed references + if skip_ids: + unenriched = [r for r in unenriched if r.reference_id not in skip_ids] + + if not unenriched: + break + + enriched, failed_ids = enrich_assets_batch( + unenriched, + extract_metadata=True, + compute_hash=self._compute_hashes, + interrupt_check=self._is_paused_or_cancelled, + hash_checkpoints=hash_checkpoints, + ) + total_enriched += enriched + skip_ids.update(failed_ids) + + if enriched == 0: + consecutive_empty += 1 + if consecutive_empty >= max_consecutive_empty: + logging.warning( + "Enrich phase stopping: %d consecutive batches with no progress (%d skipped)", + consecutive_empty, + len(skip_ids), + ) + break + else: + consecutive_empty = 0 + + now = time.perf_counter() + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "enrich", + "enriched": total_enriched, + }, + ) + last_progress_time = now + + return False, total_enriched + + +asset_seeder = _AssetSeeder() diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py new file mode 100644 index 000000000..11fcb4122 --- /dev/null +++ b/app/assets/services/__init__.py @@ -0,0 +1,87 @@ +from app.assets.services.asset_management import ( + asset_exists, + delete_asset_reference, + get_asset_by_hash, + get_asset_detail, + list_assets_page, + resolve_asset_for_download, + set_asset_preview, + update_asset_metadata, +) +from app.assets.services.bulk_ingest import ( + BulkInsertResult, + batch_insert_seed_assets, + cleanup_unreferenced_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + get_size_and_mtime_ns, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.ingest import ( + DependencyMissingError, + HashMismatchError, + create_from_hash, + upload_from_temp_path, +) +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, +) +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + IngestResult, + ListAssetsResult, + ReferenceData, + RegisterAssetResult, + TagUsage, + UploadResult, + UserMetadata, +) +from app.assets.services.tagging import ( + apply_tags, + list_tags, + remove_tags, +) + +__all__ = [ + "AddTagsResult", + "AssetData", + "AssetDetailResult", + "AssetSummaryData", + "ReferenceData", + "BulkInsertResult", + "DependencyMissingError", + "DownloadResolutionResult", + "HashMismatchError", + "IngestResult", + "ListAssetsResult", + "RegisterAssetResult", + "RemoveTagsResult", + "TagUsage", + "UploadResult", + "UserMetadata", + "apply_tags", + "asset_exists", + "batch_insert_seed_assets", + "create_from_hash", + "delete_asset_reference", + "get_asset_by_hash", + "get_asset_detail", + "get_mtime_ns", + "get_size_and_mtime_ns", + "list_assets_page", + "list_files_recursively", + "list_tags", + "cleanup_unreferenced_assets", + "remove_tags", + "resolve_asset_for_download", + "set_asset_preview", + "update_asset_metadata", + "upload_from_temp_path", + "verify_file_unchanged", +] diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py new file mode 100644 index 000000000..3fe7115c8 --- /dev/null +++ b/app/assets/services/asset_management.py @@ -0,0 +1,309 @@ +import contextlib +import mimetypes +import os +from typing import Sequence + + +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + reference_exists_for_asset_id, + delete_reference_by_id, + fetch_reference_and_asset, + soft_delete_reference_by_id, + fetch_reference_asset_and_tags, + get_asset_by_hash as queries_get_asset_by_hash, + get_reference_by_id, + get_reference_with_owner_check, + list_references_page, + list_references_by_asset_id, + set_reference_metadata, + set_reference_preview, + set_reference_tags, + update_reference_access_time, + update_reference_name, + update_reference_updated_at, +) +from app.assets.helpers import select_best_live_path +from app.assets.services.path_utils import compute_relative_filename +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + ListAssetsResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def get_asset_detail( + reference_id: str, + owner_id: str = "", +) -> AssetDetailResult | None: + with create_session() as session: + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + return None + + ref, asset, tags = result + return AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + + +def update_asset_metadata( + reference_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: UserMetadata = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + ref = get_reference_with_owner_check(session, reference_id, owner_id) + + touched = False + if name is not None and name != ref.name: + update_reference_name(session, reference_id=reference_id, name=name) + touched = True + + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + + new_meta: dict | None = None + if user_metadata is not None: + new_meta = dict(user_metadata) + elif computed_filename: + current_meta = ref.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + + if new_meta is not None: + if computed_filename: + new_meta["filename"] = computed_filename + set_reference_metadata( + session, reference_id=reference_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_reference_tags( + session, + reference_id=reference_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + update_reference_updated_at(session, reference_id=reference_id) + + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + raise RuntimeError("State changed during update") + + ref, asset, tag_list = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_list, + ) + session.commit() + + return detail + + +def delete_asset_reference( + reference_id: str, + owner_id: str, + delete_content_if_orphan: bool = True, +) -> bool: + with create_session() as session: + if not delete_content_if_orphan: + # Soft delete: mark the reference as deleted but keep everything + deleted = soft_delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + session.commit() + return deleted + + ref_row = get_reference_by_id(session, reference_id=reference_id) + asset_id = ref_row.asset_id if ref_row else None + file_path = ref_row.file_path if ref_row else None + + deleted = delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + if not deleted: + session.commit() + return False + + if not asset_id: + session.commit() + return True + + still_exists = reference_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + # Orphaned asset - delete it and its files + refs = list_references_by_asset_id(session, asset_id=asset_id) + file_paths = [ + r.file_path for r in (refs or []) if getattr(r, "file_path", None) + ] + # Also include the just-deleted file path + if file_path: + file_paths.append(file_path) + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + + # Delete files after commit + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + + return True + + +def set_asset_preview( + reference_id: str, + preview_asset_id: str | None = None, + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + set_reference_preview( + session, + reference_id=reference_id, + preview_asset_id=preview_asset_id, + ) + + result = fetch_reference_asset_and_tags( + session, reference_id=reference_id, owner_id=owner_id + ) + if not result: + raise RuntimeError("State changed during preview update") + + ref, asset, tags = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + session.commit() + + return detail + + +def asset_exists(asset_hash: str) -> bool: + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + + +def get_asset_by_hash(asset_hash: str) -> AssetData | None: + with create_session() as session: + asset = queries_get_asset_by_hash(session, asset_hash=asset_hash) + return extract_asset_data(asset) + + +def list_assets_page( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> ListAssetsResult: + with create_session() as session: + refs, tag_map, total = list_references_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + items: list[AssetSummaryData] = [] + for ref in refs: + items.append( + AssetSummaryData( + ref=extract_reference_data(ref), + asset=extract_asset_data(ref.asset), + tags=tag_map.get(ref.id, []), + ) + ) + + return ListAssetsResult(items=items, total=total) + + +def resolve_asset_for_download( + reference_id: str, + owner_id: str = "", +) -> DownloadResolutionResult: + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise ValueError(f"AssetReference {reference_id} not found") + + ref, asset = pair + + # For references with file_path, use that directly + if ref.file_path and os.path.isfile(ref.file_path): + abs_path = ref.file_path + else: + # For API-created refs without file_path, find a path from other refs + refs = list_references_by_asset_id(session, asset_id=asset.id) + abs_path = select_best_live_path(refs) + if not abs_path: + raise FileNotFoundError( + f"No live path for AssetReference {reference_id} " + f"(asset id={asset.id}, name={ref.name})" + ) + + # Capture ORM attributes before commit (commit expires loaded objects) + ref_name = ref.name + asset_mime = asset.mime_type + + update_reference_access_time(session, reference_id=reference_id) + session.commit() + + ctype = ( + asset_mime + or mimetypes.guess_type(ref_name or abs_path)[0] + or "application/octet-stream" + ) + download_name = ref_name or os.path.basename(abs_path) + return DownloadResolutionResult( + abs_path=abs_path, + content_type=ctype, + download_name=download_name, + ) diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py new file mode 100644 index 000000000..54e72730c --- /dev/null +++ b/app/assets/services/bulk_ingest.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, TypedDict + +from sqlalchemy.orm import Session + +from app.assets.database.queries import ( + bulk_insert_assets, + bulk_insert_references_ignore_conflicts, + bulk_insert_tags_and_meta, + delete_assets_by_ids, + get_existing_asset_ids, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_unreferenced_unhashed_asset_ids, + restore_references_by_paths, +) +from app.assets.helpers import get_utc_now + +if TYPE_CHECKING: + from app.assets.services.metadata_extract import ExtractedMetadata + + +class SeedAssetSpec(TypedDict): + """Spec for seeding an asset from filesystem.""" + + abs_path: str + size_bytes: int + mtime_ns: int + info_name: str + tags: list[str] + fname: str + metadata: ExtractedMetadata | None + hash: str | None + mime_type: str | None + + +class AssetRow(TypedDict): + """Row data for inserting an Asset.""" + + id: str + hash: str | None + size_bytes: int + mime_type: str | None + created_at: datetime + + +class ReferenceRow(TypedDict): + """Row data for inserting an AssetReference.""" + + id: str + asset_id: str + file_path: str + mtime_ns: int + owner_id: str + name: str + preview_id: str | None + user_metadata: dict[str, Any] | None + created_at: datetime + updated_at: datetime + last_access_time: datetime + + +class TagRow(TypedDict): + """Row data for inserting a Tag.""" + + asset_reference_id: str + tag_name: str + origin: str + added_at: datetime + + +class MetadataRow(TypedDict): + """Row data for inserting asset metadata.""" + + asset_reference_id: str + key: str + ordinal: int + val_str: str | None + val_num: float | None + val_bool: bool | None + val_json: dict[str, Any] | None + + +@dataclass +class BulkInsertResult: + """Result of bulk asset insertion.""" + + inserted_refs: int + won_paths: int + lost_paths: int + + +def batch_insert_seed_assets( + session: Session, + specs: list[SeedAssetSpec], + owner_id: str = "", +) -> BulkInsertResult: + """Seed assets from filesystem specs in batch. + + Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + + This function orchestrates: + 1. Insert seed Assets (hash=NULL) + 2. Claim references with ON CONFLICT DO NOTHING on file_path + 3. Query to find winners (paths where our asset_id was inserted) + 4. Delete Assets for losers (path already claimed by another asset) + 5. Insert tags and metadata for successfully inserted references + + Returns: + BulkInsertResult with inserted_refs, won_paths, lost_paths + """ + if not specs: + return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + + current_time = get_utc_now() + asset_rows: list[AssetRow] = [] + reference_rows: list[ReferenceRow] = [] + path_to_asset_id: dict[str, str] = {} + asset_id_to_ref_data: dict[str, dict] = {} + absolute_path_list: list[str] = [] + + for spec in specs: + absolute_path = os.path.abspath(spec["abs_path"]) + asset_id = str(uuid.uuid4()) + reference_id = str(uuid.uuid4()) + absolute_path_list.append(absolute_path) + path_to_asset_id[absolute_path] = asset_id + + mime_type = spec.get("mime_type") + asset_rows.append( + { + "id": asset_id, + "hash": spec.get("hash"), + "size_bytes": spec["size_bytes"], + "mime_type": mime_type, + "created_at": current_time, + } + ) + + # Build user_metadata from extracted metadata or fallback to filename + extracted_metadata = spec.get("metadata") + if extracted_metadata: + user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata() + elif spec["fname"]: + user_metadata = {"filename": spec["fname"]} + else: + user_metadata = None + + reference_rows.append( + { + "id": reference_id, + "asset_id": asset_id, + "file_path": absolute_path, + "mtime_ns": spec["mtime_ns"], + "owner_id": owner_id, + "name": spec["info_name"], + "preview_id": None, + "user_metadata": user_metadata, + "created_at": current_time, + "updated_at": current_time, + "last_access_time": current_time, + } + ) + + asset_id_to_ref_data[asset_id] = { + "reference_id": reference_id, + "tags": spec["tags"], + "filename": spec["fname"], + "extracted_metadata": extracted_metadata, + } + + bulk_insert_assets(session, asset_rows) + + # Filter reference rows to only those whose assets were actually inserted + # (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING) + inserted_asset_ids = get_existing_asset_ids( + session, [r["asset_id"] for r in reference_rows] + ) + reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids] + + bulk_insert_references_ignore_conflicts(session, reference_rows) + restore_references_by_paths(session, absolute_path_list) + winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id) + + inserted_paths = { + path + for path in absolute_path_list + if path_to_asset_id[path] in inserted_asset_ids + } + losing_paths = inserted_paths - winning_paths + lost_asset_ids = [path_to_asset_id[path] for path in losing_paths] + + if lost_asset_ids: + delete_assets_by_ids(session, lost_asset_ids) + + if not winning_paths: + return BulkInsertResult( + inserted_refs=0, + won_paths=0, + lost_paths=len(losing_paths), + ) + + # Get reference IDs for winners + winning_ref_ids = [ + asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"] + for path in winning_paths + ] + inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids) + + tag_rows: list[TagRow] = [] + metadata_rows: list[MetadataRow] = [] + + if inserted_ref_ids: + for path in winning_paths: + asset_id = path_to_asset_id[path] + ref_data = asset_id_to_ref_data[asset_id] + ref_id = ref_data["reference_id"] + + if ref_id not in inserted_ref_ids: + continue + + for tag in ref_data["tags"]: + tag_rows.append( + { + "asset_reference_id": ref_id, + "tag_name": tag, + "origin": "automatic", + "added_at": current_time, + } + ) + + # Use extracted metadata for meta rows if available + extracted_metadata = ref_data.get("extracted_metadata") + if extracted_metadata: + metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id)) + elif ref_data["filename"]: + # Fallback: just store filename + metadata_rows.append( + { + "asset_reference_id": ref_id, + "key": "filename", + "ordinal": 0, + "val_str": ref_data["filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows) + + return BulkInsertResult( + inserted_refs=len(inserted_ref_ids), + won_paths=len(winning_paths), + lost_paths=len(losing_paths), + ) + + +def cleanup_unreferenced_assets(session: Session) -> int: + """Hard-delete unhashed assets with no active references. + + This is a destructive operation intended for explicit cleanup. + Only deletes assets where hash=None and all references are missing. + + Returns: + Number of assets deleted + """ + unreferenced_ids = get_unreferenced_unhashed_asset_ids(session) + return delete_assets_by_ids(session, unreferenced_ids) diff --git a/app/assets/services/file_utils.py b/app/assets/services/file_utils.py new file mode 100644 index 000000000..c47ebe460 --- /dev/null +++ b/app/assets/services/file_utils.py @@ -0,0 +1,70 @@ +import os + + +def get_mtime_ns(stat_result: os.stat_result) -> int: + """Extract mtime in nanoseconds from a stat result.""" + return getattr( + stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000) + ) + + +def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]: + """Get file size in bytes and mtime in nanoseconds.""" + st = os.stat(path, follow_symlinks=follow_symlinks) + return st.st_size, get_mtime_ns(st) + + +def verify_file_unchanged( + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + """Check if a file is unchanged based on mtime and size. + + Returns True if the file's mtime and size match the database values. + Returns False if mtime_db is None or values don't match. + + size_db=None means don't check size; 0 is a valid recorded size. + """ + if mtime_db is None: + return False + actual_mtime_ns = get_mtime_ns(stat_result) + if int(mtime_db) != int(actual_mtime_ns): + return False + if size_db is not None: + return int(stat_result.st_size) == int(size_db) + return True + + +def is_visible(name: str) -> bool: + """Return True if a file or directory name is visible (not hidden).""" + return not name.startswith(".") + + +def list_files_recursively(base_dir: str) -> list[str]: + """Recursively list all files in a directory, following symlinks.""" + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + # Track seen real directory identities to prevent circular symlink loops + seen_dirs: set[tuple[int, int]] = set() + for dirpath, subdirs, filenames in os.walk( + base_abs, topdown=True, followlinks=True + ): + try: + st = os.stat(dirpath) + dir_id = (st.st_dev, st.st_ino) + except OSError: + subdirs.clear() + continue + if dir_id in seen_dirs: + subdirs.clear() + continue + seen_dirs.add(dir_id) + subdirs[:] = [d for d in subdirs if is_visible(d)] + for name in filenames: + if not is_visible(name): + continue + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py new file mode 100644 index 000000000..41d8b4615 --- /dev/null +++ b/app/assets/services/hashing.py @@ -0,0 +1,99 @@ +import io +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import IO, Any, Callable, Iterator +import logging + +try: + from blake3 import blake3 +except ModuleNotFoundError: + logging.warning("WARNING: blake3 package not installed") + +DEFAULT_CHUNK = 8 * 1024 * 1024 + +InterruptCheck = Callable[[], bool] + + +@dataclass +class HashCheckpoint: + """Saved state for resuming an interrupted hash computation.""" + + bytes_processed: int + hasher: Any # blake3 hasher instance + mtime_ns: int = 0 + file_size: int = 0 + + +@contextmanager +def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]: + """Yield (file_object, is_path) with appropriate setup/teardown.""" + if hasattr(fp, "read"): + seekable = getattr(fp, "seekable", lambda: False)() + orig_pos = None + if seekable: + try: + orig_pos = fp.tell() + if orig_pos != 0: + fp.seek(0) + except io.UnsupportedOperation: + orig_pos = None + try: + yield fp, False + finally: + if orig_pos is not None: + fp.seek(orig_pos) + else: + with open(os.fspath(fp), "rb") as f: + yield f, True + + +def compute_blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, + interrupt_check: InterruptCheck | None = None, + checkpoint: HashCheckpoint | None = None, +) -> tuple[str | None, HashCheckpoint | None]: + """Compute BLAKE3 hash of a file, with optional checkpoint support. + + Args: + fp: File path or file-like object + chunk_size: Size of chunks to read at a time + interrupt_check: Optional callable that returns True if the operation + should be interrupted (e.g. paused or cancelled). Must be + non-blocking so file handles are released immediately. Checked + between chunk reads. + checkpoint: Optional checkpoint to resume from (file paths only) + + Returns: + Tuple of (hex_digest, None) on completion, or + (None, checkpoint) on interruption (file paths only), or + (None, None) on interruption of a file object + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + with _open_for_hashing(fp) as (f, is_path): + if checkpoint is not None and is_path: + f.seek(checkpoint.bytes_processed) + h = checkpoint.hasher + bytes_processed = checkpoint.bytes_processed + else: + h = blake3() + bytes_processed = 0 + + while True: + if interrupt_check is not None and interrupt_check(): + if is_path: + return None, HashCheckpoint( + bytes_processed=bytes_processed, + hasher=h, + ) + return None, None + chunk = f.read(chunk_size) + if not chunk: + break + h.update(chunk) + bytes_processed += len(chunk) + + return h.hexdigest(), None diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py new file mode 100644 index 000000000..44d7aef36 --- /dev/null +++ b/app/assets/services/ingest.py @@ -0,0 +1,375 @@ +import contextlib +import logging +import mimetypes +import os +from typing import Any, Sequence + +from sqlalchemy.orm import Session + +import app.assets.services.hashing as hashing +from app.assets.database.queries import ( + add_tags_to_reference, + fetch_reference_and_asset, + get_asset_by_hash, + get_existing_asset_ids, + get_reference_by_file_path, + get_reference_tags, + get_or_create_reference, + remove_missing_tag_for_asset_id, + set_reference_metadata, + set_reference_tags, + upsert_asset, + upsert_reference, + validate_tags_exist, +) +from app.assets.helpers import normalize_tags +from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.path_utils import ( + compute_relative_filename, + resolve_destination_from_tags, + validate_path_within_base, +) +from app.assets.services.schemas import ( + IngestResult, + RegisterAssetResult, + UploadResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def _ingest_file_from_path( + abs_path: str, + asset_hash: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: UserMetadata = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> IngestResult: + locator = os.path.abspath(abs_path) + user_metadata = user_metadata or {} + + asset_created = False + asset_updated = False + ref_created = False + ref_updated = False + reference_id: str | None = None + + with create_session() as session: + if preview_id: + if preview_id not in get_existing_asset_ids(session, [preview_id]): + preview_id = None + + asset, asset_created, asset_updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=size_bytes, + mime_type=mime_type, + ) + + ref_created, ref_updated = upsert_reference( + session, + asset_id=asset.id, + file_path=locator, + name=info_name or os.path.basename(locator), + mtime_ns=mtime_ns, + owner_id=owner_id, + ) + + # Get the reference we just created/updated + ref = get_reference_by_file_path(session, locator) + if ref: + reference_id = ref.id + + if preview_id and ref.preview_id != preview_id: + ref.preview_id = preview_id + + norm = normalize_tags(list(tags)) + if norm: + if require_existing_tags: + validate_tags_exist(session, norm) + add_tags_to_reference( + session, + reference_id=reference_id, + tags=norm, + origin=tag_origin, + create_if_missing=not require_existing_tags, + ) + + _update_metadata_with_filename( + session, + reference_id=reference_id, + file_path=ref.file_path, + current_metadata=ref.user_metadata, + user_metadata=user_metadata, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + + session.commit() + + return IngestResult( + asset_created=asset_created, + asset_updated=asset_updated, + ref_created=ref_created, + ref_updated=ref_updated, + reference_id=reference_id, + ) + + +def _register_existing_asset( + asset_hash: str, + name: str, + user_metadata: UserMetadata = None, + tags: list[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> RegisterAssetResult: + user_metadata = user_metadata or {} + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"No asset with hash {asset_hash}") + + ref, ref_created = get_or_create_reference( + session, + asset_id=asset.id, + owner_id=owner_id, + name=name, + ) + + if not ref_created: + tag_names = get_reference_tags(session, reference_id=ref.id) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=False, + ) + session.commit() + return result + + new_meta = dict(user_metadata) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta: + set_reference_metadata( + session, + reference_id=ref.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_reference_tags( + session, + reference_id=ref.id, + tags=tags, + origin=tag_origin, + ) + + tag_names = get_reference_tags(session, reference_id=ref.id) + session.refresh(ref) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=True, + ) + session.commit() + + return result + + + +def _update_metadata_with_filename( + session: Session, + reference_id: str, + file_path: str | None, + current_metadata: dict | None, + user_metadata: dict[str, Any], +) -> None: + computed_filename = compute_relative_filename(file_path) if file_path else None + + current_meta = current_metadata or {} + new_meta = dict(current_meta) + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + set_reference_metadata( + session, + reference_id=reference_id, + user_metadata=new_meta, + ) + + +def _sanitize_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + return n if n else fallback + + +class HashMismatchError(Exception): + pass + + +class DependencyMissingError(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +def upload_from_temp_path( + temp_path: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + client_filename: str | None = None, + owner_id: str = "", + expected_hash: str | None = None, +) -> UploadResult: + try: + digest, _ = hashing.compute_blake3_hash(temp_path) + except ImportError as e: + raise DependencyMissingError(str(e)) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_hash and asset_hash != expected_hash.strip().lower(): + raise HashMismatchError("Uploaded file hash does not match provided hash.") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _sanitize_filename(name or client_filename, fallback=digest) + result = _register_existing_asset( + asset_hash=asset_hash, + name=display_name, + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) + + if not tags: + raise ValueError("tags are required for new asset uploads") + base_dir, subdirs = resolve_destination_from_tags(tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + validate_path_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + ingest_result = _ingest_file_from_path( + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_sanitize_filename(name or client_filename, fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=user_metadata or {}, + tags=tags, + tag_origin="manual", + require_existing_tags=False, + ) + reference_id = ingest_result.reference_id + if not reference_id: + raise RuntimeError("failed to create asset reference") + + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + ref, asset = pair + tag_names = get_reference_tags(session, reference_id=ref.id) + + return UploadResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created_new=ingest_result.asset_created, + ) + + +def create_from_hash( + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> UploadResult | None: + canonical = hash_str.strip().lower() + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + result = _register_existing_asset( + asset_hash=canonical, + name=_sanitize_filename( + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + ), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) diff --git a/app/assets/services/metadata_extract.py b/app/assets/services/metadata_extract.py new file mode 100644 index 000000000..a004929bc --- /dev/null +++ b/app/assets/services/metadata_extract.py @@ -0,0 +1,327 @@ +"""Metadata extraction for asset scanning. + +Tier 1: Filesystem metadata (zero parsing) +Tier 2: Safetensors header metadata (fast JSON read only) +""" + +from __future__ import annotations + +import json +import logging +import mimetypes +import os +import struct +from dataclasses import dataclass +from typing import Any + +from utils.mime_types import init_mime_types + +init_mime_types() + +# Supported safetensors extensions +SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"}) + +# Maximum safetensors header size to read (8MB) +MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024 + + +@dataclass +class ExtractedMetadata: + """Metadata extracted from a file during scanning.""" + + # Tier 1: Filesystem (always available) + filename: str = "" + file_path: str = "" # Full absolute path to the file + content_length: int = 0 + content_type: str | None = None + format: str = "" # file extension without dot + + # Tier 2: Safetensors header (if available) + base_model: str | None = None + trained_words: list[str] | None = None + air: str | None = None # CivitAI AIR identifier + has_preview_images: bool = False + + # Source provenance (populated if embedded in safetensors) + source_url: str | None = None + source_arn: str | None = None + repo_url: str | None = None + preview_url: str | None = None + source_hash: str | None = None + + # HuggingFace specific + repo_id: str | None = None + revision: str | None = None + filepath: str | None = None + resolve_url: str | None = None + + def to_user_metadata(self) -> dict[str, Any]: + """Convert to user_metadata dict for AssetReference.user_metadata JSON field.""" + data: dict[str, Any] = { + "filename": self.filename, + "content_length": self.content_length, + "format": self.format, + } + if self.file_path: + data["file_path"] = self.file_path + if self.content_type: + data["content_type"] = self.content_type + + # Tier 2 fields + if self.base_model: + data["base_model"] = self.base_model + if self.trained_words: + data["trained_words"] = self.trained_words + if self.air: + data["air"] = self.air + if self.has_preview_images: + data["has_preview_images"] = True + + # Source provenance + if self.source_url: + data["source_url"] = self.source_url + if self.source_arn: + data["source_arn"] = self.source_arn + if self.repo_url: + data["repo_url"] = self.repo_url + if self.preview_url: + data["preview_url"] = self.preview_url + if self.source_hash: + data["source_hash"] = self.source_hash + + # HuggingFace + if self.repo_id: + data["repo_id"] = self.repo_id + if self.revision: + data["revision"] = self.revision + if self.filepath: + data["filepath"] = self.filepath + if self.resolve_url: + data["resolve_url"] = self.resolve_url + + return data + + def to_meta_rows(self, reference_id: str) -> list[dict]: + """Convert to asset_reference_meta rows for typed/indexed querying.""" + rows: list[dict] = [] + + def add_str(key: str, val: str | None, ordinal: int = 0) -> None: + if val: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": ordinal, + "val_str": val[:2048] if len(val) > 2048 else val, + "val_num": None, + "val_bool": None, + "val_json": None, + }) + + def add_num(key: str, val: int | float | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": val, + "val_bool": None, + "val_json": None, + }) + + def add_bool(key: str, val: bool | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": None, + "val_bool": val, + "val_json": None, + }) + + # Tier 1 + add_str("filename", self.filename) + add_num("content_length", self.content_length) + add_str("content_type", self.content_type) + add_str("format", self.format) + + # Tier 2 + add_str("base_model", self.base_model) + add_str("air", self.air) + has_previews = self.has_preview_images if self.has_preview_images else None + add_bool("has_preview_images", has_previews) + + # trained_words as multiple rows with ordinals + if self.trained_words: + for i, word in enumerate(self.trained_words[:100]): # limit to 100 words + add_str("trained_words", word, ordinal=i) + + # Source provenance + add_str("source_url", self.source_url) + add_str("source_arn", self.source_arn) + add_str("repo_url", self.repo_url) + add_str("preview_url", self.preview_url) + add_str("source_hash", self.source_hash) + + # HuggingFace + add_str("repo_id", self.repo_id) + add_str("revision", self.revision) + add_str("filepath", self.filepath) + add_str("resolve_url", self.resolve_url) + + return rows + + +def _read_safetensors_header( + path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE +) -> dict[str, Any] | None: + """Read only the JSON header from a safetensors file. + + This is very fast - reads 8 bytes for header length, then the JSON header. + No tensor data is loaded. + + Args: + path: Absolute path to safetensors file + max_size: Maximum header size to read (default 8MB) + + Returns: + Parsed header dict or None if failed + """ + try: + with open(path, "rb") as f: + header_bytes = f.read(8) + if len(header_bytes) < 8: + return None + length_of_header = struct.unpack(" max_size: + return None + header_data = f.read(length_of_header) + if len(header_data) < length_of_header: + return None + return json.loads(header_data.decode("utf-8")) + except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error): + return None + + +def _extract_safetensors_metadata( + header: dict[str, Any], meta: ExtractedMetadata +) -> None: + """Extract metadata from safetensors header __metadata__ section. + + Modifies meta in-place. + """ + st_meta = header.get("__metadata__", {}) + if not isinstance(st_meta, dict): + return + + # Common model metadata + meta.base_model = ( + st_meta.get("ss_base_model_version") + or st_meta.get("modelspec.base_model") + or st_meta.get("base_model") + ) + + # Trained words / trigger words + trained_words = st_meta.get("ss_tag_frequency") + if trained_words and isinstance(trained_words, str): + try: + tag_freq = json.loads(trained_words) + # Extract unique tags from all datasets + all_tags: set[str] = set() + for dataset_tags in tag_freq.values(): + if isinstance(dataset_tags, dict): + all_tags.update(dataset_tags.keys()) + if all_tags: + meta.trained_words = sorted(all_tags)[:100] + except json.JSONDecodeError: + pass + + # Direct trained_words field (some formats) + if not meta.trained_words: + tw = st_meta.get("trained_words") + if isinstance(tw, str): + try: + parsed = json.loads(tw) + if isinstance(parsed, list): + meta.trained_words = [str(x) for x in parsed] + else: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + except json.JSONDecodeError: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + elif isinstance(tw, list): + meta.trained_words = [str(x) for x in tw] + + # CivitAI AIR + meta.air = st_meta.get("air") or st_meta.get("modelspec.air") + + # Preview images (ssmd_cover_images) + cover_images = st_meta.get("ssmd_cover_images") + if cover_images: + meta.has_preview_images = True + + # Source provenance fields + meta.source_url = st_meta.get("source_url") + meta.source_arn = st_meta.get("source_arn") + meta.repo_url = st_meta.get("repo_url") + meta.preview_url = st_meta.get("preview_url") + meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash") + + # HuggingFace fields + meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id") + meta.revision = st_meta.get("revision") or st_meta.get("hf_revision") + meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath") + meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url") + + +def extract_file_metadata( + abs_path: str, + stat_result: os.stat_result | None = None, + relative_filename: str | None = None, +) -> ExtractedMetadata: + """Extract metadata from a file using tier 1 and tier 2 methods. + + Tier 1: Filesystem metadata from path and stat + Tier 2: Safetensors header parsing if applicable + + Args: + abs_path: Absolute path to the file + stat_result: Optional pre-fetched stat result (saves a syscall) + relative_filename: Optional relative filename to use instead of basename + (e.g., "flux/123/model.safetensors" for model paths) + + Returns: + ExtractedMetadata with all available fields populated + """ + meta = ExtractedMetadata() + + # Tier 1: Filesystem metadata + meta.filename = relative_filename or os.path.basename(abs_path) + meta.file_path = abs_path + _, ext = os.path.splitext(abs_path) + meta.format = ext.lstrip(".").lower() if ext else "" + + mime_type, _ = mimetypes.guess_type(abs_path) + meta.content_type = mime_type + + # Size from stat + if stat_result is None: + try: + stat_result = os.stat(abs_path, follow_symlinks=True) + except OSError: + pass + + if stat_result: + meta.content_length = stat_result.st_size + + # Tier 2: Safetensors header (if applicable and enabled) + if ext.lower() in SAFETENSORS_EXTENSIONS: + header = _read_safetensors_header(abs_path) + if header: + try: + _extract_safetensors_metadata(header, meta) + except Exception as e: + logging.debug("Safetensors meta extract failed %s: %s", abs_path, e) + + return meta diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py new file mode 100644 index 000000000..f5dd7f7fd --- /dev/null +++ b/app/assets/services/path_utils.py @@ -0,0 +1,167 @@ +import os +from pathlib import Path +from typing import Literal + +import folder_paths +from app.assets.helpers import normalize_tags + + +_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) + + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build list of (folder_name, base_paths[]) for all model locations. + + Includes every category registered in folder_names_and_paths, + regardless of whether its paths are under the main models_dir, + but excludes non-model entries like custom_nodes. + """ + targets: list[tuple[str, list[str]]] = [] + for name, values in folder_paths.folder_names_and_paths.items(): + if name in _NON_MODEL_FOLDER_NAMES: + continue + paths, _exts = values[0], values[1] + if paths: + targets.append((name, paths)) + return targets + + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + if not tags: + raise ValueError("tags must not be empty") + root = tags[0].lower() + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + elif root == "input": + base_dir = os.path.abspath(folder_paths.get_input_directory()) + raw_subdirs = tags[1:] + elif root == "output": + base_dir = os.path.abspath(folder_paths.get_output_directory()) + raw_subdirs = tags[1:] + else: + raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") + _sep_chars = frozenset(("/", "\\", os.sep)) + for i in raw_subdirs: + if i in (".", "..") or _sep_chars & set(i): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + + +def validate_path_within_base(candidate: str, base: str) -> None: + cand_abs = Path(os.path.abspath(candidate)) + base_abs = Path(os.path.abspath(base)) + if not cand_abs.is_relative_to(base_abs): + raise ValueError("destination escapes base directory") + + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + """ + try: + root_category, rel_path = get_asset_category_and_relative_path(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def get_asset_category_and_relative_path( + file_path: str, +) -> tuple[Literal["input", "output", "models"], str]: + """Determine which root category a file path belongs to. + + Categories: + - 'input': under folder_paths.get_input_directory() + - 'output': under folder_paths.get_output_directory() + - 'models': under any base path from get_comfy_models_folders() + + Returns: + (root_category, relative_path_inside_that_root) + + Raises: + ValueError: path does not belong to any known root. + """ + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _check_is_within(fp_abs, input_base): + return "input", _compute_relative(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _check_is_within(fp_abs, output_base): + return "output", _compute_relative(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _check_is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError( + f"Path is not within input, output, or configured model bases: {file_path}" + ) + + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return (name, tags) derived from a filesystem path. + + - name: base filename with extension + - tags: [root_category] + parent folder names in order + + Raises: + ValueError: path does not belong to any known root. + """ + root_category, some_path = get_asset_category_and_relative_path(file_path) + p = Path(some_path) + parent_parts = [ + part for part in p.parent.parts if part not in (".", "..", p.anchor) + ] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py new file mode 100644 index 000000000..8b1f1f4dc --- /dev/null +++ b/app/assets/services/schemas.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, NamedTuple + +from app.assets.database.models import Asset, AssetReference + +UserMetadata = dict[str, Any] | None + + +@dataclass(frozen=True) +class AssetData: + hash: str | None + size_bytes: int | None + mime_type: str | None + + +@dataclass(frozen=True) +class ReferenceData: + """Data transfer object for AssetReference.""" + + id: str + name: str + file_path: str | None + user_metadata: UserMetadata + preview_id: str | None + created_at: datetime + updated_at: datetime + last_access_time: datetime | None + + +@dataclass(frozen=True) +class AssetDetailResult: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class RegisterAssetResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created: bool + + +@dataclass(frozen=True) +class IngestResult: + asset_created: bool + asset_updated: bool + ref_created: bool + ref_updated: bool + reference_id: str | None + + +class TagUsage(NamedTuple): + name: str + tag_type: str + count: int + + +@dataclass(frozen=True) +class AssetSummaryData: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class ListAssetsResult: + items: list[AssetSummaryData] + total: int + + +@dataclass(frozen=True) +class DownloadResolutionResult: + abs_path: str + content_type: str + download_name: str + + +@dataclass(frozen=True) +class UploadResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created_new: bool + + +def extract_reference_data(ref: AssetReference) -> ReferenceData: + return ReferenceData( + id=ref.id, + name=ref.name, + file_path=ref.file_path, + user_metadata=ref.user_metadata, + preview_id=ref.preview_id, + created_at=ref.created_at, + updated_at=ref.updated_at, + last_access_time=ref.last_access_time, + ) + + +def extract_asset_data(asset: Asset | None) -> AssetData | None: + if asset is None: + return None + return AssetData( + hash=asset.hash, + size_bytes=asset.size_bytes, + mime_type=asset.mime_type, + ) diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py new file mode 100644 index 000000000..28900464d --- /dev/null +++ b/app/assets/services/tagging.py @@ -0,0 +1,75 @@ +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, + add_tags_to_reference, + get_reference_with_owner_check, + list_tags_with_usage, + remove_tags_from_reference, +) +from app.assets.services.schemas import TagUsage +from app.database.db import create_session + + +def apply_tags( + reference_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> AddTagsResult: + with create_session() as session: + ref_row = get_reference_with_owner_check(session, reference_id, owner_id) + + result = add_tags_to_reference( + session, + reference_id=reference_id, + tags=tags, + origin=origin, + create_if_missing=True, + reference_row=ref_row, + ) + session.commit() + + return result + + +def remove_tags( + reference_id: str, + tags: list[str], + owner_id: str = "", +) -> RemoveTagsResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + result = remove_tags_from_reference( + session, + reference_id=reference_id, + tags=tags, + ) + session.commit() + + return result + + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> tuple[list[TagUsage], int]: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total diff --git a/app/database/db.py b/app/database/db.py index 1de8b80ed..0aab09a49 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -3,6 +3,7 @@ import os import shutil from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message +from filelock import FileLock, Timeout from comfy.cli_args import args _DB_AVAILABLE = False @@ -14,8 +15,12 @@ try: from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory - from sqlalchemy import create_engine + from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker + from sqlalchemy.pool import StaticPool + + from app.database.models import Base + import app.assets.database.models # noqa: F401 — register models with Base.metadata _DB_AVAILABLE = True except ImportError as e: @@ -65,9 +70,69 @@ def get_db_path(): raise ValueError(f"Unsupported database URL '{url}'.") +_db_lock = None + +def _acquire_file_lock(db_path): + """Acquire an OS-level file lock to prevent multi-process access. + + Uses filelock for cross-platform support (macOS, Linux, Windows). + The OS automatically releases the lock when the process exits, even on crashes. + """ + global _db_lock + lock_path = db_path + ".lock" + _db_lock = FileLock(lock_path) + try: + _db_lock.acquire(timeout=0) + except Timeout: + raise RuntimeError( + f"Could not acquire lock on database '{db_path}'. " + "Another ComfyUI process may already be using it. " + "Use --database-url to specify a separate database file." + ) + + +def _is_memory_db(db_url): + """Check if the database URL refers to an in-memory SQLite database.""" + return db_url in ("sqlite:///:memory:", "sqlite://") + + def init_db(): db_url = args.database_url logging.debug(f"Database URL: {db_url}") + + if _is_memory_db(db_url): + _init_memory_db(db_url) + else: + _init_file_db(db_url) + + +def _init_memory_db(db_url): + """Initialize an in-memory SQLite database using metadata.create_all. + + Alembic migrations don't work with in-memory SQLite because each + connection gets its own separate database — tables created by Alembic's + internal connection are lost immediately. + """ + engine = create_engine( + db_url, + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + Base.metadata.create_all(engine) + + global Session + Session = sessionmaker(bind=engine) + + +def _init_file_db(db_url): + """Initialize a file-backed SQLite database using Alembic migrations.""" db_path = get_db_path() db_exists = os.path.exists(db_path) @@ -75,6 +140,14 @@ def init_db(): # Check if we need to upgrade engine = create_engine(db_url) + + # Enable foreign key enforcement for SQLite + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + conn = engine.connect() context = MigrationContext.configure(conn) @@ -104,6 +177,12 @@ def init_db(): logging.exception("Error upgrading database: ") raise e + # Acquire an OS-level file lock after migrations are complete. + # Alembic uses its own connection, so we must wait until it's done + # before locking — otherwise our own lock blocks the migration. + conn.close() + _acquire_file_lock(db_path) + global Session Session = sessionmaker(bind=engine) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 16998af94..0de7584b0 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -27,6 +27,7 @@ class AudioEncoderModel(): self.model.eval() self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 + comfy.model_management.archive_model_dtypes(self.model) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13079c7bc..e9832acaf 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -232,7 +232,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") -parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") +parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8b3f500d7..e20d498f8 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module): del txt_k, img_k v = torch.cat((txt_v, img_v), dim=2) del txt_v, img_v + + extra_options["img_slice"] = [txt.shape[1], q.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # run actual attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v if "attn1_output_patch" in transformer_patches: - extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] patch = transformer_patches["attn1_output_patch"] for p in patch: attn = p(attn, extra_options) @@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module): del qkv q, k = self.norm(q, k, v) + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 5e764bb46..824daf5e6 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def _apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]: + freqs_cis = freqs_cis[:, :, :x_.shape[2]] x_out = freqs_cis[..., 0] * x_[..., 0] x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4dcf7c5..00f12c031 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -170,7 +170,7 @@ class Flux(nn.Module): if "post_input" in patches: for p in patches["post_input"]: - out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) img = out["img"] txt = out["txt"] img_ids = out["img_ids"] diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 553fd5b38..08d686b7b 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -2,11 +2,16 @@ from typing import Tuple import torch import torch.nn as nn from comfy.ldm.lightricks.model import ( + ADALN_BASE_PARAMS_COUNT, + ADALN_CROSS_ATTN_PARAMS_COUNT, CrossAttention, FeedForward, AdaLayerNormSingle, PixArtAlphaTextProjection, + NormSingleLinearTextProjection, LTXVModel, + apply_cross_attention_adaln, + compute_prompt_timestep, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module): v_context_dim=None, a_context_dim=None, attn_precision=None, + apply_gated_attention=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=v_dim, @@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=vd_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=ad_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module): heads=v_heads, dim_head=vd_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module): a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations ) - self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype)) self.audio_scale_shift_table = nn.Parameter( - torch.empty(6, a_dim, device=device, dtype=dtype) + torch.empty(num_ada_params, a_dim, device=device, dtype=dtype) ) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype)) + self.scale_shift_table_a2v_ca_audio = nn.Parameter( torch.empty(5, a_dim, device=device, dtype=dtype) ) @@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) + def _apply_text_cross_attention( + self, x, context, attn, scale_shift_table, prompt_scale_shift_table, + timestep, prompt_timestep, attention_mask, transformer_options, + ): + """Apply text cross-attention, with optional ADaLN modulation.""" + if self.cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values( + scale_shift_table, x.shape[0], timestep, slice(6, 9) + ) + return apply_cross_attention_adaln( + x, context, attn, shift_q, scale_q, gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask, transformer_options, + ) + return attn( + comfy.ldm.common_dit.rms_norm(x), context=context, + mask=attention_mask, transformer_options=transformer_options, + ) + def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, + v_prompt_timestep=None, a_prompt_timestep=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module): vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vx.addcmul_(attn1_out, vgate_msa) del vgate_msa, attn1_out - vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + vx.add_(self._apply_text_cross_attention( + vx, v_context, self.attn2, self.scale_shift_table, + getattr(self, 'prompt_scale_shift_table', None), + v_timestep, v_prompt_timestep, attention_mask, transformer_options,) + ) # audio if run_ax: @@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module): agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] ax.addcmul_(attn1_out, agate_msa) del agate_msa, attn1_out - ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) + ax.add_(self._apply_text_cross_attention( + ax, a_context, self.audio_attn2, self.audio_scale_shift_table, + getattr(self, 'audio_prompt_scale_shift_table', None), + a_timestep, a_prompt_timestep, attention_mask, transformer_options,) + ) # video - audio cross attention. if run_a2v or run_v2a: @@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel): use_middle_indices_grid=False, timestep_scale_multiplier=1000.0, av_ca_timestep_scale_multiplier=1.0, + apply_gated_attention=False, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel): self.audio_attention_head_dim = audio_attention_head_dim self.audio_num_attention_heads = audio_num_attention_heads self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.apply_gated_attention = apply_gated_attention # Calculate audio dimensions self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim @@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel): ) # Audio-specific AdaLN + audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.audio_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, + embedding_coefficient=audio_embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations, ) + if self.cross_attention_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=2, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_prompt_adaln_single = None + num_scale_shift_values = 4 self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, @@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel): ) # Audio caption projection - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.audio_inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.audio_caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_caption_projection = lambda a: a + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + connector_split_rope = kwargs.get("rope_type", "split") == "split" + connector_gated_attention = kwargs.get("connector_apply_gated_attention", False) + attention_head_dim = kwargs.get("connector_attention_head_dim", 128) + num_attention_heads = kwargs.get("connector_num_attention_heads", 30) + num_layers = kwargs.get("connector_num_layers", 2) self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim), + num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads), + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) - def preprocess_text_embeds(self, context): - if context.shape[-1] == self.caption_channels * 2: - return context - out_vid = self.video_embeddings_connector(context)[0] - out_audio = self.audio_embeddings_connector(context)[0] + def preprocess_text_embeds(self, context, unprocessed=False): + # LTXv2 fully processed context has dimension of self.caption_channels * 2 + # LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim + if not unprocessed: + if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2): + return context + if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim: + context_vid = context[:, :, :self.cross_attention_dim] + context_audio = context[:, :, self.cross_attention_dim:] + else: + context_vid = context + context_audio = context + if self.caption_proj_before_connector: + context_vid = self.caption_projection(context_vid) + context_audio = self.audio_caption_projection(context_audio) + out_vid = self.video_embeddings_connector(context_vid)[0] + out_audio = self.audio_embeddings_connector(context_audio)[0] return torch.concat((out_vid, out_audio), dim=-1) def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel): ad_head=self.audio_attention_head_dim, v_context_dim=self.cross_attention_dim, a_context_dim=self.audio_cross_attention_dim, + apply_gated_attention=self.apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel): v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + v_prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: @@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel): # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep_flat, + timestep.max().expand_as(a_timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep_flat, + a_timestep.max().expand_as(timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep_flat * av_ca_factor, + a_timestep.max().expand_as(timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep_flat * av_ca_factor, + timestep.max().expand_as(a_timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel): # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + + a_prompt_timestep = compute_prompt_timestep( + self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype + ) else: a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] + a_prompt_timestep = None - return [v_timestep, a_timestep, cross_av_timestep_ss], [ + return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [ v_embedded_timestep, a_embedded_timestep, - ] + ], None def _prepare_context(self, context, batch_size, x, attention_mask=None): vx = x[0] ax = x[1] + video_dim = vx.shape[-1] + audio_dim = ax.shape[-1] + + v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim + a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim + v_context, a_context = torch.split( - context, int(context.shape[-1] / 2), len(context.shape) - 1 + context, [v_context_dim, a_context_dim], len(context.shape) - 1 ) v_context, attention_mask = super()._prepare_context( v_context, batch_size, vx, attention_mask ) - if self.audio_caption_projection is not None: + if self.caption_proj_before_connector is False: a_context = self.audio_caption_projection(a_context) - a_context = a_context.view(batch_size, -1, ax.shape[-1]) + a_context = a_context.view(batch_size, -1, audio_dim) return [v_context, a_context], attention_mask @@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel): av_ca_v2a_gate_noise_timestep, ) = timestep[2] + v_prompt_timestep = timestep[3] + a_prompt_timestep = timestep[4] + """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), + v_prompt_timestep=args.get("v_prompt_timestep"), + a_prompt_timestep=args.get("a_prompt_timestep"), ) return out @@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel): "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, + "v_prompt_timestep": v_prompt_timestep, + "a_prompt_timestep": a_prompt_timestep, }, {"original_block": block_wrap}, ) @@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + v_prompt_timestep=v_prompt_timestep, + a_prompt_timestep=a_prompt_timestep, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 33adb9671..2811080be 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module): d_head, context_dim=None, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module): heads=n_heads, dim_head=d_head, context_dim=None, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module): positional_embedding_max_pos=[4096], causal_temporal_positioning=False, num_learnable_registers: Optional[int] = 128, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module): num_attention_heads, attention_head_dim, context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 60d760d29..bfbc08357 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class NormSingleLinearTextProjection(nn.Module): + """Text projection for 20B models - single linear with RMSNorm (no activation).""" + + def __init__( + self, in_features, hidden_size, dtype=None, device=None, operations=None + ): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.in_norm = operations.RMSNorm( + in_features, eps=1e-6, elementwise_affine=False + ) + self.linear_1 = operations.Linear( + in_features, hidden_size, bias=True, dtype=dtype, device=device + ) + self.hidden_size = hidden_size + self.in_features = in_features + + def forward(self, caption): + caption = self.in_norm(caption) + caption = caption * (self.hidden_size / self.in_features) ** 0.5 + return self.linear_1(caption) + + class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() @@ -343,6 +367,7 @@ class CrossAttention(nn.Module): dim_head=64, dropout=0.0, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -362,6 +387,12 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) + else: + self.to_gate_logits = None + self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -383,16 +414,30 @@ class CrossAttention(nn.Module): out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + out = out.view(b, t, self.heads, self.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity + out = out * gates.unsqueeze(-1) + out = out.view(b, t, self.heads * self.dim_head) + return self.to_out(out) +# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) +ADALN_BASE_PARAMS_COUNT = 6 +ADALN_CROSS_ATTN_PARAMS_COUNT = 9 class BasicTransformerBlock(nn.Module): def __init__( - self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None ): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, @@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module): operations=operations, ) - self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) - attn1_input = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) - x.addcmul_(attn1_input, gate_msa) - del attn1_input + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) - x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa + + if self.cross_attention_adaln: + shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) + x += apply_cross_attention_adaln( + x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, + self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, + ) + else: + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) @@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module): return x +def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): + """Compute a single global prompt timestep for cross-attention ADaLN. + + Uses the max across tokens (matching JAX max_per_segment) and broadcasts + over text tokens. Returns None when *adaln_module* is None. + """ + if adaln_module is None: + return None + ts_input = ( + timestep_scaled.max(dim=1, keepdim=True).values.flatten() + if timestep_scaled.dim() > 1 + else timestep_scaled.flatten() + ) + prompt_ts, _ = adaln_module( + ts_input, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) + + +def apply_cross_attention_adaln( + x, context, attn, q_shift, q_scale, q_gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask=None, transformer_options={}, +): + """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). + + Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so + that both regular tensors and CompressedTimestep are supported. + """ + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate + def get_fractional_positions(indices_grid, max_pos): n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' @@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC): vae_scale_factors: tuple = (8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + caption_projection_first_linear=True, dtype=None, device=None, operations=None, @@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC): self.causal_temporal_positioning = causal_temporal_positioning self.operations = operations self.timestep_scale_multiplier = timestep_scale_multiplier + self.caption_proj_before_connector = caption_proj_before_connector + self.cross_attention_adaln = cross_attention_adaln + self.caption_projection_first_linear = caption_projection_first_linear # Common dimensions self.inner_dim = num_attention_heads * attention_head_dim @@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC): self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) + embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.cross_attention_adaln: + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) + else: + self.prompt_adaln_single = None + + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.caption_projection = lambda a: a + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) @abstractmethod def _init_model_components(self, device, dtype, **kwargs): @@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC): timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) - return timestep, embedded_timestep + prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + return timestep, embedded_timestep, prompt_timestep def _prepare_context(self, context, batch_size, x, attention_mask=None): """Prepare context for transformer blocks.""" - if self.caption_projection is not None: + if self.caption_proj_before_connector is False: context = self.caption_projection(context) - context = context.view(batch_size, -1, x.shape[-1]) + context = context.view(batch_size, -1, x.shape[-1]) return context, attention_mask def _precompute_freqs_cis( @@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC): merged_args.update(additional_args) # Prepare timestep and context - timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + merged_args["prompt_timestep"] = prompt_timestep context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) # Prepare attention mask and positional embeddings @@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel): causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, - timestep_scale_multiplier = 1000.0, + timestep_scale_multiplier=1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel): def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXV-specific components.""" - # No additional components needed for LTXV beyond base class pass def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel): self.num_attention_heads, self.attention_head_dim, context_dim=self.cross_attention_dim, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prompt_timestep = kwargs.get("prompt_timestep", None) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel): pe=pe, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + prompt_timestep=prompt_timestep, ) return x diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 55a074661..fa0a00748 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( CausalityAxis, CausalAudioAutoencoder, ) -from comfy.ldm.lightricks.vocoders.vocoder import Vocoder +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE LATENT_DOWNSAMPLE_FACTOR = 4 @@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module): vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) - self.vocoder = Vocoder(config=component_config.vocoder) + if "bwe" in component_config.vocoder: + self.vocoder = VocoderWithBWE(config=component_config.vocoder) + else: + self.vocoder = Vocoder(config=component_config.vocoder) self.autoencoder.load_state_dict(vae_sd, strict=False) self.vocoder.load_state_dict(vocoder_sd, strict=False) diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py index f12b9bb53..b556b128f 100644 --- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module): super().__init__() if config is None: - config = self._guess_config() + config = self.get_default_config() - # Extract encoder and decoder configs from the new format model_config = config.get("model", {}).get("params", {}) - variables_config = config.get("variables", {}) - self.sampling_rate = variables_config.get( - "sampling_rate", - model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + self.sampling_rate = model_config.get( + "sampling_rate", config.get("sampling_rate", 16000) ) encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) decoder_config = model_config.get("decoder", encoder_config) # Load mel spectrogram parameters self.mel_bins = encoder_config.get("mel_bins", 64) - self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) - self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) # Store causality configuration at VAE level (not just in encoder internals) - causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value) self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) self.is_causal = self.causality_axis == CausalityAxis.HEIGHT @@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module): self.per_channel_statistics = processor() - def _guess_config(self): - encoder_config = { - # Required parameters - based on ltx-video-av-1679000 model metadata - "ch": 128, - "out_ch": 8, - "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] - "num_res_blocks": 2, - "attn_resolutions": [], # Based on metadata: empty list, no attention - "dropout": 0.0, - "resamp_with_conv": True, - "in_channels": 2, # stereo - "resolution": 256, - "z_channels": 8, + def get_default_config(self): + ddconfig = { "double_z": True, - "attn_type": "vanilla", - "mid_block_add_attention": False, # Based on metadata: false + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 2, + "out_ch": 2, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "mid_block_add_attention": False, "norm_type": "pixel", - "causality_axis": "height", # Based on metadata - "mel_bins": 64, # Based on metadata: mel_bins = 64 - } - - decoder_config = { - # Inherits encoder config, can override specific params - **encoder_config, - "out_ch": 2, # Stereo audio output (2 channels) - "give_pre_end": False, - "tanh_out": False, + "causality_axis": "height", } config = { - "_class_name": "CausalAudioAutoencoder", - "sampling_rate": 16000, "model": { "params": { - "encoder": encoder_config, - "decoder": decoder_config, + "ddconfig": ddconfig, + "sampling_rate": 16000, } }, + "preprocessing": { + "stft": { + "filter_length": 1024, + "hop_length": 160, + }, + }, } return config diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index cbfdf412d..5b57dfc5e 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def in_meta_context(): + return torch.device("meta") == torch.empty(0).device + def mark_conv3d_ended(module): tid = threading.get_ident() for _, m in module.named_modules(): @@ -350,6 +353,10 @@ class Decoder(nn.Module): output_channel = output_channel * block_params.get("multiplier", 2) if block_name == "compress_all": output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_space": + output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_time": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -395,17 +402,21 @@ class Decoder(nn.Module): spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 1, 1), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(1, 2, 2), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": @@ -455,6 +466,15 @@ class Decoder(nn.Module): output_channel * 2, 0, operations=ops, ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + else: + self.register_buffer( + "last_scale_shift_table", + torch.tensor( + [0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(2, output_channel), + persistent=False, + ) # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: @@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module): self.scale_shift_table = nn.Parameter( torch.randn(4, in_channels) / in_channels**0.5 ) + else: + self.register_buffer( + "scale_shift_table", + torch.tensor( + [0.0, 0.0, 0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(4, in_channels), + persistent=False, + ) self.temporal_cache_state={} @@ -1012,9 +1041,6 @@ class processor(nn.Module): super().__init__() self.register_buffer("std-of-means", torch.empty(128)) self.register_buffer("mean-of-means", torch.empty(128)) - self.register_buffer("mean-of-stds", torch.empty(128)) - self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) - self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) @@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module): super().__init__() if config is None: - config = self.guess_config(version) + config = self.get_default_config(version) + self.config = config self.timestep_conditioning = config.get("timestep_conditioning", False) + self.decode_noise_scale = config.get("decode_noise_scale", 0.025) + self.decode_timestep = config.get("decode_timestep", 0.05) double_z = config.get("double_z", True) latent_log_var = config.get( "latent_log_var", "per_channel" if double_z else "none" @@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module): latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + base_channels=config.get("encoder_base_channels", 128), ) self.decoder = Decoder( @@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module): in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + base_channels=config.get("decoder_base_channels", 128), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), @@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module): self.per_channel_statistics = processor() - def guess_config(self, version): + def get_default_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x, timestep=0.05, noise_scale=0.025): + def decode(self, x): if self.timestep_conditioning: #TODO: seed - x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) - + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index b1f15f2c5..2481d8bdd 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -2,7 +2,9 @@ import torch import torch.nn.functional as F import torch.nn as nn import comfy.ops +import comfy.model_management import numpy as np +import math ops = comfy.ops.disable_weight_init @@ -12,6 +14,307 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor): + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride=1, + padding=True, + padding_mode="replicate", + kernel_size=12, + ): + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter — identical to torchaudio.functional.resample + # with its default parameters (rolloff=0.99, lowpass_filter_width=6). + # Uses replicate boundary padding, matching the reference resampler exactly. + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + t = (torch.arange(self.kernel_size) / ratio - width) * rolloff + t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 + filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + + self.register_buffer("filter", filter, persistent=persistent) + + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio=2, + down_ratio=2, + up_kernel_size=12, + down_kernel_size=12, + ): + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +# --------------------------------------------------------------------------- +# BigVGAN v2 activations (Snake / SnakeBeta) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + if self.alpha_logscale: + a = torch.exp(a) + return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + if self.alpha_logscale: + a = torch.exp(a) + b = torch.exp(b) + return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) + + +# --------------------------------------------------------------------------- +# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) +# --------------------------------------------------------------------------- + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + self.acts1 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] + ) + self.acts2 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] + ) + + def forward(self, x): + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# HiFi-GAN residual blocks +# --------------------------------------------------------------------------- + + class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() @@ -119,6 +422,7 @@ class Vocoder(torch.nn.Module): """ Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). """ def __init__(self, config=None): @@ -128,19 +432,39 @@ class Vocoder(torch.nn.Module): config = self.get_default_config() resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) - upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) - upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_initial_channel = config.get("upsample_initial_channel", 1024) stereo = config.get("stereo", True) - resblock = config.get("resblock", "1") + activation = config.get("activation", "snake") + use_bias_at_final = config.get("use_bias_at_final", True) + + # "output_sample_rate" is not present in recent checkpoint configs. + # When absent (None), AudioVAE.output_sample_rate computes it as: + # sample_rate * vocoder.upsample_factor / mel_hop_length + # where upsample_factor = product of all upsample stride lengths, + # and mel_hop_length is loaded from the autoencoder config at + # preprocessing.stft.hop_length (see CausalAudioAutoencoder). self.output_sample_rate = config.get("output_sample_rate") + self.resblock = config.get("resblock", "1") + self.use_tanh_at_final = config.get("use_tanh_at_final", True) + self.apply_final_activation = config.get("apply_final_activation", True) self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + if self.resblock == "1": + resblock_cls = ResBlock1 + elif self.resblock == "2": + resblock_cls = ResBlock2 + elif self.resblock == "AMP1": + resblock_cls = AMPBlock1 + else: + raise ValueError(f"Unknown resblock type: {self.resblock}") self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -157,25 +481,40 @@ class Vocoder(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + if self.resblock == "AMP1": + self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, k, d)) out_channels = 2 if stereo else 1 - self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + if self.resblock == "AMP1": + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(ch)) + else: + self.act_post = nn.LeakyReLU() + + self.conv_post = ops.Conv1d( + ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final + ) self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + def get_default_config(self): """Generate default configuration for the vocoder.""" config = { "resblock_kernel_sizes": [3, 7, 11], - "upsample_rates": [6, 5, 2, 2, 2], - "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_initial_channel": 1024, "stereo": True, "resblock": "1", + "activation": "snake", + "use_bias_at_final": True, + "use_tanh_at_final": True, } return config @@ -196,8 +535,10 @@ class Vocoder(torch.nn.Module): assert x.shape[1] == 2, "Input must have 2 channels for stereo" x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) x = self.conv_pre(x) + for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) + if self.resblock != "AMP1": + x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): @@ -206,8 +547,167 @@ class Vocoder(torch.nn.Module): else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels - x = F.leaky_relu(x) + + x = self.act_post(x) x = self.conv_post(x) - x = torch.tanh(x) + + if self.apply_final_activation: + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, -1, 1) return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT × Hann-window bases. + + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int): + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + The STFT is computed by convolving the padded signal with forward_basis. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + Computed in float32 for numerical stability, then cast back to + the input dtype. + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + sampling_rate: int, + mel_fmin: float, + mel_fmax: float, + ): + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram( + self, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(torch.nn.Module): + """Vocoder with bandwidth extension (BWE) for higher sample rate output. + + Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + """ + + def __init__(self, config): + super().__init__() + vocoder_config = config["vocoder"] + bwe_config = config["bwe"] + + self.vocoder = Vocoder(config=vocoder_config) + self.bwe_generator = Vocoder( + config={**bwe_config, "apply_final_activation": False} + ) + + self.input_sample_rate = bwe_config["input_sampling_rate"] + self.output_sample_rate = bwe_config["output_sampling_rate"] + self.hop_length = bwe_config["hop_length"] + + self.mel_stft = MelSTFT( + filter_length=bwe_config["n_fft"], + hop_length=bwe_config["hop_length"], + win_length=bwe_config["n_fft"], + n_mel_channels=bwe_config["num_mels"], + sampling_rate=bwe_config["input_sampling_rate"], + mel_fmin=0.0, + mel_fmax=bwe_config["input_sampling_rate"] / 2.0, + ) + self.resampler = UpSample1d( + ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], + persistent=False, + window_type="hann", + ) + + def _compute_mel(self, audio): + """Compute log-mel spectrogram from waveform using causal STFT bases.""" + B, C, T = audio.shape + flat = audio.reshape(B * C, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec): + x = self.vocoder(mel_spec) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate + + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :T_out] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 10d051325..b193fe5e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 805592aa5..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -258,7 +258,8 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -314,7 +315,8 @@ def pytorch_attention(q, k, v): try: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") oom_fallback = True if oom_fallback: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c..f982afc2b 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/lora.py b/comfy/lora.py index f36ddb046..63ee85323 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models + if tp > 0 and not k.startswith("clip_"): + key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False diff --git a/comfy/model_base.py b/comfy/model_base.py index 1e01e9edc..d9d5a9293 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1021,7 +1021,7 @@ class LTXAV(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: if hasattr(self.diffusion_model, "preprocess_text_embeds"): - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6eace4628..35a6822e3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,4 +1,5 @@ import json +import comfy.memory_management import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): new[:old_weight.shape[0]] = old_weight old_weight = new + if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled: + old_weight = old_weight.clone() + w = old_weight.narrow(offset[0], offset[1], offset[2]) else: + if comfy.memory_management.aimdo_enabled: + weight = weight.clone() old_weight = weight w = weight w[:] = fun(weight) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e0e96672..81550c790 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,6 +270,18 @@ try: except: OOM_EXCEPTION = Exception +def is_oom(e): + if isinstance(e, OOM_EXCEPTION): + return True + if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + discard_cuda_async_error() + return True + return False + +def raise_non_oom(e): + if not is_oom(e): + raise e + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: @@ -796,6 +808,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): @@ -828,11 +842,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -840,7 +857,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: + if mem_dev > mem_cpu and model_size < mem_dev: return torch_dev else: return cpu_dev @@ -934,7 +951,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: @@ -943,6 +960,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device @@ -1140,6 +1160,7 @@ def reset_cast_buffers(): LARGEST_CASTED_WEIGHT = (None, 0) for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() + synchronize() STREAM_CAST_BUFFERS.clear() soft_empty_cache() @@ -1658,12 +1679,16 @@ def lora_compute_dtype(device): return dtype def synchronize(): + if cpu_mode(): + return if is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() def soft_empty_cache(force=False): + if cpu_mode(): + return global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e380e406b..745384271 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -306,10 +307,16 @@ class ModelPatcher: return self.model.lowvram_patch_counter def get_free_memory(self, device): - return comfy.model_management.get_free_memory(device) + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +343,7 @@ class ModelPatcher: n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -698,7 +705,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False, default_device=None): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): default = False @@ -708,8 +715,8 @@ class ModelPatcher: default = True # default random weights in non leaf modules break if default and default_device is not None: - for param in params.values(): - param.data = param.data.to(device=default_device) + for param_name, param in params.items(): + param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None)) if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem @@ -726,8 +733,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () - loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -1459,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher): vbar = self._vbar_get() return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory - def get_free_memory(self, device): - #NOTE: on high condition / batch counts, estimate should have already vacated - #all non-dynamic models so this is safe even if its not 100% true that this - #would all be avaiable for inference use. - return comfy.model_management.get_total_memory(device) - self.model_size() - #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. def pin_weight_to_device(self, key): @@ -1507,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) - loading.sort(reverse=True) + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() for x in loading: - _, _, _, n, m, params = x + *_, module_mem, n, m, params = x def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): @@ -1579,11 +1585,22 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) - comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) - self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") @@ -1607,15 +1624,17 @@ class ModelPatcherDynamic(ModelPatcher): for key in list(self.backup.keys()): bk = self.backup.pop(key) comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) freed += self.model.model_loaded_weight_memory self.model.model_loaded_weight_memory = 0 return freed def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: - _, _, _, _, m, _ = x + *_, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) if ram_to_unload <= 0: return diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..87b36b5c5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if comfy.model_management.is_device_cpu(device): + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = None + if s.bias is not None: + bias = s.bias.to(dtype=bias_dtype, copy=True) + return weight, bias, (None, None, None) + offload_stream = None xfer_dest = None @@ -269,8 +284,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream): return os, weight_a, bias_a = offload_stream device=None - #FIXME: This is not good RTTI - if not isinstance(weight_a, torch.Tensor): + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): comfy_aimdo.model_vbar.vbar_unpin(s._v) device = weight_a if os is None: @@ -660,23 +675,29 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..adcd67767 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,7 +428,7 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) @@ -954,7 +954,8 @@ class VAE: if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -1029,7 +1030,8 @@ class VAE: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples[x:x + batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -1467,7 +1469,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage elif clip_type == CLIPType.LTXV: - clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e86ea9f4e..5e1273c6e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is +class DualLinearProjection(torch.nn.Module): + def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None): + super().__init__() + self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device) + self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device) + + def forward(self, x): + source_dim = x.shape[-1] + x = x.movedim(1, -1) + x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2) + + video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim)) + audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim)) + return torch.cat((video, audio), dim=-1) + class LTXAVTEModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}): super().__init__() self.dtypes = set() self.dtypes.add(dtype) self.compat_mode = False + self.text_projection_type = text_projection_type self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) self.dtypes.add(dtype_llama) operations = self.gemma3_12b.operations # TODO - self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + if self.text_projection_type == "single_linear": + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + elif self.text_projection_type == "dual_linear": + self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations) + def enable_compat_mode(self): # TODO: remove from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module): out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) - out = out.movedim(1, -1).to(self.execution_device) - out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) - out = out.reshape((out.shape[0], out.shape[1], -1)) - out = self.text_embedding_projection(out) - out = out.float() - if self.compat_mode: - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) + if self.text_projection_type == "single_linear": + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) - return out.to(out_device), pooled + if self.compat_mode: + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + extra = {} + else: + extra = {"unprocessed_ltxav_embeds": True} + elif self.text_projection_type == "dual_linear": + out = self.text_embedding_projection(out) + extra = {"unprocessed_ltxav_embeds": True} + + return out.to(device=out_device, dtype=torch.float), pooled, extra def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) @@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True) if len(sdo) == 0: sdo = sd @@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module): num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 -def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: @@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options) return LTXAVTEModel_ + +def sd_detect(state_dict_list, prefix=""): + for sd in state_dict_list: + if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd: + return {"text_projection_type": "dual_linear"} + if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd: + return {"text_projection_type": "single_linear"} + return {} + + def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): class Gemma3_12BModel_(Gemma3_12BModel): def __init__(self, device="cpu", dtype=None, model_options={}): diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..6e1d14419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index a90a5ca40..9f6918315 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, + "assets": args.enable_assets, } diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a3d48c87f..58a37c9e8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput): codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): + """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput): extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value + elif isinstance(path, io.BytesIO): + # BytesIO has no file extension, so av.open can't infer the format. + # Default to mp4 since that's the only supported format anyway. + extra_kwargs["format"] = "mp4" with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py index 8e3c79ab9..c56c8aecc 100644 --- a/comfy_api_nodes/apis/grok.py +++ b/comfy_api_nodes/apis/grok.py @@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel): aspect_ratio: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + resolution: str = Field(...) class InputUrlObject(BaseModel): @@ -16,12 +17,13 @@ class InputUrlObject(BaseModel): class ImageEditRequest(BaseModel): model: str = Field(...) - image: InputUrlObject = Field(...) + images: list[InputUrlObject] = Field(...) prompt: str = Field(...) resolution: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + aspect_ratio: str | None = Field(...) class VideoGenerationRequest(BaseModel): @@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel): revised_prompt: str | None = Field(None) +class UsageObject(BaseModel): + cost_in_usd_ticks: int | None = Field(None) + + class ImageGenerationResponse(BaseModel): data: list[ImageResponseObject] = Field(...) + usage: UsageObject | None = Field(None) class VideoGenerationResponse(BaseModel): @@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel): status: str | None = Field(None) video: VideoResponseObject | None = Field(None) model: str | None = Field(None) + usage: UsageObject | None = Field(None) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py index e84eba31e..dad9bc2fa 100644 --- a/comfy_api_nodes/apis/hunyuan3d.py +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel): JobId: str = Field(...) -class To3DUVFileInput(BaseModel): +class TaskFile3DInput(BaseModel): Type: str = Field(..., description="File type: GLB, OBJ, or FBX") Url: str = Field(...) class To3DUVTaskRequest(BaseModel): - File: To3DUVFileInput = Field(...) + File: TaskFile3DInput = Field(...) + + +class To3DPartTaskRequest(BaseModel): + File: TaskFile3DInput = Field(...) class TextureEditImageInfo(BaseModel): @@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel): class TextureEditTaskRequest(BaseModel): - File3D: To3DUVFileInput = Field(...) + File3D: TaskFile3DInput = Field(...) Image: TextureEditImageInfo | None = Field(None) Prompt: str | None = Field(None) EnablePBR: bool | None = Field(None) + + +class SmartTopologyRequest(BaseModel): + File3D: TaskFile3DInput = Field(...) + PolygonType: str | None = Field(...) + FaceLevel: str | None = Field(...) diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py index a5bd5f1d3..fe0f97cb3 100644 --- a/comfy_api_nodes/apis/kling.py +++ b/comfy_api_nodes/apis/kling.py @@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel): keep_original_sound: str = Field(...) character_orientation: str = Field(...) mode: str = Field(..., description="'pro' or 'std'") + model_name: str = Field(...) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index d83d2fc15..8225ea67e 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( ) -class GeminiModel(str, Enum): - """ - Gemini Model Names allowed by comfy-api - """ - - gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" - gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" - gemini_2_5_pro = "gemini-2.5-pro" - gemini_2_5_flash = "gemini-2.5-flash" - gemini_3_0_pro = "gemini-3-pro-preview" - - class GeminiImageModel(str, Enum): """ Gemini Image Model Names allowed by comfy-api @@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 30.0 - elif response.modelVersion == "gemini-3-pro-preview": + elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"): input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3.1-flash-lite-preview": + input_tokens_price = 0.25 + output_text_tokens_price = 1.50 + output_image_tokens_price = 0.0 elif response.modelVersion == "gemini-3-pro-image-preview": input_tokens_price = 2 output_text_tokens_price = 12.0 @@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode): ), IO.Combo.Input( "model", - options=GeminiModel, - default=GeminiModel.gemini_2_5_pro, + options=[ + "gemini-2.5-pro-preview-05-06", + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-pro-preview", + "gemini-3-1-pro", + "gemini-3-1-flash-lite", + ], + default="gemini-3-1-pro", tooltip="The Gemini model to use for generating responses.", ), IO.Int.Input( @@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode): "usd": [0.00125, 0.01], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } - : $contains($m, "gemini-3-pro-preview") ? { + : ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? { "type": "list_usd", "usd": [0.002, 0.012], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } + : $contains($m, "gemini-3-1-flash-lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : {"type":"text", "text":"Token-based"} ) """, @@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode): files: list[GeminiPart] | None = None, system_prompt: str = "", ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + if model == "gemini-3-pro-preview": + model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google + elif model == "gemini-3-1-pro": + model = "gemini-3.1-pro-preview" + elif model == "gemini-3-1-flash-lite": + model = "gemini-3.1-flash-lite-preview" - # Create parts list with text prompt as the first part parts: list[GeminiPart] = [GeminiPart(text=prompt)] - - # Add other modal parts if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index da15e97ea..0716d6239 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -27,6 +27,12 @@ from comfy_api_nodes.util import ( ) +def _extract_grok_price(response) -> float | None: + if response.usage and response.usage.cost_in_usd_ticks is not None: + return response.usage.cost_in_usd_ticks / 10_000_000_000 + return None + + class GrokImageNode(IO.ComfyNode): @classmethod @@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode): category="api node/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), IO.String.Input( "prompt", multiline=True, @@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), ], outputs=[ IO.Image.Output(), @@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": $rate * widgets.number_of_images} + ) + """, ), ) @@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio: str, number_of_images: int, seed: int, + resolution: str = "1K", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) response = await sync_op( @@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio=aspect_ratio, n=number_of_images, seed=seed, + resolution=resolution.lower(), ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode): category="api node/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), - IO.Image.Input("image"), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), + IO.Image.Input("image", display_name="images"), IO.String.Input( "prompt", multiline=True, tooltip="The text prompt used to generate the image", ), - IO.Combo.Input("resolution", options=["1K"]), + IO.Combo.Input("resolution", options=["1K", "2K"]), IO.Int.Input( "number_of_images", default=1, @@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + optional=True, + tooltip="Only allowed when multiple images are connected to the image input.", + ), ], outputs=[ IO.Image.Output(), @@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} + ) + """, ), ) @@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode): resolution: str, number_of_images: int, seed: int, + aspect_ratio: str = "auto", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Only one input image is supported.") + if model == "grok-imagine-image-pro": + if get_number_of_images(image) > 1: + raise ValueError("The pro model supports only 1 input image.") + elif get_number_of_images(image) > 3: + raise ValueError("A maximum of 3 input images is supported.") + if aspect_ratio != "auto" and get_number_of_images(image) == 1: + raise ValueError( + "Custom aspect ratio is only allowed when multiple images are connected to the image input." + ) response = await sync_op( cls, ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), data=ImageEditRequest( model=model, - image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], prompt=prompt, resolution=resolution.lower(), n=number_of_images, seed=seed, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode): category="api node/video/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), expr=""" ( - $base := 0.181 * widgets.duration; + $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $base := $rate * widgets.duration; {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} ) """, @@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) @@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode): category="api node/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""", ), ) @@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index d1d9578ec..bd8bde997 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import ( Hunyuan3DViewImage, InputGenerateType, ResultFile3D, + SmartTopologyRequest, + TaskFile3DInput, TextureEditTaskRequest, + To3DPartTaskRequest, To3DProTaskCreateResponse, To3DProTaskQueryRequest, To3DProTaskRequest, To3DProTaskResultResponse, - To3DUVFileInput, To3DUVTaskRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_file_3d, - download_url_to_image_tensor, downscale_image_tensor_by_max_side, poll_op, sync_op, @@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): outputs=[ IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DFBX.Output(display_name="FBX"), - IO.Image.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), response_model=To3DProTaskCreateResponse, data=To3DUVTaskRequest( - File=To3DUVFileInput( + File=TaskFile3DInput( Type=file_format.upper(), Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), ) @@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.NodeOutput( await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), - await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url), ) @@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), response_model=To3DProTaskCreateResponse, data=TextureEditTaskRequest( - File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), Prompt=prompt, EnablePBR=True, ), @@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode): cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), response_model=To3DProTaskCreateResponse, - data=To3DUVTaskRequest( - File=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + data=To3DPartTaskRequest( + File=TaskFile3DInput(Type=file_format.upper(), Url=model_url), ), is_rate_limited=_is_tencent_rate_limited, ) @@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode): ) +class TencentSmartTopologyNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentSmartTopologyNode", + display_name="Hunyuan3D: Smart Topology", + category="api node/3d/Tencent", + description="Perform smart retopology on a 3D model. " + "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny], + tooltip="Input 3D model (GLB or OBJ)", + ), + IO.Combo.Input( + "polygon_type", + options=["triangle", "quadrilateral"], + tooltip="Surface composition type.", + ), + IO.Combo.Input( + "face_level", + options=["medium", "high", "low"], + tooltip="Polygon reduction level.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DOBJ.Output(display_name="OBJ"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'), + ) + + SUPPORTED_FORMATS = {"glb", "obj"} + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + polygon_type: str, + face_level: str, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format not in cls.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." + ) + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"), + response_model=To3DProTaskCreateResponse, + data=SmartTopologyRequest( + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), + PolygonType=polygon_type, + FaceLevel=face_level, + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + ) + + class TencentHunyuan3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TencentTextToModelNode, TencentImageToModelNode, - # TencentModelTo3DUVNode, + TencentModelTo3DUVNode, # Tencent3DTextureEditNode, Tencent3DPartNode, + TencentSmartTopologyNode, ] diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 74fa078ff..8963c335d 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode): "but the character orientation matches the reference image (camera/other details via prompt).", ), IO.Combo.Input("mode", options=["pro", "std"]), + IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True), ], outputs=[ IO.Video.Output(), @@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound: bool, character_orientation: str, mode: str, + model: str = "kling-v2-6", ) -> IO.NodeOutput: validate_string(prompt, max_length=2500) validate_image_dimensions(reference_image, min_width=340, min_height=340) @@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound="yes" if keep_original_sound else "no", character_orientation=character_orientation, mode=mode, + model_name=model, ), ) if response.code: diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 94886af7b..79ffb77c1 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -83,7 +83,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] async def sync_op( diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 32fe921ff..c05571143 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) - pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 + if causal_fix is None: + causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1 + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix) pixel_coords[:, 0] += frame_idx # The following adjusts keyframe end positions for small grid IC-LoRA. @@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None): if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) if guide_mask is not None: target_h = max(noise_mask.shape[3], guide_mask.shape[3]) diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py new file mode 100644 index 000000000..6417bacf1 --- /dev/null +++ b/comfy_extras/nodes_math.py @@ -0,0 +1,119 @@ +"""Math expression node using simpleeval for safe evaluation. + +Provides a ComfyMathExpression node that evaluates math expressions +against dynamically-grown numeric inputs. +""" + +from __future__ import annotations + +import math +import string + +from simpleeval import simple_eval +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +MAX_EXPONENT = 4000 + + +def _variadic_sum(*args): + """Support both sum(values) and sum(a, b, c).""" + if len(args) == 1 and hasattr(args[0], "__iter__"): + return sum(args[0]) + return sum(args) + + +def _safe_pow(base, exp): + """Wrap pow() with an exponent cap to prevent DoS via huge exponents. + + The ** operator is already guarded by simpleeval's safe_power, but + pow() as a callable bypasses that guard. + """ + if abs(exp) > MAX_EXPONENT: + raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})") + return pow(base, exp) + + +MATH_FUNCTIONS = { + "sum": _variadic_sum, + "min": min, + "max": max, + "abs": abs, + "round": round, + "pow": _safe_pow, + "sqrt": math.sqrt, + "ceil": math.ceil, + "floor": math.floor, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "int": int, + "float": float, +} + + +class MathExpressionNode(io.ComfyNode): + """Evaluates a math expression against dynamically-grown inputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + autogrow = io.Autogrow.TemplateNames( + input=io.MultiType.Input("value", [io.Float, io.Int]), + names=list(string.ascii_lowercase), + min=1, + ) + return io.Schema( + node_id="ComfyMathExpression", + display_name="Math Expression", + category="math", + search_aliases=[ + "expression", "formula", "calculate", "calculator", + "eval", "math", + ], + inputs=[ + io.String.Input("expression", default="a + b", multiline=True), + io.Autogrow.Input("values", template=autogrow), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute( + cls, expression: str, values: io.Autogrow.Type + ) -> io.NodeOutput: + if not expression.strip(): + raise ValueError("Expression cannot be empty.") + + context: dict = dict(values) + context["values"] = list(values.values()) + + result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS) + # bool check must come first because bool is a subclass of int in Python + if isinstance(result, bool) or not isinstance(result, (int, float)): + raise ValueError( + f"Math Expression '{expression}' must evaluate to a numeric result, " + f"got {type(result).__name__}: {result!r}" + ) + if not math.isfinite(result): + raise ValueError( + f"Math Expression '{expression}' produced a non-finite result: {result}" + ) + return io.NodeOutput(float(result), int(result)) + + +class MathExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [MathExpressionNode] + + +async def comfy_entrypoint() -> MathExtension: + return MathExtension() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 97b9e948d..db4f9d231 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) tile //= 2 if tile < 128: raise e diff --git a/comfyui_version.py b/comfyui_version.py index 6a35c6de3..2723d02e7 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.1" +__version__ = "0.16.4" diff --git a/execution.py b/execution.py index 3d80606b4..d4112b207 100644 --- a/execution.py +++ b/execution.py @@ -613,7 +613,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, logging.error(traceback.format_exc()) tips = "" - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") diff --git a/main.py b/main.py index 0f58d57b8..8905fd09a 100644 --- a/main.py +++ b/main.py @@ -3,18 +3,22 @@ comfy.options.enable_args_parsing() import os import importlib.util +import shutil +import importlib.metadata import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.scanner import seed_assets import itertools import utils.extra_config +from utils.mime_types import init_mime_types +import faulthandler import logging import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +from app.database.db import init_db, dependencies_available if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -23,6 +27,8 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +faulthandler.enable(file=sys.stderr, all_threads=False) + import comfy_aimdo.control if enables_dynamic_vram(): @@ -62,8 +68,15 @@ if __name__ == "__main__": def handle_comfyui_manager_unavailable(): - if not args.windows_standalone_build: - logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") + uv_available = shutil.which("uv") is not None + + pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" + msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}" + if uv_available: + msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}" + msg += "\n" + logging.warning(msg) args.enable_manager = False @@ -161,6 +174,7 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() +init_mime_types() if args.enable_manager: comfyui_manager.prestartup() @@ -170,7 +184,6 @@ execute_prestartup_script() # Main code import asyncio -import shutil import threading import gc @@ -179,6 +192,7 @@ if 'torch' in sys.modules: import comfy.utils +from app.assets.seeder import asset_seeder import execution import server @@ -258,6 +272,7 @@ def prompt_worker(q, server_instance): for k in sensitive: extra_data[k] = sensitive[k] + asset_seeder.pause() e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True @@ -302,6 +317,7 @@ def prompt_worker(q, server_instance): last_gc_collect = current_time need_gc = False hook_breaker_ac10a0.restore_functions() + asset_seeder.resume() async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): @@ -352,12 +368,29 @@ def cleanup_temp(): def setup_database(): try: - from app.database.db import init_db, dependencies_available if dependencies_available(): init_db() - if not args.disable_assets_autoscan: - seed_assets(["models"], enable_logging=True) + if args.enable_assets: + if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True): + logging.info("Background asset scan initiated for models, input, output") except Exception as e: + if "database is locked" in str(e): + logging.error( + "Database is locked. Another ComfyUI process is already using this database.\n" + "To resolve this, specify a separate database file for this instance:\n" + " --database-url sqlite:///path/to/another.db" + ) + sys.exit(1) + if args.enable_assets: + logging.error( + f"Failed to initialize database: {e}\n" + "The --enable-assets flag requires a working database connection.\n" + "To resolve this, try one of the following:\n" + " 1. Install the latest requirements: pip install -r requirements.txt\n" + " 2. Specify an alternative database URL: --database-url sqlite:///path/to/your.db\n" + " 3. Use an in-memory database: --database-url sqlite:///:memory:" + ) + sys.exit(1) logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") @@ -429,6 +462,11 @@ if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") @@ -440,5 +478,6 @@ if __name__ == "__main__": event_loop.run_until_complete(x) except KeyboardInterrupt: logging.info("\nStopped server") - - cleanup_temp() + finally: + asset_seeder.shutdown() + cleanup_temp() diff --git a/manager_requirements.txt b/manager_requirements.txt index c420cc48e..6bcc3fb50 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b1 +comfyui_manager==4.1b2 \ No newline at end of file diff --git a/nodes.py b/nodes.py index 5be9b16f9..0ef23b640 100644 --- a/nodes.py +++ b/nodes.py @@ -2449,6 +2449,7 @@ async def init_builtin_extra_nodes(): "nodes_replacements.py", "nodes_nag.py", "nodes_sdpose.py", + "nodes_math.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 1b2318273..753b219b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.1" +version = "0.16.4" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index 608b0cfa6..bb58f8d01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.5 +comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch torchsde @@ -20,10 +20,13 @@ tqdm psutil alembic SQLAlchemy +filelock av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.4 +comfy-aimdo>=0.2.9 requests +simpleeval>=1.0.0 +blake3 #non essential dependencies: kornia>=0.7.1 diff --git a/server.py b/server.py index 275bce5a7..76904ebc9 100644 --- a/server.py +++ b/server.py @@ -33,8 +33,8 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal -from app.assets.scanner import seed_assets -from app.assets.api.routes import register_assets_system +from app.assets.seeder import asset_seeder +from app.assets.api.routes import register_assets_routes from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -197,10 +197,6 @@ class PromptServer(): def __init__(self, loop): PromptServer.instance = self - mimetypes.init() - mimetypes.add_type('application/javascript; charset=utf-8', '.js') - mimetypes.add_type('image/webp', '.webp') - self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() @@ -239,7 +235,11 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") - register_assets_system(self.app, self.user_manager) + if args.enable_assets: + register_assets_routes(self.app, self.user_manager) + else: + register_assets_routes(self.app) + asset_seeder.disable() routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -697,10 +697,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): - try: - seed_assets(["models"]) - except Exception as e: - logging.error(f"Failed to seed assets: {e}") + asset_seeder.start(roots=("models", "input", "output")) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 0a57dd7b5..6c5c56113 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -108,7 +108,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest) "main.py", f"--base-directory={str(comfy_tmp_base_dir)}", f"--database-url={db_url}", - "--disable-assets-autoscan", + "--enable-assets", "--listen", "127.0.0.1", "--port", @@ -212,7 +212,7 @@ def asset_factory(http: requests.Session, api_base: str): for aid in created: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) @pytest.fixture @@ -258,14 +258,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str): break for aid in ids: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) - - -def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: - """Force a fast sync/seed pass by calling the seed endpoint.""" - session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) - time.sleep(0.2) - - -def get_asset_filename(asset_hash: str, extension: str) -> str: - return asset_hash.removeprefix("blake3:") + extension + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) diff --git a/tests-unit/assets_test/helpers.py b/tests-unit/assets_test/helpers.py new file mode 100644 index 000000000..770e011f4 --- /dev/null +++ b/tests-unit/assets_test/helpers.py @@ -0,0 +1,28 @@ +"""Helper functions for assets integration tests.""" +import time + +import requests + + +def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: + """Force a synchronous sync/seed pass by calling the seed endpoint with wait=true. + + Retries on 409 (already running) until the previous scan finishes. + """ + deadline = time.monotonic() + 60 + while True: + r = session.post( + base_url + "/api/assets/seed?wait=true", + json={"roots": ["models", "input", "output"]}, + timeout=60, + ) + if r.status_code != 409: + assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}" + return + if time.monotonic() > deadline: + raise TimeoutError("seed endpoint stuck in 409 (already running)") + time.sleep(0.25) + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-unit/assets_test/queries/conftest.py b/tests-unit/assets_test/queries/conftest.py new file mode 100644 index 000000000..4ca0e86a9 --- /dev/null +++ b/tests-unit/assets_test/queries/conftest.py @@ -0,0 +1,20 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture +def session(): + """In-memory SQLite session for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + with Session(engine) as sess: + yield sess + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - query tests don't need server cleanup.""" + yield diff --git a/tests-unit/assets_test/queries/test_asset.py b/tests-unit/assets_test/queries/test_asset.py new file mode 100644 index 000000000..08f84cd11 --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset.py @@ -0,0 +1,144 @@ +import uuid + +import pytest +from sqlalchemy.orm import Session + +from app.assets.helpers import get_utc_now +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + get_asset_by_hash, + upsert_asset, + bulk_insert_assets, +) + + +class TestAssetExistsByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,expected", + [ + (None, "nonexistent", False), # No asset exists + ("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash + (None, "", False), # Null hash in DB doesn't match empty string + ], + ids=["nonexistent", "existing", "null_hash_no_match"], + ) + def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected): + if setup_hash is not None or query_hash == "": + asset = Asset(hash=setup_hash, size_bytes=100) + session.add(asset) + session.commit() + + assert asset_exists_by_hash(session, asset_hash=query_hash) is expected + + +class TestGetAssetByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,should_find", + [ + (None, "nonexistent", False), + ("blake3:def456", "blake3:def456", True), + ], + ids=["nonexistent", "existing"], + ) + def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find): + if setup_hash is not None: + asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png") + session.add(asset) + session.commit() + + result = get_asset_by_hash(session, asset_hash=query_hash) + if should_find: + assert result is not None + assert result.size_bytes == 200 + assert result.mime_type == "image/png" + else: + assert result is None + + +class TestUpsertAsset: + @pytest.mark.parametrize( + "first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime", + [ + # New asset creation + (None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"), + # Existing asset, same values - no update + (500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"), + # Existing asset with size 0, update with new values + (0, None, 2048, "image/png", False, True, 2048, "image/png"), + # Existing asset, second call with size 0 - no update + (1000, None, 0, None, False, False, 1000, None), + ], + ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"], + ) + def test_upsert_scenarios( + self, + session: Session, + first_size, + first_mime, + second_size, + second_mime, + expect_created, + expect_updated, + final_size, + final_mime, + ): + asset_hash = f"blake3:test_{first_size}_{second_size}" + + # First upsert (if first_size is not None, we're testing the second call) + if first_size is not None: + upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=first_size, + mime_type=first_mime, + ) + session.commit() + + # The upsert call we're testing + asset, created, updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=second_size, + mime_type=second_mime, + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + assert asset.size_bytes == final_size + assert asset.mime_type == final_mime + + +class TestBulkInsertAssets: + def test_inserts_multiple_assets(self, session: Session): + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now}, + ] + bulk_insert_assets(session, rows) + session.commit() + + assets = session.query(Asset).all() + assert len(assets) == 3 + hashes = {a.hash for a in assets} + assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"} + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_assets(session, []) + session.commit() + assert session.query(Asset).count() == 0 + + def test_handles_large_batch(self, session: Session): + """Test chunking logic with more rows than MAX_BIND_PARAMS allows.""" + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now} + for i in range(200) + ] + bulk_insert_assets(session, rows) + session.commit() + + assert session.query(Asset).count() == 200 diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py new file mode 100644 index 000000000..8f6c7fcdb --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -0,0 +1,517 @@ +import time +import uuid +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import ( + reference_exists_for_asset_id, + get_reference_by_id, + insert_reference, + get_or_create_reference, + update_reference_timestamps, + list_references_page, + fetch_reference_asset_and_tags, + fetch_reference_and_asset, + update_reference_access_time, + set_reference_metadata, + delete_reference_by_id, + set_reference_preview, + bulk_insert_references_ignore_conflicts, + get_reference_ids_by_ids, + ensure_tags_exist, + add_tags_to_reference, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestReferenceExistsForAssetId: + def test_returns_false_when_no_reference(self, session: Session): + asset = _make_asset(session, "hash1") + assert reference_exists_for_asset_id(session, asset_id=asset.id) is False + + def test_returns_true_when_reference_exists(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset) + assert reference_exists_for_asset_id(session, asset_id=asset.id) is True + + +class TestGetReferenceById: + def test_returns_none_for_nonexistent(self, session: Session): + assert get_reference_by_id(session, reference_id="nonexistent") is None + + def test_returns_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="myfile.txt") + + result = get_reference_by_id(session, reference_id=ref.id) + assert result is not None + assert result.name == "myfile.txt" + + +class TestListReferencesPage: + def test_empty_db(self, session: Session): + refs, tag_map, total = list_references_page(session) + assert refs == [] + assert tag_map == {} + assert total == 0 + + def test_returns_references_with_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + refs, tag_map, total = list_references_page(session) + assert len(refs) == 1 + assert refs[0].id == ref.id + assert set(tag_map[ref.id]) == {"alpha", "beta"} + assert total == 1 + + def test_name_contains_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="model_v1.safetensors") + _make_reference(session, asset, name="config.json") + session.commit() + + refs, _, total = list_references_page(session, name_contains="model") + assert total == 1 + assert refs[0].name == "model_v1.safetensors" + + def test_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="public", owner_id="") + _make_reference(session, asset, name="private", owner_id="user1") + session.commit() + + # Empty owner sees only public + refs, _, total = list_references_page(session, owner_id="") + assert total == 1 + assert refs[0].name == "public" + + # Owner sees both + refs, _, total = list_references_page(session, owner_id="user1") + assert total == 2 + + def test_include_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="tagged") + _make_reference(session, asset, name="untagged") + ensure_tags_exist(session, ["wanted"]) + add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"]) + session.commit() + + refs, _, total = list_references_page(session, include_tags=["wanted"]) + assert total == 1 + assert refs[0].name == "tagged" + + def test_exclude_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="keep") + ref_exclude = _make_reference(session, asset, name="exclude") + ensure_tags_exist(session, ["bad"]) + add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"]) + session.commit() + + refs, _, total = list_references_page(session, exclude_tags=["bad"]) + assert total == 1 + assert refs[0].name == "keep" + + def test_sorting(self, session: Session): + asset = _make_asset(session, "hash1", size=100) + asset2 = _make_asset(session, "hash2", size=500) + _make_reference(session, asset, name="small") + _make_reference(session, asset2, name="large") + session.commit() + + refs, _, _ = list_references_page(session, sort="size", order="desc") + assert refs[0].name == "large" + + refs, _, _ = list_references_page(session, sort="name", order="asc") + assert refs[0].name == "large" + + +class TestFetchReferenceAssetAndTags: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_asset_and_tags(session, "nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["tag1"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"]) + session.commit() + + result = fetch_reference_asset_and_tags(session, ref.id) + assert result is not None + ret_ref, ret_asset, ret_tags = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + assert ret_tags == ["tag1"] + + +class TestFetchReferenceAndAsset: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_and_asset(session, reference_id="nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = fetch_reference_and_asset(session, reference_id=ref.id) + assert result is not None + ret_ref, ret_asset = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + + +class TestUpdateReferenceAccessTime: + def test_updates_last_access_time(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_time = ref.last_access_time + session.commit() + + import time + time.sleep(0.01) + + update_reference_access_time(session, reference_id=ref.id) + session.commit() + + session.refresh(ref) + assert ref.last_access_time > original_time + + +class TestDeleteReferenceById: + def test_deletes_existing(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="") + assert result is True + assert get_reference_by_id(session, reference_id=ref.id) is None + + def test_returns_false_for_nonexistent(self, session: Session): + result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="") + assert result is False + + def test_respects_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2") + assert result is False + assert get_reference_by_id(session, reference_id=ref.id) is not None + + +class TestSetReferencePreview: + def test_sets_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + def test_clears_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) + session.commit() + + session.refresh(ref) + assert ref.preview_id is None + + def test_raises_for_nonexistent_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) + + def test_raises_for_nonexistent_preview(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + with pytest.raises(ValueError, match="Preview Asset"): + set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") + + +class TestInsertReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="test.bin" + ) + session.commit() + + assert ref is not None + assert ref.name == "test.bin" + assert ref.owner_id == "user1" + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin") + session.commit() + + # Duplicate names are now allowed + ref2 = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="dup.bin" + ) + session.commit() + + assert ref1 is not None + assert ref2 is not None + assert ref1.id != ref2.id + + +class TestGetOrCreateReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref, created = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="new.bin" + ) + session.commit() + + assert created is True + assert ref.name == "new.bin" + + def test_always_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref1, created1 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + # Duplicate names are allowed, so always creates new + ref2, created2 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + assert created1 is True + assert created2 is True + assert ref1.id != ref2.id + + +class TestUpdateReferenceTimestamps: + def test_updates_timestamps(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_updated_at = ref.updated_at + session.commit() + + time.sleep(0.01) + update_reference_timestamps(session, ref) + session.commit() + + session.refresh(ref) + assert ref.updated_at > original_updated_at + + def test_updates_preview_id(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + update_reference_timestamps(session, ref, preview_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + +class TestSetReferenceMetadata: + def test_sets_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {"key": "value"} + # Check metadata table + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "key" + assert meta[0].val_str == "value" + + def test_replaces_existing_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"old": "data"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"new": "data"} + ) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "new" + + def test_clears_metadata_with_empty_dict(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {} + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 0 + + def test_raises_for_nonexistent(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_metadata( + session, reference_id="nonexistent", user_metadata={"key": "value"} + ) + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk1.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk2.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + refs = session.query(AssetReference).all() + assert len(refs) == 2 + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="existing.bin", owner_id="") + session.commit() + + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "existing.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "new.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + # Duplicate names allowed, so all 3 rows exist + refs = session.query(AssetReference).all() + assert len(refs) == 3 + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferenceIdsByIds: + def test_returns_existing_ids(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="a.bin") + ref2 = _make_reference(session, asset, name="b.bin") + session.commit() + + found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"]) + + assert found == {ref1.id, ref2.id} + + def test_empty_list_returns_empty(self, session: Session): + found = get_reference_ids_by_ids(session, []) + assert found == set() diff --git a/tests-unit/assets_test/queries/test_cache_state.py b/tests-unit/assets_test/queries/test_cache_state.py new file mode 100644 index 000000000..ead60e570 --- /dev/null +++ b/tests-unit/assets_test/queries/test_cache_state.py @@ -0,0 +1,499 @@ +"""Tests for cache_state (AssetReference file path) query functions.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ( + list_references_by_asset_id, + upsert_reference, + get_unreferenced_unhashed_asset_ids, + delete_assets_by_ids, + get_references_for_prefixes, + bulk_update_needs_verify, + delete_references_by_ids, + delete_orphaned_seed_asset, + bulk_insert_references_ignore_conflicts, + get_references_by_paths_and_asset_ids, + mark_references_missing_outside_prefixes, + restore_references_by_paths, +) +from app.assets.helpers import select_best_live_path, get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + file_path: str, + name: str = "test", + mtime_ns: int | None = None, + needs_verify: bool = False, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + asset_id=asset.id, + file_path=file_path, + name=name, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestListReferencesByAssetId: + def test_returns_empty_for_no_references(self, session: Session): + asset = _make_asset(session, "hash1") + refs = list_references_by_asset_id(session, asset_id=asset.id) + assert list(refs) == [] + + def test_returns_references_for_asset(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/path/a.bin", name="a") + _make_reference(session, asset, "/path/b.bin", name="b") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset.id) + paths = [r.file_path for r in refs] + assert set(paths) == {"/path/a.bin", "/path/b.bin"} + + def test_does_not_return_other_assets_references(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + _make_reference(session, asset1, "/path/asset1.bin", name="a1") + _make_reference(session, asset2, "/path/asset2.bin", name="a2") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset1.id) + paths = [r.file_path for r in refs] + assert paths == ["/path/asset1.bin"] + + +class TestSelectBestLivePath: + def test_returns_empty_for_empty_list(self): + result = select_best_live_path([]) + assert result == "" + + def test_returns_empty_when_no_files_exist(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, "/nonexistent/path.bin") + session.commit() + + result = select_best_live_path([ref]) + assert result == "" + + def test_prefers_verified_path(self, session: Session, tmp_path): + """needs_verify=False should be preferred.""" + asset = _make_asset(session, "hash1") + + verified_file = tmp_path / "verified.bin" + verified_file.write_bytes(b"data") + + unverified_file = tmp_path / "unverified.bin" + unverified_file.write_bytes(b"data") + + ref_verified = _make_reference( + session, asset, str(verified_file), name="verified", needs_verify=False + ) + ref_unverified = _make_reference( + session, asset, str(unverified_file), name="unverified", needs_verify=True + ) + session.commit() + + refs = [ref_unverified, ref_verified] + result = select_best_live_path(refs) + assert result == str(verified_file) + + def test_falls_back_to_existing_unverified(self, session: Session, tmp_path): + """If all references need verification, return first existing path.""" + asset = _make_asset(session, "hash1") + + existing_file = tmp_path / "exists.bin" + existing_file.write_bytes(b"data") + + ref = _make_reference(session, asset, str(existing_file), needs_verify=True) + session.commit() + + result = select_best_live_path([ref]) + assert result == str(existing_file) + + +class TestSelectBestLivePathWithMocking: + def test_handles_missing_file_path_attr(self): + """Gracefully handle references with None file_path.""" + + class MockRef: + file_path = None + needs_verify = False + + result = select_best_live_path([MockRef()]) + assert result == "" + + +class TestUpsertReference: + @pytest.mark.parametrize( + "initial_mtime,second_mtime,expect_created,expect_updated,final_mtime", + [ + # New reference creation + (None, 12345, True, False, 12345), + # Existing reference, same mtime - no update + (100, 100, False, False, 100), + # Existing reference, different mtime - update + (100, 200, False, True, 200), + ], + ids=["new_reference", "existing_no_change", "existing_update_mtime"], + ) + def test_upsert_scenarios( + self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime + ): + asset = _make_asset(session, "hash1") + file_path = f"/path_{initial_mtime}_{second_mtime}.bin" + name = f"file_{initial_mtime}_{second_mtime}" + + # Create initial reference if needed + if initial_mtime is not None: + upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime) + session.commit() + + # The upsert call we're testing + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert ref.mtime_ns == final_mtime + + def test_upsert_restores_missing_reference(self, session: Session): + """Upserting a reference that was marked missing should restore it.""" + asset = _make_asset(session, "hash1") + file_path = "/restored/file.bin" + + ref = _make_reference(session, asset, file_path, mtime_ns=100) + ref.is_missing = True + session.commit() + + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100 + ) + session.commit() + + assert created is False + assert updated is True + restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert restored_ref.is_missing is False + + +class TestRestoreReferencesByPaths: + def test_restores_missing_references(self, session: Session): + asset = _make_asset(session, "hash1") + missing_path = "/missing/file.bin" + active_path = "/active/file.bin" + + missing_ref = _make_reference(session, asset, missing_path, name="missing") + missing_ref.is_missing = True + _make_reference(session, asset, active_path, name="active") + session.commit() + + restored = restore_references_by_paths(session, [missing_path]) + session.commit() + + assert restored == 1 + ref = session.query(AssetReference).filter_by(file_path=missing_path).one() + assert ref.is_missing is False + + def test_empty_list_restores_nothing(self, session: Session): + restored = restore_references_by_paths(session, []) + assert restored == 0 + + +class TestMarkReferencesMissingOutsidePrefixes: + def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + valid_dir = tmp_path / "valid" + valid_dir.mkdir() + invalid_dir = tmp_path / "invalid" + invalid_dir.mkdir() + + valid_path = str(valid_dir / "file.bin") + invalid_path = str(invalid_dir / "file.bin") + + _make_reference(session, asset, valid_path, name="valid") + _make_reference(session, asset, invalid_path, name="invalid") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)]) + session.commit() + + assert marked == 1 + all_refs = session.query(AssetReference).all() + assert len(all_refs) == 2 + + valid_ref = next(r for r in all_refs if r.file_path == valid_path) + invalid_ref = next(r for r in all_refs if r.file_path == invalid_path) + assert valid_ref.is_missing is False + assert invalid_ref.is_missing is True + + def test_empty_prefixes_marks_nothing(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, []) + + assert marked == 0 + + +class TestGetUnreferencedUnhashedAssetIds: + def test_returns_unreferenced_unhashed_assets(self, session: Session): + # Unhashed asset (hash=None) with no references (no file_path) + no_refs = _make_asset(session, hash_val=None) + # Unhashed asset with active reference (not unreferenced) + with_active_ref = _make_asset(session, hash_val=None) + _make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref") + # Unhashed asset with only missing reference (should be unreferenced) + with_missing_ref = _make_asset(session, hash_val=None) + missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref") + missing_ref.is_missing = True + # Regular asset (hash not None) - should not be returned + _make_asset(session, hash_val="blake3:regular") + session.commit() + + unreferenced = get_unreferenced_unhashed_asset_ids(session) + + assert no_refs.id in unreferenced + assert with_missing_ref.id in unreferenced + assert with_active_ref.id not in unreferenced + + +class TestDeleteAssetsByIds: + def test_deletes_assets_and_references(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_assets_by_ids(session, [asset.id]) + session.commit() + + assert deleted == 1 + assert session.query(Asset).count() == 0 + assert session.query(AssetReference).count() == 0 + + def test_empty_list_deletes_nothing(self, session: Session): + _make_asset(session, "hash1") + session.commit() + + deleted = delete_assets_by_ids(session, []) + + assert deleted == 0 + assert session.query(Asset).count() == 1 + + +class TestGetReferencesForPrefixes: + def test_returns_references_matching_prefix(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + dir1 = tmp_path / "dir1" + dir1.mkdir() + dir2 = tmp_path / "dir2" + dir2.mkdir() + + path1 = str(dir1 / "file.bin") + path2 = str(dir2 / "file.bin") + + _make_reference(session, asset, path1, name="file1", mtime_ns=100) + _make_reference(session, asset, path2, name="file2", mtime_ns=200) + session.commit() + + rows = get_references_for_prefixes(session, [str(dir1)]) + + assert len(rows) == 1 + assert rows[0].file_path == path1 + + def test_empty_prefixes_returns_empty(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + rows = get_references_for_prefixes(session, []) + + assert rows == [] + + +class TestBulkSetNeedsVerify: + def test_sets_needs_verify_flag(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False) + ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False) + session.commit() + + updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True) + session.commit() + + assert updated == 2 + session.refresh(ref1) + session.refresh(ref2) + assert ref1.needs_verify is True + assert ref2.needs_verify is True + + def test_empty_list_updates_nothing(self, session: Session): + updated = bulk_update_needs_verify(session, [], True) + assert updated == 0 + + +class TestDeleteReferencesByIds: + def test_deletes_references_by_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin") + _make_reference(session, asset, "/path2.bin") + session.commit() + + deleted = delete_references_by_ids(session, [ref1.id]) + session.commit() + + assert deleted == 1 + assert session.query(AssetReference).count() == 1 + + def test_empty_list_deletes_nothing(self, session: Session): + deleted = delete_references_by_ids(session, []) + assert deleted == 0 + + +class TestDeleteOrphanedSeedAsset: + @pytest.mark.parametrize( + "create_asset,expected_deleted,expected_count", + [ + (True, True, 0), # Existing asset gets deleted + (False, False, 0), # Nonexistent returns False + ], + ids=["deletes_existing", "nonexistent_returns_false"], + ) + def test_delete_orphaned_seed_asset( + self, session: Session, create_asset, expected_deleted, expected_count + ): + asset_id = "nonexistent-id" + if create_asset: + asset = _make_asset(session, hash_val=None) + asset_id = asset.id + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_orphaned_seed_asset(session, asset_id) + if create_asset: + session.commit() + + assert deleted is expected_deleted + assert session.query(Asset).count() == expected_count + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/bulk1.bin", + "name": "bulk1", + "mtime_ns": 100, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/bulk2.bin", + "name": "bulk2", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/existing.bin", mtime_ns=100) + session.commit() + + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/existing.bin", + "name": "existing", + "mtime_ns": 999, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/new.bin", + "name": "new", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one() + assert existing.mtime_ns == 100 # Original value preserved + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferencesByPathsAndAssetIds: + def test_returns_matching_paths(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + _make_reference(session, asset2, "/path2.bin") + session.commit() + + path_to_asset = { + "/path1.bin": asset1.id, + "/path2.bin": asset2.id, + } + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == {"/path1.bin", "/path2.bin"} + + def test_excludes_non_matching_asset_ids(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + session.commit() + + # Path exists but with different asset_id + path_to_asset = {"/path1.bin": asset2.id} + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == set() + + def test_empty_dict_returns_empty(self, session: Session): + winners = get_references_by_paths_and_asset_ids(session, {}) + assert winners == set() diff --git a/tests-unit/assets_test/queries/test_metadata.py b/tests-unit/assets_test/queries/test_metadata.py new file mode 100644 index 000000000..6a545e819 --- /dev/null +++ b/tests-unit/assets_test/queries/test_metadata.py @@ -0,0 +1,184 @@ +"""Tests for metadata filtering logic in asset_reference queries.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import list_references_page +from app.assets.database.queries.asset_reference import convert_metadata_to_rows +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str, + metadata: dict | None = None, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id="", + name=name, + asset_id=asset.id, + user_metadata=metadata, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + + if metadata: + for key, val in metadata.items(): + for row in convert_metadata_to_rows(key, val): + meta_row = AssetReferenceMeta( + asset_reference_id=ref.id, + key=row["key"], + ordinal=row.get("ordinal", 0), + val_str=row.get("val_str"), + val_num=row.get("val_num"), + val_bool=row.get("val_bool"), + val_json=row.get("val_json"), + ) + session.add(meta_row) + session.flush() + + return ref + + +class TestMetadataFilterByType: + """Table-driven tests for metadata filtering by different value types.""" + + @pytest.mark.parametrize( + "match_meta,nomatch_meta,filter_key,filter_val", + [ + # String matching + ({"category": "models"}, {"category": "images"}, "category", "models"), + # Integer matching + ({"epoch": 5}, {"epoch": 10}, "epoch", 5), + # Float matching + ({"score": 0.95}, {"score": 0.5}, "score", 0.95), + # Boolean True matching + ({"enabled": True}, {"enabled": False}, "enabled", True), + # Boolean False matching + ({"enabled": False}, {"enabled": True}, "enabled", False), + ], + ids=["string", "int", "float", "bool_true", "bool_false"], + ) + def test_filter_matches_correct_value( + self, session: Session, match_meta, nomatch_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", match_meta) + _make_reference(session, asset, "nomatch", nomatch_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 1 + assert refs[0].name == "match" + + @pytest.mark.parametrize( + "stored_meta,filter_key,filter_val", + [ + # String no match + ({"category": "models"}, "category", "other"), + # Int no match + ({"epoch": 5}, "epoch", 99), + # Float no match + ({"score": 0.5}, "score", 0.99), + ], + ids=["string_no_match", "int_no_match", "float_no_match"], + ) + def test_filter_returns_empty_when_no_match( + self, session: Session, stored_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "item", stored_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 0 + + +class TestMetadataFilterNull: + """Tests for null/missing key filtering.""" + + @pytest.mark.parametrize( + "match_name,match_meta,nomatch_name,nomatch_meta,filter_key", + [ + # Null matches missing key + ("missing_key", {}, "has_key", {"optional": "value"}, "optional"), + # Null matches explicit null + ("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"), + ], + ids=["missing_key", "explicit_null"], + ) + def test_null_filter_matches( + self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, match_name, match_meta) + _make_reference(session, asset, nomatch_name, nomatch_meta) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={filter_key: None}) + assert total == 1 + assert refs[0].name == match_name + + +class TestMetadataFilterList: + """Tests for list-based (OR) filtering.""" + + def test_filter_by_list_matches_any(self, session: Session): + """List values should match ANY of the values (OR).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "cat_a", {"category": "a"}) + _make_reference(session, asset, "cat_b", {"category": "b"}) + _make_reference(session, asset, "cat_c", {"category": "c"}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={"category": ["a", "b"]}) + assert total == 2 + names = {r.name for r in refs} + assert names == {"cat_a", "cat_b"} + + +class TestMetadataFilterMultipleKeys: + """Tests for multiple filter keys (AND semantics).""" + + def test_multiple_keys_must_all_match(self, session: Session): + """Multiple keys should ALL match (AND).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", {"type": "model", "version": 2}) + _make_reference(session, asset, "wrong_type", {"type": "config", "version": 2}) + _make_reference(session, asset, "wrong_version", {"type": "model", "version": 1}) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={"type": "model", "version": 2} + ) + assert total == 1 + assert refs[0].name == "match" + + +class TestMetadataFilterEmptyDict: + """Tests for empty filter behavior.""" + + def test_empty_filter_returns_all(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "a", {"key": "val"}) + _make_reference(session, asset, "b", {}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={}) + assert total == 2 diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py new file mode 100644 index 000000000..4ed99aa37 --- /dev/null +++ b/tests-unit/assets_test/queries/test_tags.py @@ -0,0 +1,366 @@ +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, AssetReferenceMeta, Tag +from app.assets.database.queries import ( + ensure_tags_exist, + get_reference_tags, + set_reference_tags, + add_tags_to_reference, + remove_tags_from_reference, + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, + list_tags_with_usage, + bulk_insert_tags_and_meta, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestEnsureTagsExist: + def test_creates_new_tags(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta"], tag_type="user") + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_is_idempotent(self, session: Session): + ensure_tags_exist(session, ["alpha"], tag_type="user") + ensure_tags_exist(session, ["alpha"], tag_type="user") + session.commit() + + assert session.query(Tag).count() == 1 + + def test_normalizes_tags(self, session: Session): + ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"]) + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_empty_list_is_noop(self, session: Session): + ensure_tags_exist(session, []) + session.commit() + assert session.query(Tag).count() == 0 + + def test_tag_type_is_set(self, session: Session): + ensure_tags_exist(session, ["system-tag"], tag_type="system") + session.commit() + + tag = session.query(Tag).filter_by(name="system-tag").one() + assert tag.tag_type == "system" + + +class TestGetReferenceTags: + def test_returns_empty_for_no_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + tags = get_reference_tags(session, reference_id=ref.id) + assert tags == [] + + def test_returns_tags_for_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + ensure_tags_exist(session, ["tag1", "tag2"]) + session.add_all([ + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag1", origin="manual", added_at=get_utc_now()), + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag2", origin="manual", added_at=get_utc_now()), + ]) + session.flush() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"tag1", "tag2"} + + +class TestSetReferenceTags: + def test_adds_new_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.added) == {"a", "b"} + assert result.removed == [] + assert set(result.total) == {"a", "b"} + + def test_removes_old_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["a"]) + session.commit() + + assert result.added == [] + assert set(result.removed) == {"b", "c"} + assert result.total == ["a"] + + def test_replaces_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"]) + session.commit() + + assert result.added == ["c"] + assert result.removed == ["a"] + assert set(result.total) == {"b", "c"} + + +class TestAddTagsToReference: + def test_adds_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert set(result.added) == {"x", "y"} + assert result.already_present == [] + + def test_reports_already_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["x"]) + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert result.added == ["y"] + assert result.already_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + add_tags_to_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestRemoveTagsFromReference: + def test_removes_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"]) + session.commit() + + assert result.removed == ["a"] + assert result.not_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + remove_tags_from_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestMissingTagFunctions: + def test_add_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" in tags + + def test_add_missing_tag_is_idempotent(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="missing").all() + assert len(links) == 1 + + def test_remove_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + add_missing_tag_for_asset_id(session, asset_id=asset.id) + + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" not in tags + + +class TestListTagsWithUsage: + def test_returns_tags_with_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session) + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_exclude_zero_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session, include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, total = list_tags_with_usage(session, prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags_with_usage(session, order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_owner_visibility(self, session: Session): + ensure_tags_exist(session, ["shared-tag", "owner-tag"]) + + asset = _make_asset(session, "hash1") + shared_ref = _make_reference(session, asset, name="shared", owner_id="") + owner_ref = _make_reference(session, asset, name="owned", owner_id="user1") + + add_tags_to_reference(session, reference_id=shared_ref.id, tags=["shared-tag"]) + add_tags_to_reference(session, reference_id=owner_ref.id, tags=["owner-tag"]) + session.commit() + + # Empty owner sees only shared + rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 0 + + # User1 sees both + rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 1 + + +class TestBulkInsertTagsAndMeta: + def test_inserts_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now}, + {"asset_reference_id": ref.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"bulk-tag1", "bulk-tag2"} + + def test_inserts_meta(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + meta_rows = [ + { + "asset_reference_id": ref.id, + "key": "meta-key", + "ordinal": 0, + "val_str": "meta-value", + "val_num": None, + "val_bool": None, + "val_json": None, + }, + ] + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "meta-key" + assert meta[0].val_str == "meta-value" + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing-tag"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing-tag"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + # Should still have only one tag link + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="existing-tag").all() + assert len(links) == 1 + # Origin should be original, not overwritten + assert links[0].origin == "manual" + + def test_empty_lists_is_noop(self, session: Session): + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[]) + assert session.query(AssetReferenceTag).count() == 0 + assert session.query(AssetReferenceMeta).count() == 0 diff --git a/tests-unit/assets_test/services/__init__.py b/tests-unit/assets_test/services/__init__.py new file mode 100644 index 000000000..d0213422e --- /dev/null +++ b/tests-unit/assets_test/services/__init__.py @@ -0,0 +1 @@ +# Service layer tests diff --git a/tests-unit/assets_test/services/conftest.py b/tests-unit/assets_test/services/conftest.py new file mode 100644 index 000000000..31c763d48 --- /dev/null +++ b/tests-unit/assets_test/services/conftest.py @@ -0,0 +1,54 @@ +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - service unit tests don't need server cleanup.""" + yield + + +@pytest.fixture +def db_engine(): + """In-memory SQLite engine for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + """Session fixture for tests that need direct DB access.""" + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def mock_create_session(db_engine): + """Patch create_session to use our in-memory database.""" + from contextlib import contextmanager + from sqlalchemy.orm import Session as SASession + + @contextmanager + def _create_session(): + with SASession(db_engine) as sess: + yield sess + + with patch("app.assets.services.ingest.create_session", _create_session), \ + patch("app.assets.services.asset_management.create_session", _create_session), \ + patch("app.assets.services.tagging.create_session", _create_session): + yield _create_session + + +@pytest.fixture +def temp_dir(): + """Temporary directory for file operations.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) diff --git a/tests-unit/assets_test/services/test_asset_management.py b/tests-unit/assets_test/services/test_asset_management.py new file mode 100644 index 000000000..101ef7292 --- /dev/null +++ b/tests-unit/assets_test/services/test_asset_management.py @@ -0,0 +1,268 @@ +"""Tests for asset_management services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import ( + get_asset_detail, + update_asset_metadata, + delete_asset_reference, + set_asset_preview, +) + + +def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestGetAssetDetail: + def test_returns_none_for_nonexistent(self, mock_create_session): + result = get_asset_detail(reference_id="nonexistent") + assert result is None + + def test_returns_asset_with_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + result = get_asset_detail(reference_id=ref.id) + + assert result is not None + assert result.ref.id == ref.id + assert result.asset.hash == asset.hash + assert set(result.tags) == {"alpha", "beta"} + + def test_respects_owner_visibility(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + # Wrong owner cannot see + result = get_asset_detail(reference_id=ref.id, owner_id="user2") + assert result is None + + # Correct owner can see + result = get_asset_detail(reference_id=ref.id, owner_id="user1") + assert result is not None + + +class TestUpdateAssetMetadata: + def test_updates_name(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="old_name.bin") + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + name="new_name.bin", + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.name == "new_name.bin" + + def test_updates_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["old"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["old"]) + session.commit() + + result = update_asset_metadata( + reference_id=ref.id, + tags=["new1", "new2"], + ) + + assert set(result.tags) == {"new1", "new2"} + assert "old" not in result.tags + + def test_updates_user_metadata(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + user_metadata={"key": "value", "num": 42}, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.user_metadata["key"] == "value" + assert updated_ref.user_metadata["num"] == 42 + + def test_raises_for_nonexistent(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + update_asset_metadata(reference_id="nonexistent", name="fail") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + update_asset_metadata( + reference_id=ref.id, + name="new", + owner_id="user2", + ) + + +class TestDeleteAssetReference: + def test_soft_deletes_reference(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=False, + ) + + assert result is True + # Row still exists but is marked as soft-deleted + session.expire_all() + row = session.get(AssetReference, ref_id) + assert row is not None + assert row.deleted_at is not None + + def test_returns_false_for_nonexistent(self, mock_create_session): + result = delete_asset_reference( + reference_id="nonexistent", + owner_id="", + ) + assert result is False + + def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="user2", + ) + + assert result is False + assert session.get(AssetReference, ref_id) is not None + + def test_keeps_asset_if_other_references_exist(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref1 = _make_reference(session, asset, name="ref1") + _make_reference(session, asset, name="ref2") # Second ref keeps asset alive + asset_id = asset.id + session.commit() + + delete_asset_reference( + reference_id=ref1.id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Asset should still exist + assert session.get(Asset, asset_id) is not None + + def test_deletes_orphaned_asset(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + asset_id = asset.id + ref_id = ref.id + session.commit() + + delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Both ref and asset should be gone + assert session.get(AssetReference, ref_id) is None + assert session.get(Asset, asset_id) is None + + +class TestSetAssetPreview: + def test_sets_preview(self, mock_create_session, session: Session): + asset = _make_asset(session, hash_val="blake3:main") + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref_id = ref.id + preview_id = preview_asset.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=preview_id, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id == preview_id + + def test_clears_preview(self, mock_create_session, session: Session): + asset = _make_asset(session) + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + ref_id = ref.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=None, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id is None + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + set_asset_preview(reference_id="nonexistent") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + set_asset_preview( + reference_id=ref.id, + preview_asset_id=None, + owner_id="user2", + ) diff --git a/tests-unit/assets_test/services/test_bulk_ingest.py b/tests-unit/assets_test/services/test_bulk_ingest.py new file mode 100644 index 000000000..26e22a01d --- /dev/null +++ b/tests-unit/assets_test/services/test_bulk_ingest.py @@ -0,0 +1,137 @@ +"""Tests for bulk ingest services.""" + +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets + + +class TestBatchInsertSeedAssets: + def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path): + """Verify mime_type is stored in the Asset table for model files.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "Test Model", + "tags": ["models"], + "fname": "model.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + # Verify Asset has mime_type populated + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type == "application/safetensors" + + def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path): + """Verify mime_type is None when not provided in spec.""" + file_path = temp_dir / "unknown.bin" + file_path.write_bytes(b"binary data") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 11, + "mtime_ns": 1234567890000000000, + "info_name": "Unknown File", + "tags": [], + "fname": "unknown.bin", + "metadata": None, + "hash": None, + "mime_type": None, + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type is None + + def test_various_model_mime_types(self, session: Session, temp_dir: Path): + """Verify various model file types get correct mime_type.""" + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + specs: list[SeedAssetSpec] = [] + for filename, mime_type in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + specs.append( + { + "abs_path": str(file_path), + "size_bytes": 7, + "mtime_ns": 1234567890000000000, + "info_name": filename, + "tags": [], + "fname": filename, + "metadata": None, + "hash": None, + "mime_type": mime_type, + } + ) + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == len(test_cases) + + for filename, expected_mime in test_cases: + ref = session.query(AssetReference).filter_by(name=filename).first() + assert ref is not None + asset = session.query(Asset).filter_by(id=ref.asset_id).first() + assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}" + + +class TestMetadataExtraction: + def test_extracts_mime_type_for_model_files(self, temp_dir: Path): + """Verify metadata extraction returns correct mime_type for model files.""" + from app.assets.services.metadata_extract import extract_file_metadata + + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == "application/safetensors" + + def test_mime_type_for_various_model_formats(self, temp_dir: Path): + """Verify various model file types get correct mime_type from metadata.""" + from app.assets.services.metadata_extract import extract_file_metadata + + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.sft", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.pth", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.pkl", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + for filename, expected_mime in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}" diff --git a/tests-unit/assets_test/services/test_enrich.py b/tests-unit/assets_test/services/test_enrich.py new file mode 100644 index 000000000..2bd79a01a --- /dev/null +++ b/tests-unit/assets_test/services/test_enrich.py @@ -0,0 +1,207 @@ +"""Tests for asset enrichment (mime_type and hash population).""" +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.scanner import ( + ENRICHMENT_HASHED, + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + enrich_asset, +) + + +def _create_stub_asset( + session: Session, + file_path: str, + asset_id: str = "test-asset-id", + reference_id: str = "test-ref-id", + name: str | None = None, +) -> tuple[Asset, AssetReference]: + """Create a stub asset with reference for testing enrichment.""" + asset = Asset( + id=asset_id, + hash=None, + size_bytes=100, + mime_type=None, + ) + session.add(asset) + session.flush() + + ref = AssetReference( + id=reference_id, + asset_id=asset_id, + name=name or f"test-asset-{asset_id}", + owner_id="system", + file_path=file_path, + mtime_ns=1234567890000000000, + enrichment_level=ENRICHMENT_STUB, + ) + session.add(ref) + session.flush() + + return asset, ref + + +class TestEnrichAsset: + def test_extracts_mime_type_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify mime_type is written to the Asset table during enrichment.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 100) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-1", "ref-1" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=False, + ) + + assert new_level == ENRICHMENT_METADATA + + session.expire_all() + updated_asset = session.get(Asset, "asset-1") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + + def test_computes_hash_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify hash is written to the Asset table during enrichment.""" + file_path = temp_dir / "data.bin" + file_path.write_bytes(b"test content for hashing") + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-2", "ref-2" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_HASHED + + session.expire_all() + updated_asset = session.get(Asset, "asset-2") + assert updated_asset is not None + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_enrichment_updates_both_mime_and_hash( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify both mime_type and hash are set when full enrichment runs.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 50) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-3", "ref-3" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + updated_asset = session.get(Asset, "asset-3") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_missing_file_returns_stub_level( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify missing files don't cause errors and return STUB level.""" + file_path = temp_dir / "nonexistent.bin" + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-4", "ref-4" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_STUB + + session.expire_all() + updated_asset = session.get(Asset, "asset-4") + assert updated_asset.mime_type is None + assert updated_asset.hash is None + + def test_duplicate_hash_merges_into_existing_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify duplicate files merge into existing asset instead of failing.""" + file_path_1 = temp_dir / "file1.bin" + file_path_2 = temp_dir / "file2.bin" + content = b"identical content" + file_path_1.write_bytes(content) + file_path_2.write_bytes(content) + + asset1, ref1 = _create_stub_asset( + session, str(file_path_1), "asset-dup-1", "ref-dup-1" + ) + asset2, ref2 = _create_stub_asset( + session, str(file_path_2), "asset-dup-2", "ref-dup-2" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path_1), + reference_id=ref1.id, + asset_id=asset1.id, + extract_metadata=True, + compute_hash=True, + ) + + enrich_asset( + session, + file_path=str(file_path_2), + reference_id=ref2.id, + asset_id=asset2.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + + updated_asset1 = session.get(Asset, "asset-dup-1") + assert updated_asset1 is not None + assert updated_asset1.hash is not None + + updated_asset2 = session.get(Asset, "asset-dup-2") + assert updated_asset2 is None + + updated_ref2 = session.get(AssetReference, "ref-dup-2") + assert updated_ref2 is not None + assert updated_ref2.asset_id == "asset-dup-1" diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py new file mode 100644 index 000000000..367bc7721 --- /dev/null +++ b/tests-unit/assets_test/services/test_ingest.py @@ -0,0 +1,229 @@ +"""Tests for ingest services.""" +from pathlib import Path + +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, Tag +from app.assets.database.queries import get_reference_tags +from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset + + +class TestIngestFileFromPath: + def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "test_file.bin" + file_path.write_bytes(b"test content") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:abc123", + size_bytes=12, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + ) + + assert result.asset_created is True + assert result.ref_created is True + assert result.reference_id is not None + + # Verify DB state + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].hash == "blake3:abc123" + + refs = session.query(AssetReference).all() + assert len(refs) == 1 + assert refs[0].file_path == str(file_path) + + def test_creates_reference_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"model data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:def456", + size_bytes=10, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + info_name="My Model", + owner_id="user1", + ) + + assert result.asset_created is True + assert result.reference_id is not None + + ref = session.query(AssetReference).first() + assert ref is not None + assert ref.name == "My Model" + assert ref.owner_id == "user1" + + def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "tagged.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:ghi789", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Tagged Asset", + tags=["models", "checkpoints"], + ) + + assert result.reference_id is not None + + # Verify tags were created and linked + tags = session.query(Tag).all() + tag_names = {t.name for t in tags} + assert "models" in tag_names + assert "checkpoints" in tag_names + + ref_tags = get_reference_tags(session, reference_id=result.reference_id) + assert set(ref_tags) == {"models", "checkpoints"} + + def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "dup.bin" + file_path.write_bytes(b"content") + + # First ingest + r1 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000000, + ) + assert r1.asset_created is True + + # Second ingest with same hash - should update, not create + r2 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000001, # different mtime + ) + assert r2.asset_created is False + assert r2.ref_created is False + assert r2.ref_updated is True + + # Still only one asset + assets = session.query(Asset).all() + assert len(assets) == 1 + + def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "with_preview.bin" + file_path.write_bytes(b"data") + + # Create a preview asset first + preview_asset = Asset(hash="blake3:preview", size_bytes=100) + session.add(preview_asset) + session.commit() + preview_id = preview_asset.id + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:main", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="With Preview", + preview_id=preview_id, + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id == preview_id + + def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "bad_preview.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:badpreview", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Bad Preview", + preview_id="nonexistent-uuid", + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id is None + + +class TestRegisterExistingAsset: + def test_creates_reference_for_existing_asset(self, mock_create_session, session: Session): + # Create existing asset + asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png") + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:existing", + name="Registered Asset", + user_metadata={"key": "value"}, + tags=["models"], + ) + + assert result.created is True + assert "models" in result.tags + + # Verify by re-fetching from DB + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Registered Asset").all() + assert len(refs) == 1 + + def test_creates_new_reference_even_with_same_name(self, mock_create_session, session: Session): + # Create asset and reference + asset = Asset(hash="blake3:withref", size_bytes=512) + session.add(asset) + session.flush() + + from app.assets.helpers import get_utc_now + ref = AssetReference( + owner_id="", + name="Existing Ref", + asset_id=asset.id, + created_at=get_utc_now(), + updated_at=get_utc_now(), + last_access_time=get_utc_now(), + ) + session.add(ref) + session.flush() + ref_id = ref.id + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:withref", + name="Existing Ref", + owner_id="", + ) + + # Multiple files with same name are allowed + assert result.created is True + + # Verify two AssetReferences exist for this name + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Existing Ref").all() + assert len(refs) == 2 + assert ref_id in [r.id for r in refs] + + def test_raises_for_nonexistent_hash(self, mock_create_session): + with pytest.raises(ValueError, match="No asset with hash"): + _register_existing_asset( + asset_hash="blake3:doesnotexist", + name="Fail", + ) + + def test_applies_tags_to_new_reference(self, mock_create_session, session: Session): + asset = Asset(hash="blake3:tagged", size_bytes=256) + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:tagged", + name="Tagged Ref", + tags=["alpha", "beta"], + ) + + assert result.created is True + assert set(result.tags) == {"alpha", "beta"} diff --git a/tests-unit/assets_test/services/test_tagging.py b/tests-unit/assets_test/services/test_tagging.py new file mode 100644 index 000000000..ab69e5dc1 --- /dev/null +++ b/tests-unit/assets_test/services/test_tagging.py @@ -0,0 +1,197 @@ +"""Tests for tagging services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import apply_tags, remove_tags, list_tags + + +def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestApplyTags: + def test_adds_new_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["alpha", "beta"], + ) + + assert set(result.added) == {"alpha", "beta"} + assert result.already_present == [] + assert set(result.total_tags) == {"alpha", "beta"} + + def test_reports_already_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing"]) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["existing", "new"], + ) + + assert result.added == ["new"] + assert result.already_present == ["existing"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + apply_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + apply_tags( + reference_id=ref.id, + tags=["new"], + owner_id="user2", + ) + + +class TestRemoveTags: + def test_removes_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["a", "b", "c"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["a", "b"], + ) + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["present"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["present"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["present", "absent"], + ) + + assert result.removed == ["present"] + assert result.not_present == ["absent"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + remove_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + remove_tags( + reference_id=ref.id, + tags=["x"], + owner_id="user2", + ) + + +class TestListTags: + def test_returns_tags_with_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags() + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_excludes_zero_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags(include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, _ = list_tags(prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags(order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_pagination(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a", "b", "c", "d", "e"]) + session.commit() + + rows, total = list_tags(limit=2, offset=1, order="name_asc") + + assert total == 5 + assert len(rows) == 2 + names = [name for name, _, _ in rows] + assert names == ["b", "c"] + + def test_clamps_limit(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a"]) + session.commit() + + # Service should clamp limit to max 1000 + rows, _ = list_tags(limit=2000) + assert len(rows) <= 1000 diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py index 78fa7b404..47dc130cb 100644 --- a/tests-unit/assets_test/test_assets_missing_sync.py +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py index d2b69f475..07310223e 100644 --- a/tests-unit/assets_test/test_crud.py +++ b/tests-unit/assets_test/test_crud.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_create_from_hash_success( @@ -24,11 +24,11 @@ def test_create_from_hash_success( assert b1["created_new"] is False aid = b1["id"] - # Calling again with the same name should return the same AssetInfo id + # Calling again with the same name creates a new AssetReference (duplicates allowed) r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) b2 = r2.json() assert r2.status_code == 201, b2 - assert b2["id"] == aid + assert b2["id"] != aid # new reference, not the same one def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict): @@ -42,8 +42,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert "user_metadata" in detail assert "filename" in detail["user_metadata"] - # DELETE - rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + # DELETE (hard delete to also remove underlying asset and file) + rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) assert rd.status_code == 204 # GET again -> 404 @@ -53,6 +53,35 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert body["error"]["code"] == "ASSET_NOT_FOUND" +def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + asset_hash = seeded_asset["asset_hash"] + + # Soft-delete (default, no delete_content param) + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + assert rd.status_code == 204 + + # GET by reference ID -> 404 (soft-deleted references are hidden) + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + assert rg.status_code == 404 + + # Asset identity is preserved (underlying content still exists) + rh = http.head(f"{api_base}/api/assets/hash/{asset_hash}", timeout=120) + assert rh.status_code == 200 + + # Soft-deleted reference should not appear in listings + rl = http.get( + f"{api_base}/api/assets", + params={"include_tags": "unit-tests", "limit": "500"}, + timeout=120, + ) + ids = [a["id"] for a in rl.json().get("assets", [])] + assert aid not in ids + + # Clean up: hard-delete the soft-deleted reference and orphaned asset + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + + def test_delete_upon_reference_count( http: requests.Session, api_base: str, seeded_asset: dict ): @@ -70,21 +99,32 @@ def test_delete_upon_reference_count( assert copy["asset_hash"] == src_hash assert copy["created_new"] is False - # Delete original reference -> asset identity must remain + # Soft-delete original reference (default) -> asset identity must remain aid1 = seeded_asset["id"] rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120) assert rd1.status_code == 204 rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh1.status_code == 200 # identity still present + assert rh1.status_code == 200 # identity still present (second ref exists) - # Delete the last reference with default semantics -> identity and cached files removed + # Soft-delete the last reference -> asset identity preserved (no hard delete) aid2 = copy["id"] rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120) assert rd2.status_code == 204 rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh2.status_code == 404 # orphan content removed + assert rh2.status_code == 200 # asset identity preserved (soft delete) + + # Re-associate via from-hash, then hard-delete -> orphan content removed + r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + assert r3.status_code == 201, r3.json() + aid3 = r3.json()["id"] + + rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120) + assert rd3.status_code == 204 + + rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh3.status_code == 404 # orphan content removed def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): @@ -126,42 +166,52 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api assert body == b"" -def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str): - bogus = str(uuid.uuid4()) - r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120) +@pytest.mark.parametrize( + "method,endpoint_template,payload,expected_status,error_code", + [ + # Delete nonexistent asset + ("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"), + # Bad hash algorithm in from-hash + ( + "post", + "/api/assets/from-hash", + {"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]}, + 400, + "INVALID_BODY", + ), + # Get with bad UUID format + ("get", "/api/assets/not-a-uuid", None, 404, None), + # Get content with bad UUID format + ("get", "/api/assets/not-a-uuid/content", None, 404, None), + ], + ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"], +) +def test_error_responses( + http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code +): + # Replace {uuid} placeholder with a random UUID for delete test + endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4())) + url = f"{api_base}{endpoint}" + + if method == "get": + r = http.get(url, timeout=120) + elif method == "post": + r = http.post(url, json=payload, timeout=120) + elif method == "delete": + r = http.delete(url, timeout=120) + + assert r.status_code == expected_status + if error_code: + body = r.json() + assert body["error"]["code"] == error_code + + +def test_create_from_hash_invalid_json(http: requests.Session, api_base: str): + """Invalid JSON body requires special handling (data= instead of json=).""" + r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) body = r.json() - assert r.status_code == 404 - assert body["error"]["code"] == "ASSET_NOT_FOUND" - - -def test_create_from_hash_invalids(http: requests.Session, api_base: str): - # Bad hash algorithm - bad = { - "hash": "sha256:" + "0" * 64, - "name": "x.bin", - "tags": ["models", "checkpoints", "unit-tests"], - } - r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_BODY" - - # Invalid JSON body - r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_JSON" - - -def test_get_update_download_bad_ids(http: requests.Session, api_base: str): - # All endpoints should be not found, as we UUID regex directly in the route definition. - bad_id = "not-a-uuid" - - r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120) - assert r1.status_code == 404 - - r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120) - assert r3.status_code == 404 + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_JSON" def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict): diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py index cdebf9082..672ba9728 100644 --- a/tests-unit/assets_test/test_downloads.py +++ b/tests-unit/assets_test/test_downloads.py @@ -6,7 +6,7 @@ from typing import Optional import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict): @@ -117,7 +117,7 @@ def test_download_missing_file_returns_404( assert body["error"]["code"] == "FILE_NOT_FOUND" finally: # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. - dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) dr.content diff --git a/tests-unit/assets_test/test_file_utils.py b/tests-unit/assets_test/test_file_utils.py new file mode 100644 index 000000000..e3591d49b --- /dev/null +++ b/tests-unit/assets_test/test_file_utils.py @@ -0,0 +1,121 @@ +import os +import sys + +import pytest + +from app.assets.services.file_utils import is_visible, list_files_recursively + + +class TestIsVisible: + def test_visible_file(self): + assert is_visible("file.txt") is True + + def test_hidden_file(self): + assert is_visible(".hidden") is False + + def test_hidden_directory(self): + assert is_visible(".git") is False + + def test_visible_directory(self): + assert is_visible("src") is True + + def test_dotdot_is_hidden(self): + assert is_visible("..") is False + + def test_dot_is_hidden(self): + assert is_visible(".") is False + + +class TestListFilesRecursively: + def test_skips_hidden_files(self, tmp_path): + (tmp_path / "visible.txt").write_text("a") + (tmp_path / ".hidden").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert result[0].endswith("visible.txt") + + def test_skips_hidden_directories(self, tmp_path): + hidden_dir = tmp_path / ".hidden_dir" + hidden_dir.mkdir() + (hidden_dir / "file.txt").write_text("a") + + visible_dir = tmp_path / "visible_dir" + visible_dir.mkdir() + (visible_dir / "file.txt").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert "visible_dir" in result[0] + assert ".hidden_dir" not in result[0] + + def test_empty_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path)) + assert result == [] + + def test_nonexistent_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path / "nonexistent")) + assert result == [] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_directories(self, tmp_path): + target = tmp_path / "real_dir" + target.mkdir() + (target / "model.safetensors").write_text("data") + + root = tmp_path / "root" + root.mkdir() + (root / "link").symlink_to(target) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("model.safetensors") + assert "link" in result[0] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_files(self, tmp_path): + real_file = tmp_path / "real.txt" + real_file.write_text("content") + + root = tmp_path / "root" + root.mkdir() + (root / "link.txt").symlink_to(real_file) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("link.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_circular_symlinks_do_not_loop(self, tmp_path): + dir_a = tmp_path / "a" + dir_a.mkdir() + (dir_a / "file.txt").write_text("a") + # a/b -> a (circular) + (dir_a / "b").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + + assert len(result) == 1 + assert result[0].endswith("file.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_mutual_circular_symlinks(self, tmp_path): + dir_a = tmp_path / "a" + dir_b = tmp_path / "b" + dir_a.mkdir() + dir_b.mkdir() + (dir_a / "file_a.txt").write_text("a") + (dir_b / "file_b.txt").write_text("b") + # a/link_b -> b and b/link_a -> a + (dir_a / "link_b").symlink_to(dir_b) + (dir_b / "link_a").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + basenames = sorted(os.path.basename(p) for p in result) + + assert "file_a.txt" in basenames + assert "file_b.txt" in basenames diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py index 82e109832..dcb7a73ca 100644 --- a/tests-unit/assets_test/test_list_filter.py +++ b/tests-unit/assets_test/test_list_filter.py @@ -1,6 +1,7 @@ import time import uuid +import pytest import requests @@ -283,30 +284,21 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse assert b2["has_more"] is False -def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): - r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" - - -def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str): - # limit too small - r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - # bad metadata JSON - r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" +@pytest.mark.parametrize( + "params,error_code", + [ + ({"offset": "-1"}, "INVALID_QUERY"), + ({"limit": "abc"}, "INVALID_QUERY"), + ({"limit": "0"}, "INVALID_QUERY"), + ({"metadata_filter": "{not json"}, "INVALID_QUERY"), + ], + ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"], +) +def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code): + r = http.get(api_base + "/api/assets", params=params, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == error_code def test_list_assets_name_contains_literal_underscore( diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py index f602e5a77..1fbd4d4e2 100644 --- a/tests-unit/assets_test/test_prune_orphaned_assets.py +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets @pytest.fixture diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py new file mode 100644 index 000000000..94cc255bc --- /dev/null +++ b/tests-unit/assets_test/test_sync_references.py @@ -0,0 +1,482 @@ +"""Tests for sync_references_with_filesystem in scanner.py.""" + +import os +import tempfile +from datetime import datetime +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceTag, + Base, + Tag, +) +from app.assets.database.queries.asset_reference import ( + bulk_insert_references_ignore_conflicts, + get_references_for_prefixes, + get_unenriched_references, + restore_references_by_paths, +) +from app.assets.scanner import sync_references_with_filesystem +from app.assets.services.file_utils import get_mtime_ns + + +@pytest.fixture +def db_engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str: + """Create a file and return its absolute path (no symlink resolution).""" + p = temp_dir / name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(content) + return os.path.abspath(str(p)) + + +def _stat_mtime_ns(path: str) -> int: + return get_mtime_ns(os.stat(path, follow_symlinks=True)) + + +def _make_asset( + session: Session, + asset_id: str, + file_path: str, + ref_id: str, + *, + asset_hash: str | None = None, + size_bytes: int = 100, + mtime_ns: int | None = None, + needs_verify: bool = False, + is_missing: bool = False, +) -> tuple[Asset, AssetReference]: + """Insert an Asset + AssetReference and flush.""" + asset = session.get(Asset, asset_id) + if asset is None: + asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes) + session.add(asset) + session.flush() + + ref = AssetReference( + id=ref_id, + asset_id=asset_id, + name=f"test-{ref_id}", + owner_id="system", + file_path=file_path, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + is_missing=is_missing, + ) + session.add(ref) + session.flush() + return asset, ref + + +def _ensure_missing_tag(session: Session): + """Ensure the 'missing' tag exists.""" + if not session.get(Tag, "missing"): + session.add(Tag(name="missing", tag_type="system")) + session.flush() + + +class _VerifyCase: + def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify): + self.id = id + self.stat_unchanged = stat_unchanged + self.needs_verify_before = needs_verify_before + self.expect_needs_verify = expect_needs_verify + + +VERIFY_CASES = [ + _VerifyCase( + id="unchanged_clears_verify", + stat_unchanged=True, + needs_verify_before=True, + expect_needs_verify=False, + ), + _VerifyCase( + id="unchanged_keeps_clear", + stat_unchanged=True, + needs_verify_before=False, + expect_needs_verify=False, + ), + _VerifyCase( + id="changed_sets_verify", + stat_unchanged=False, + needs_verify_before=False, + expect_needs_verify=True, + ), + _VerifyCase( + id="changed_keeps_verify", + stat_unchanged=False, + needs_verify_before=True, + expect_needs_verify=True, + ), +] + + +@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id) +def test_needs_verify_toggling(session, temp_dir, case): + """needs_verify is set/cleared based on mtime+size match.""" + fp = _create_file(temp_dir, "model.bin") + real_mtime = _stat_mtime_ns(fp) + + mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1 + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", + mtime_ns=mtime_for_db, + needs_verify=case.needs_verify_before, + ) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.needs_verify is case.expect_needs_verify + + +class _MissingCase: + def __init__(self, id, file_exists, expect_is_missing): + self.id = id + self.file_exists = file_exists + self.expect_is_missing = expect_is_missing + + +MISSING_CASES = [ + _MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False), + _MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True), +] + + +@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id) +def test_is_missing_flag(session, temp_dir, case): + """is_missing reflects whether the file exists on disk.""" + if case.file_exists: + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + else: + fp = str(temp_dir / "gone.bin") + mtime = 999 + + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is case.expect_is_missing + + +def test_seed_asset_all_missing_deletes_asset(session, temp_dir): + """Seed asset with all refs missing gets deleted entirely.""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + assert session.get(Asset, "seed1") is None + assert session.get(AssetReference, "r1") is None + + +def test_seed_asset_some_exist_returns_survivors(session, temp_dir): + """Seed asset with at least one existing ref survives and is returned.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert session.get(Asset, "seed1") is not None + assert os.path.abspath(fp) in survivors + + +def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir): + """Hashed asset with one stat-unchanged ref deletes missing refs.""" + fp_ok = _create_file(temp_dir, "good.bin") + fp_gone = str(temp_dir / "gone.bin") + mtime = _stat_mtime_ns(fp_ok) + + _make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime) + # Second ref on same asset, file missing + ref_gone = AssetReference( + id="r_gone", asset_id="h1", name="gone", + owner_id="system", file_path=fp_gone, mtime_ns=999, + ) + session.add(ref_gone) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r_ok") is not None + assert session.get(AssetReference, "r_gone") is None + + +def test_hashed_asset_all_missing_keeps_refs(session, temp_dir): + """Hashed asset with all refs missing keeps refs (no pruning).""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r1") is not None + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True + + +def test_missing_tag_added_when_all_refs_gone(session, temp_dir): + """Missing tag is added to hashed asset when all refs are missing.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is not None + + +def test_missing_tag_removed_when_ref_ok(session, temp_dir): + """Missing tag is removed from hashed asset when a ref is stat-unchanged.""" + _ensure_missing_tag(session) + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime) + # Pre-add a stale missing tag + session.add(AssetReferenceTag( + asset_reference_id="r1", tag_name="missing", origin="automatic", + )) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None + + +def test_missing_tags_not_touched_when_flag_false(session, temp_dir): + """Missing tags are not modified when update_missing_tags=False.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=False, + ) + session.commit() + + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None # tag was never added + + +def test_returns_none_when_collect_false(session, temp_dir): + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=False, + ) + + assert result is None + + +def test_returns_empty_set_for_no_prefixes(session): + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + + assert result == set() + + +def test_no_references_is_noop(session, temp_dir): + """No crash and no side effects when there are no references.""" + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert survivors == set() + + +# --------------------------------------------------------------------------- +# Soft-delete persistence across scanner operations +# --------------------------------------------------------------------------- + +def _soft_delete_ref(session: Session, ref_id: str) -> None: + """Mark a reference as soft-deleted (mimics the API DELETE behaviour).""" + ref = session.get(AssetReference, ref_id) + ref.deleted_at = datetime(2025, 1, 1) + session.flush() + + +def test_soft_deleted_ref_excluded_from_get_references_for_prefixes(session, temp_dir): + """get_references_for_prefixes skips soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_references_for_prefixes(session, [str(temp_dir)], include_missing=True) + assert len(rows) == 0 + + +def test_sync_does_not_resurrect_soft_deleted_ref(session, temp_dir): + """Scanner sync leaves soft-deleted refs untouched even when file exists on disk.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.deleted_at is not None, "soft-deleted ref must stay deleted after sync" + + +def test_bulk_insert_does_not_overwrite_soft_deleted_ref(session, temp_dir): + """bulk_insert_references_ignore_conflicts cannot replace a soft-deleted row.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + now = datetime.now(tz=None) + bulk_insert_references_ignore_conflicts(session, [ + { + "id": "r_new", + "asset_id": "a1", + "file_path": fp, + "name": "model.bin", + "owner_id": "", + "mtime_ns": mtime, + "preview_id": None, + "user_metadata": None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ]) + session.commit() + + session.expire_all() + # Original row is still the soft-deleted one + ref = session.get(AssetReference, "r1") + assert ref is not None + assert ref.deleted_at is not None + # The new row was not inserted (conflict on file_path) + assert session.get(AssetReference, "r_new") is None + + +def test_restore_references_by_paths_skips_soft_deleted(session, temp_dir): + """restore_references_by_paths does not clear is_missing on soft-deleted refs.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", mtime_ns=mtime, is_missing=True, + ) + _soft_delete_ref(session, "r1") + session.commit() + + restored = restore_references_by_paths(session, [fp]) + session.commit() + + assert restored == 0 + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True, "is_missing must not be cleared on soft-deleted ref" + assert ref.deleted_at is not None + + +def test_get_unenriched_references_excludes_soft_deleted(session, temp_dir): + """Enrichment queries do not pick up soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_unenriched_references(session, [str(temp_dir)], max_level=2) + assert len(rows) == 0 + + +def test_sync_ignores_soft_deleted_seed_asset(session, temp_dir): + """Soft-deleted seed ref is not garbage-collected even when file is missing.""" + fp = str(temp_dir / "gone.bin") # file does not exist + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + # Asset and ref must still exist — scanner did not see the soft-deleted row + assert session.get(Asset, "seed1") is not None + assert session.get(AssetReference, "r1") is not None diff --git a/tests-unit/assets_test/test_tags.py b/tests-unit/assets_test/test_tags_api.py similarity index 98% rename from tests-unit/assets_test/test_tags.py rename to tests-unit/assets_test/test_tags_api.py index 6b1047802..595bf29c6 100644 --- a/tests-unit/assets_test/test_tags.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, used_names = [t["name"] for t in body2["tags"]] assert custom_tag in used_names - # Delete the asset so the tag usage drops to zero - rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) + # Hard-delete the asset so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120) assert rd.status_code == 204 # Now the custom tag must not be returned when include_zero=false diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py index 137d7391a..d68e5b5d7 100644 --- a/tests-unit/assets_test/test_uploads.py +++ b/tests-unit/assets_test/test_uploads.py @@ -18,25 +18,25 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma assert r1.status_code == 201, a1 assert a1["created_new"] is True - # Second upload with the same data and name should return created_new == False and the same asset + # Second upload with the same data and name creates a new AssetReference (duplicates allowed) + # Returns 200 because Asset already exists, but a new AssetReference is created files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) a2 = r2.json() - assert r2.status_code == 200, a2 - assert a2["created_new"] is False + assert r2.status_code in (200, 201), a2 assert a2["asset_hash"] == a1["asset_hash"] - assert a2["id"] == a1["id"] # old reference + assert a2["id"] != a1["id"] # new reference with same content - # Third upload with the same data but new name should return created_new == False and the new AssetReference + # Third upload with the same data but different name also creates new AssetReference files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)} - r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) - a3 = r2.json() - assert r2.status_code == 200, a3 - assert a3["created_new"] is False + r3 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a3 = r3.json() + assert r3.status_code in (200, 201), a3 assert a3["asset_hash"] == a1["asset_hash"] - assert a3["id"] != a1["id"] # old reference + assert a3["id"] != a1["id"] + assert a3["id"] != a2["id"] def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): @@ -116,7 +116,7 @@ def test_concurrent_upload_identical_bytes_different_names( ): """ Two concurrent uploads of identical bytes but different names. - Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + Expect a single Asset (same hash), two AssetReference rows, and exactly one created_new=True. """ scope = f"concupload-{uuid.uuid4().hex[:6]}" name1, name2 = "cu_a.bin", "cu_b.bin" diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py new file mode 100644 index 000000000..fa4cdcac3 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -0,0 +1,197 @@ +import math + +import pytest +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 +mock_server = MagicMock() + +with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}): + from comfy_extras.nodes_math import MathExpressionNode + + +class TestMathExpressionExecute: + @staticmethod + def _exec(expression: str, **kwargs) -> object: + values = OrderedDict(kwargs) + return MathExpressionNode.execute(expression, values) + + def test_addition(self): + result = self._exec("a + b", a=3, b=4) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_subtraction(self): + result = self._exec("a - b", a=10, b=3) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_multiplication(self): + result = self._exec("a * b", a=3, b=5) + assert result[0] == 15.0 + assert result[1] == 15 + + def test_division(self): + result = self._exec("a / b", a=10, b=4) + assert result[0] == 2.5 + assert result[1] == 2 + + def test_single_input(self): + result = self._exec("a * 2", a=5) + assert result[0] == 10.0 + assert result[1] == 10 + + def test_three_inputs(self): + result = self._exec("a + b + c", a=1, b=2, c=3) + assert result[0] == 6.0 + assert result[1] == 6 + + def test_float_inputs(self): + result = self._exec("a + b", a=1.5, b=2.5) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_mixed_int_float_inputs(self): + result = self._exec("a * b", a=1024, b=1.5) + assert result[0] == 1536.0 + assert result[1] == 1536 + + def test_mixed_resolution_scale(self): + result = self._exec("a * b", a=512, b=0.75) + assert result[0] == 384.0 + assert result[1] == 384 + + def test_sum_values_array(self): + result = self._exec("sum(values)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_sum_variadic(self): + result = self._exec("sum(a, b, c)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_min_values(self): + result = self._exec("min(values)", a=5, b=2, c=8) + assert result[0] == 2.0 + + def test_max_values(self): + result = self._exec("max(values)", a=5, b=2, c=8) + assert result[0] == 8.0 + + def test_abs_function(self): + result = self._exec("abs(a)", a=-7) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_sqrt(self): + result = self._exec("sqrt(a)", a=16) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_ceil(self): + result = self._exec("ceil(a)", a=2.3) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_floor(self): + result = self._exec("floor(a)", a=2.7) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_sin(self): + result = self._exec("sin(a)", a=0) + assert result[0] == 0.0 + + def test_log10(self): + result = self._exec("log10(a)", a=100) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_float_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[0], float) + + def test_int_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[1], int) + + def test_non_numeric_result_raises(self): + with pytest.raises(ValueError, match="must evaluate to a numeric result"): + self._exec("'hello'", a=42) + + def test_undefined_function_raises(self): + with pytest.raises(Exception, match="not defined"): + self._exec("str(a)", a=42) + + def test_boolean_result_raises(self): + with pytest.raises(ValueError, match="got bool"): + self._exec("a > b", a=5, b=3) + + def test_empty_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec("", a=1) + + def test_whitespace_only_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec(" ", a=1) + + # --- Missing function coverage (round, pow, log, log2, cos, tan) --- + + def test_round(self): + result = self._exec("round(a)", a=2.7) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_round_with_ndigits(self): + result = self._exec("round(a, 2)", a=3.14159) + assert result[0] == pytest.approx(3.14) + + def test_pow(self): + result = self._exec("pow(a, b)", a=2, b=10) + assert result[0] == 1024.0 + assert result[1] == 1024 + + def test_log(self): + result = self._exec("log(a)", a=math.e) + assert result[0] == pytest.approx(1.0) + + def test_log2(self): + result = self._exec("log2(a)", a=8) + assert result[0] == pytest.approx(3.0) + + def test_cos(self): + result = self._exec("cos(a)", a=0) + assert result[0] == 1.0 + + def test_tan(self): + result = self._exec("tan(a)", a=0) + assert result[0] == 0.0 + + # --- int/float converter functions --- + + def test_int_converter(self): + result = self._exec("int(a / b)", a=7, b=2) + assert result[1] == 3 + + def test_float_converter(self): + result = self._exec("float(a)", a=5) + assert result[0] == 5.0 + + # --- Error path tests --- + + def test_division_by_zero_raises(self): + with pytest.raises(ZeroDivisionError): + self._exec("a / b", a=1, b=0) + + def test_sqrt_negative_raises(self): + with pytest.raises(ValueError, match="math domain error"): + self._exec("sqrt(a)", a=-1) + + def test_overflow_inf_raises(self): + with pytest.raises(ValueError, match="non-finite result"): + self._exec("a * b", a=1e308, b=10) + + def test_pow_huge_exponent_raises(self): + with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): + self._exec("pow(a, b)", a=10, b=10000000) diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 2355b8000..3a6790ee0 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -2,4 +2,3 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio websocket-client -blake3 diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py new file mode 100644 index 000000000..db3795e48 --- /dev/null +++ b/tests-unit/seeder_test/test_seeder.py @@ -0,0 +1,900 @@ +"""Unit tests for the _AssetSeeder background scanning class.""" + +import threading +from unittest.mock import patch + +import pytest + +from app.assets.database.queries.asset_reference import UnenrichedReferenceRow +from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State + + +@pytest.fixture +def fresh_seeder(): + """Create a fresh _AssetSeeder instance for testing.""" + seeder = _AssetSeeder() + yield seeder + seeder.shutdown(timeout=1.0) + + +@pytest.fixture +def mock_dependencies(): + """Mock all external dependencies for isolated testing.""" + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + yield + + +class TestSeederStateTransitions: + """Test state machine transitions.""" + + def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder): + assert fresh_seeder.get_status().state == State.IDLE + + def test_start_transitions_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + started = fresh_seeder.start(roots=("models",)) + assert started is True + assert reached.wait(timeout=2.0) + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_start_while_running_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + second_start = fresh_seeder.start(roots=("models",)) + assert second_start is False + + barrier.set() + + def test_cancel_transitions_to_cancelling( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + cancelled = fresh_seeder.cancel() + assert cancelled is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + cancelled = fresh_seeder.cancel() + assert cancelled is False + + def test_state_returns_to_idle_after_completion( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + +class TestSeederWait: + """Test wait() behavior.""" + + def test_wait_blocks_until_complete( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + def test_wait_returns_false_on_timeout( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=10.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=0.1) + assert completed is False + + barrier.set() + + def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder): + completed = fresh_seeder.wait(timeout=1.0) + assert completed is True + + +class TestSeederProgress: + """Test progress tracking.""" + + def test_get_status_returns_progress_during_scan( + self, fresh_seeder: _AssetSeeder + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_build(*args, **kwargs): + reached.set() + barrier.wait(timeout=5.0) + return ([], set(), 0) + + paths = ["/path/file1.safetensors", "/path/file2.safetensors"] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch("app.assets.seeder.build_asset_specs", side_effect=slow_build), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + status = fresh_seeder.get_status() + assert status.state == State.RUNNING + assert status.progress is not None + assert status.progress.total == 2 + + barrier.set() + + def test_progress_callback_is_invoked( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + progress_updates: list[Progress] = [] + + def callback(p: Progress): + progress_updates.append(p) + + with patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=[f"/path/file{i}.safetensors" for i in range(10)], + ): + fresh_seeder.start(roots=("models",), progress_callback=callback) + fresh_seeder.wait(timeout=5.0) + + assert len(progress_updates) > 0 + + +class TestSeederCancellation: + """Test cancellation behavior.""" + + def test_scan_commits_partial_progress_on_cancellation( + self, fresh_seeder: _AssetSeeder + ): + insert_count = 0 + barrier = threading.Event() + first_insert_done = threading.Event() + + def slow_insert(specs, tags): + nonlocal insert_count + insert_count += 1 + if insert_count == 1: + first_insert_done.set() + if insert_count >= 2: + barrier.wait(timeout=5.0) + return len(specs) + + paths = [f"/path/file{i}.safetensors" for i in range(1500)] + specs = [ + { + "abs_path": p, + "size_bytes": 100, + "mtime_ns": 0, + "info_name": f"file{i}", + "tags": [], + "fname": f"file{i}", + "metadata": None, + "hash": None, + "mime_type": None, + } + for i, p in enumerate(paths) + ] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch( + "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0) + ), + patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert first_insert_done.wait(timeout=2.0) + + fresh_seeder.cancel() + barrier.set() + fresh_seeder.wait(timeout=5.0) + + assert 1 <= insert_count < 3 # 1500 paths / 500 batch = 3; cancel stopped early + + +class TestSeederErrorHandling: + """Test error handling behavior.""" + + def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=["/path/file.safetensors"], + ), + patch( + "app.assets.seeder.build_asset_specs", + return_value=( + [ + { + "abs_path": "/path/file.safetensors", + "size_bytes": 100, + "mtime_ns": 0, + "info_name": "file", + "tags": [], + "fname": "file", + "metadata": None, + "hash": None, + "mime_type": None, + } + ], + set(), + 0, + ), + ), + patch( + "app.assets.seeder.insert_asset_specs", + side_effect=Exception("DB connection failed"), + ), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "DB connection failed" in status.errors[0] + + def test_dependencies_unavailable_captured_in_errors( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "dependencies" in status.errors[0].lower() + + def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.sync_root_safely", + side_effect=RuntimeError("Unexpected crash"), + ), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert status.state == State.IDLE + assert len(status.errors) > 0 + + +class TestSeederThreadSafety: + """Test thread safety of concurrent operations.""" + + def test_concurrent_start_calls_spawn_only_one_thread( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + results = [] + + def try_start(): + results.append(fresh_seeder.start(roots=("models",))) + + threads = [threading.Thread(target=try_start) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + barrier.set() + + assert sum(results) == 1 + + def test_get_status_safe_during_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + statuses = [] + for _ in range(100): + statuses.append(fresh_seeder.get_status()) + + barrier.set() + + assert all( + s.state in (State.RUNNING, State.IDLE, State.CANCELLING) + for s in statuses + ) + + +class TestSeederMarkMissing: + """Test mark_missing_outside_prefixes behavior.""" + + def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.get_all_known_prefixes", + return_value=["/models", "/input", "/output"], + ), + patch( + "app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5 + ) as mock_mark, + ): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 5 + mock_mark.assert_called_once_with(["/models", "/input", "/output"]) + + def test_mark_missing_raises_when_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + with pytest.raises(ScanInProgressError): + fresh_seeder.mark_missing_outside_prefixes() + + barrier.set() + + def test_mark_missing_returns_zero_when_dependencies_unavailable( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 0 + + def test_prune_first_flag_triggers_mark_missing_before_scan( + self, fresh_seeder: _AssetSeeder + ): + call_order = [] + + def track_mark(prefixes): + call_order.append("mark_missing") + return 3 + + def track_sync(root): + call_order.append(f"sync_{root}") + return set() + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]), + patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark), + patch("app.assets.seeder.sync_root_safely", side_effect=track_sync), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), prune_first=True) + fresh_seeder.wait(timeout=5.0) + + assert call_order[0] == "mark_missing" + assert "sync_models" in call_order + + +class TestSeederPhases: + """Test phased scanning behavior.""" + + def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_fast only runs the fast phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_fast(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 0 + + def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_enrich only runs the enrich phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_enrich(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 0 + assert len(enrich_called) == 1 + + def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder): + """Verify full scan runs both fast and enrich phases.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.FULL) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 1 + + +class TestSeederPauseResume: + """Test pause/resume behavior.""" + + def test_pause_transitions_to_paused( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + paused = fresh_seeder.pause() + assert paused is True + assert fresh_seeder.get_status().state == State.PAUSED + + barrier.set() + + def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + paused = fresh_seeder.pause() + assert paused is False + + def test_resume_returns_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + resumed = fresh_seeder.resume() + assert resumed is True + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_resume_when_not_paused_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + resumed = fresh_seeder.resume() + assert resumed is False + + barrier.set() + + def test_cancel_while_paused_works( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached_checkpoint = threading.Event() + + def slow_collect(*args): + reached_checkpoint.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached_checkpoint.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + cancelled = fresh_seeder.cancel() + assert cancelled is True + + barrier.set() + fresh_seeder.wait(timeout=5.0) + assert fresh_seeder.get_status().state == State.IDLE + +class TestSeederStopRestart: + """Test stop and restart behavior.""" + + def test_stop_is_alias_for_cancel( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + stopped = fresh_seeder.stop() + assert stopped is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_restart_cancels_and_starts_new_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + start_count = 0 + + def slow_collect(*args): + nonlocal start_count + start_count += 1 + if start_count == 1: + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + barrier.set() + restarted = fresh_seeder.restart() + assert restarted is True + + fresh_seeder.wait(timeout=5.0) + assert start_count == 2 + + def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder): + """Verify restart uses previous params when not overridden.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("input", "output")) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart() + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("input", "output") + assert collected_roots[1] == ("input", "output") + + def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder): + """Verify restart can override previous params.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart(roots=("input",)) + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("models",) + assert collected_roots[1] == ("input",) + + +def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: + return UnenrichedReferenceRow( + reference_id=ref_id, asset_id=asset_id, + file_path=f"/fake/{ref_id}.bin", enrichment_level=0, + ) + + +class TestEnrichPhaseDefensiveLogic: + """Test skip_ids filtering and consecutive_empty termination.""" + + def test_failed_refs_are_skipped_on_subsequent_batches( + self, fresh_seeder: _AssetSeeder, + ): + """References that fail enrichment are filtered out of future batches.""" + row_a = _make_row("r1") + row_b = _make_row("r2") + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + return [row_a, row_b] + return [] + + enriched_refs: list[list[str]] = [] + + def fake_enrich(rows, **kwargs): + ref_ids = [r.reference_id for r in rows] + enriched_refs.append(ref_ids) + # r1 always fails, r2 succeeds + failed = [r.reference_id for r in rows if r.reference_id == "r1"] + enriched = len(rows) - len(failed) + return enriched, failed + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # First batch: both refs attempted + assert "r1" in enriched_refs[0] + assert "r2" in enriched_refs[0] + # Second batch: r1 filtered out + assert "r1" not in enriched_refs[1] + assert "r2" in enriched_refs[1] + + def test_stops_after_consecutive_empty_batches( + self, fresh_seeder: _AssetSeeder, + ): + """Enrich phase terminates after 3 consecutive batches with zero progress.""" + row = _make_row("r1") + batch_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal batch_count + batch_count += 1 + # Always return the same row (simulating a permanently failing ref) + return [row] + + def fake_enrich(rows, **kwargs): + # Always fail — zero enriched, all failed + return 0, [r.reference_id for r in rows] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # Should stop after exactly 3 consecutive empty batches + # Batch 1: returns row, enrich fails → filtered out in batch 2+ + # But get_unenriched keeps returning it, filter removes it → empty → break + # Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row], + # skip_ids filters it → empty list → breaks via `if not unenriched: break` + # So it terminates in 2 calls to get_unenriched. + assert batch_count == 2 + + def test_consecutive_empty_counter_resets_on_success( + self, fresh_seeder: _AssetSeeder, + ): + """A successful batch resets the consecutive empty counter.""" + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 6: + return [_make_row(f"r{call_count}", f"a{call_count}")] + return [] + + def fake_enrich(rows, **kwargs): + ref_id = rows[0].reference_id + # Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6 + if ref_id in ("r1", "r2", "r4", "r5"): + return 0, [ref_id] + return 1, [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # All 6 batches should run + 1 final call returning empty + assert call_count == 7 + status = fresh_seeder.get_status() + assert status.state == State.IDLE diff --git a/utils/mime_types.py b/utils/mime_types.py new file mode 100644 index 000000000..916e963c5 --- /dev/null +++ b/utils/mime_types.py @@ -0,0 +1,37 @@ +"""Centralized MIME type initialization. + +Call init_mime_types() once at startup to initialize the MIME type database +and register all custom types used across ComfyUI. +""" + +import mimetypes + +_initialized = False + + +def init_mime_types(): + """Initialize the MIME type database and register all custom types. + + Safe to call multiple times; only runs once. + """ + global _initialized + if _initialized: + return + _initialized = True + + mimetypes.init() + + # Web types (used by server.py for static file serving) + mimetypes.add_type('application/javascript; charset=utf-8', '.js') + mimetypes.add_type('image/webp', '.webp') + + # Model and data file types (used by asset scanning / metadata extraction) + mimetypes.add_type("application/safetensors", ".safetensors") + mimetypes.add_type("application/safetensors", ".sft") + mimetypes.add_type("application/pytorch", ".pt") + mimetypes.add_type("application/pytorch", ".pth") + mimetypes.add_type("application/pickle", ".ckpt") + mimetypes.add_type("application/pickle", ".pkl") + mimetypes.add_type("application/gguf", ".gguf") + mimetypes.add_type("application/yaml", ".yaml") + mimetypes.add_type("application/yaml", ".yml")