From d115a4d3431a67f5ac785ca063b462444de104aa Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Mon, 9 Mar 2026 22:36:00 -0700 Subject: [PATCH] feat(assets): align local API with cloud spec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Unify response models, add missing fields, and align input schemas with the cloud OpenAPI spec at cloud.comfy.org/openapi. - Replace AssetSummary/AssetDetail/AssetUpdated with single Asset model - Add is_immutable, metadata (system_metadata), prompt_id fields - Support mime_type and preview_id in update endpoint - Make CreateFromHashBody.name optional, add mime_type, require >=1 tag - Add id/mime_type/preview_id to upload, relax tags to optional - Rename total_tags → tags in tag add/remove responses - Add GET /api/assets/tags/refine histogram endpoint - Add DB migration for system_metadata and prompt_id columns Co-Authored-By: Claude Opus 4.6 --- .../versions/0003_add_metadata_prompt.py | 31 ++++ app/assets/api/routes.py | 147 +++++++++++------- app/assets/api/schemas_in.py | 86 ++++++++-- app/assets/api/schemas_out.py | 65 +++----- app/assets/api/upload.py | 12 ++ app/assets/database/models.py | 4 + app/assets/database/queries/__init__.py | 2 + .../database/queries/asset_reference.py | 87 +---------- app/assets/database/queries/common.py | 87 ++++++++++- app/assets/database/queries/tags.py | 50 ++++++ app/assets/services/asset_management.py | 17 ++ app/assets/services/ingest.py | 13 +- app/assets/services/schemas.py | 10 +- app/assets/services/tagging.py | 23 +++ tests-unit/assets_test/test_tags_api.py | 2 +- 15 files changed, 426 insertions(+), 210 deletions(-) create mode 100644 alembic_db/versions/0003_add_metadata_prompt.py diff --git a/alembic_db/versions/0003_add_metadata_prompt.py b/alembic_db/versions/0003_add_metadata_prompt.py new file mode 100644 index 000000000..484d92923 --- /dev/null +++ b/alembic_db/versions/0003_add_metadata_prompt.py @@ -0,0 +1,31 @@ +""" +Add system_metadata and prompt_id columns to asset_references. + +Revision ID: 0003_add_metadata_prompt +Revises: 0002_merge_to_asset_references +Create Date: 2026-03-09 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0003_add_metadata_prompt" +down_revision = "0002_merge_to_asset_references" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("asset_references") as batch_op: + batch_op.add_column( + sa.Column("system_metadata", sa.JSON(), nullable=True) + ) + batch_op.add_column( + sa.Column("prompt_id", sa.String(length=36), nullable=True) + ) + + +def downgrade() -> None: + with op.batch_alter_table("asset_references") as batch_op: + batch_op.drop_column("prompt_id") + batch_op.drop_column("system_metadata") diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 40dee9f46..489ace2f1 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -38,6 +38,7 @@ from app.assets.services import ( update_asset_metadata, upload_from_temp_path, ) +from app.assets.services.tagging import list_tag_histogram ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None @@ -122,6 +123,29 @@ def _validate_sort_field(requested: str | None) -> str: return "created_at" +def _build_asset_response(result) -> schemas_out.Asset: + """Build an Asset response from a service result.""" + preview_url = None + if result.ref.preview_id: + preview_url = f"/api/assets/{result.ref.preview_id}/content?disposition=inline" + return schemas_out.Asset( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + size=int(result.asset.size_bytes) if result.asset else 0, + mime_type=result.asset.mime_type if result.asset else None, + tags=result.tags, + preview_url=preview_url, + preview_id=result.ref.preview_id, + user_metadata=result.ref.user_metadata or {}, + metadata=result.ref.system_metadata, + prompt_id=result.ref.prompt_id, + created_at=result.ref.created_at, + updated_at=result.ref.updated_at, + last_access_time=result.ref.last_access_time, + ) + + @ROUTES.head("/api/assets/hash/{hash}") @_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: @@ -164,20 +188,7 @@ async def list_assets_route(request: web.Request) -> web.Response: order=order, ) - summaries = [ - schemas_out.AssetSummary( - id=item.ref.id, - name=item.ref.name, - asset_hash=item.asset.hash if item.asset else None, - size=int(item.asset.size_bytes) if item.asset else None, - mime_type=item.asset.mime_type if item.asset else None, - tags=item.tags, - created_at=item.ref.created_at, - updated_at=item.ref.updated_at, - last_access_time=item.ref.last_access_time, - ) - for item in result.items - ] + summaries = [_build_asset_response(item) for item in result.items] payload = schemas_out.AssetsList( assets=summaries, @@ -207,18 +218,7 @@ async def get_asset_route(request: web.Request) -> web.Response: {"id": reference_id}, ) - payload = schemas_out.AssetDetail( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash if result.asset else None, - size=int(result.asset.size_bytes) if result.asset else None, - mime_type=result.asset.mime_type if result.asset else None, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - preview_id=result.ref.preview_id, - created_at=result.ref.created_at, - last_access_time=result.ref.last_access_time, - ) + payload = _build_asset_response(result) except ValueError as e: return _build_error_response( 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} @@ -312,29 +312,27 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response: 400, "INVALID_JSON", "Request body must be valid JSON." ) + # Derive name from hash if not provided + name = body.name + if name is None: + name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash + result = create_from_hash( hash_str=body.hash, - name=body.name, + name=name, tags=body.tags, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), + mime_type=body.mime_type, ) if result is None: return _build_error_response( 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" ) + asset = _build_asset_response(result) payload_out = schemas_out.AssetCreated( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash, - size=int(result.asset.size_bytes), - mime_type=result.asset.mime_type, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - preview_id=result.ref.preview_id, - created_at=result.ref.created_at, - last_access_time=result.ref.last_access_time, + **asset.model_dump(), created_new=result.created_new, ) return web.json_response(payload_out.model_dump(mode="json"), status=201) @@ -358,6 +356,9 @@ async def upload_asset(request: web.Request) -> web.Response: "name": parsed.provided_name, "user_metadata": parsed.user_metadata_raw, "hash": parsed.provided_hash, + "id": parsed.provided_id, + "mime_type": parsed.provided_mime_type, + "preview_id": parsed.provided_preview_id, } ) except ValidationError as ve: @@ -378,6 +379,21 @@ async def upload_asset(request: web.Request) -> web.Response: ) try: + # Idempotent create: if spec.id is provided, check if reference already exists + if spec.id: + existing = get_asset_detail( + reference_id=spec.id, + owner_id=owner_id, + ) + if existing: + delete_temp_file_if_exists(parsed.tmp_path) + asset = _build_asset_response(existing) + payload_out = schemas_out.AssetCreated( + **asset.model_dump(), + created_new=False, + ) + return web.json_response(payload_out.model_dump(mode="json"), status=200) + # Fast path: hash exists, create AssetReference without writing anything if spec.hash and parsed.provided_hash_exists is True: result = create_from_hash( @@ -386,6 +402,7 @@ async def upload_asset(request: web.Request) -> web.Response: tags=spec.tags, user_metadata=spec.user_metadata or {}, owner_id=owner_id, + mime_type=spec.mime_type, ) if result is None: delete_temp_file_if_exists(parsed.tmp_path) @@ -410,6 +427,9 @@ async def upload_asset(request: web.Request) -> web.Response: client_filename=parsed.file_client_name, owner_id=owner_id, expected_hash=spec.hash, + mime_type=spec.mime_type, + preview_id=spec.preview_id, + asset_id=spec.id, ) except AssetValidationError as e: delete_temp_file_if_exists(parsed.tmp_path) @@ -428,21 +448,13 @@ async def upload_asset(request: web.Request) -> web.Response: logging.exception("upload_asset failed for owner_id=%s", owner_id) return _build_error_response(500, "INTERNAL", "Unexpected server error.") - payload = schemas_out.AssetCreated( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash, - size=int(result.asset.size_bytes), - mime_type=result.asset.mime_type, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - preview_id=result.ref.preview_id, - created_at=result.ref.created_at, - last_access_time=result.ref.last_access_time, + asset = _build_asset_response(result) + payload_out = schemas_out.AssetCreated( + **asset.model_dump(), created_new=result.created_new, ) status = 201 if result.created_new else 200 - return web.json_response(payload.model_dump(mode="json"), status=status) + return web.json_response(payload_out.model_dump(mode="json"), status=status) @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") @@ -464,15 +476,10 @@ async def update_asset_route(request: web.Request) -> web.Response: name=body.name, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), + mime_type=body.mime_type, + preview_id=body.preview_id, ) - payload = schemas_out.AssetUpdated( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash if result.asset else None, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - updated_at=result.ref.updated_at, - ) + payload = _build_asset_response(result) except PermissionError as pe: return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) except ValueError as ve: @@ -587,7 +594,7 @@ async def add_asset_tags(request: web.Request) -> web.Response: payload = schemas_out.TagsAdd( added=result.added, already_present=result.already_present, - total_tags=result.total_tags, + tags=result.total_tags, ) except PermissionError as pe: return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) @@ -634,7 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response: payload = schemas_out.TagsRemove( removed=result.removed, not_present=result.not_present, - total_tags=result.total_tags, + tags=result.total_tags, ) except PermissionError as pe: return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) @@ -653,6 +660,28 @@ async def delete_asset_tags(request: web.Request) -> web.Response: return web.json_response(payload.model_dump(mode="json"), status=200) +@ROUTES.get("/api/assets/tags/refine") +@_require_assets_feature_enabled +async def get_tags_refine(request: web.Request) -> web.Response: + """GET request to get tag histogram for filtered assets.""" + query_dict = get_query_dict(request) + try: + q = schemas_in.TagsRefineQuery.model_validate(query_dict) + except ValidationError as ve: + return _build_validation_error_response("INVALID_QUERY", ve) + + tag_counts = list_tag_histogram( + owner_id=USER_MANAGER.get_request_user_id(request), + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + ) + payload = schemas_out.TagHistogram(tag_counts=tag_counts) + return web.json_response(payload.model_dump(mode="json"), status=200) + + @ROUTES.post("/api/assets/seed") @_require_assets_feature_enabled async def seed_assets(request: web.Request) -> web.Response: diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index d255c938e..48d11a391 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -45,6 +45,9 @@ class ParsedUpload: user_metadata_raw: str | None provided_hash: str | None provided_hash_exists: bool | None + provided_id: str | None = None + provided_mime_type: str | None = None + provided_preview_id: str | None = None class ListAssetsQuery(BaseModel): @@ -98,11 +101,18 @@ class ListAssetsQuery(BaseModel): class UpdateAssetBody(BaseModel): name: str | None = None user_metadata: dict[str, Any] | None = None + mime_type: str | None = None + preview_id: str | None = None @model_validator(mode="after") def _validate_at_least_one_field(self): - if self.name is None and self.user_metadata is None: - raise ValueError("Provide at least one of: name, user_metadata.") + if all( + v is None + for v in (self.name, self.user_metadata, self.mime_type, self.preview_id) + ): + raise ValueError( + "Provide at least one of: name, user_metadata, mime_type, preview_id." + ) return self @@ -110,9 +120,10 @@ class CreateFromHashBody(BaseModel): model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) hash: str - name: str - tags: list[str] = Field(default_factory=list) + name: str | None = None + tags: list[str] = Field(default_factory=list, min_length=1) user_metadata: dict[str, Any] = Field(default_factory=dict) + mime_type: str | None = None @field_validator("hash") @classmethod @@ -138,6 +149,44 @@ class CreateFromHashBody(BaseModel): return [] +class TagsRefineQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + metadata_filter: dict[str, Any] | None = None + limit: conint(ge=1, le=1000) = 100 + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + class TagsListQuery(BaseModel): model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) @@ -186,21 +235,27 @@ class TagsRemove(TagsAdd): class UploadAssetSpec(BaseModel): """Upload Asset operation. - - tags: ordered; first is root ('models'|'input'|'output'); + - tags: optional list; if provided, first is root ('models'|'input'|'output'); if root == 'models', second must be a valid category - name: display name - user_metadata: arbitrary JSON object (optional) - hash: optional canonical 'blake3:' for validation / fast-path + - id: optional UUID for idempotent creation + - mime_type: optional MIME type override + - preview_id: optional asset ID for preview Files are stored using the content hash as filename stem. """ model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) - tags: list[str] = Field(..., min_length=1) + tags: list[str] = Field(default_factory=list) name: str | None = Field(default=None, max_length=512, description="Display Name") user_metadata: dict[str, Any] = Field(default_factory=dict) hash: str | None = Field(default=None) + id: str | None = Field(default=None) + mime_type: str | None = Field(default=None) + preview_id: str | None = Field(default=None) @field_validator("hash", mode="before") @classmethod @@ -278,14 +333,13 @@ class UploadAssetSpec(BaseModel): @model_validator(mode="after") def _validate_order(self): - if not self.tags: - raise ValueError("tags must be provided and non-empty") - root = self.tags[0] - if root not in {"models", "input", "output"}: - raise ValueError("first tag must be one of: models, input, output") - if root == "models": - if len(self.tags) < 2: - raise ValueError( - "models uploads require a category tag as the second tag" - ) + if self.tags: + root = self.tags[0] + if root not in {"models", "input", "output"}: + raise ValueError("first tag must be one of: models, input, output") + if root == "models": + if len(self.tags) < 2: + raise ValueError( + "models uploads require a category tag as the second tag" + ) return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index f36447856..e2d52c75f 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -4,16 +4,21 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_serializer -class AssetSummary(BaseModel): +class Asset(BaseModel): id: str name: str asset_hash: str | None = None - size: int | None = None + size: int = 0 mime_type: str | None = None tags: list[str] = Field(default_factory=list) preview_url: str | None = None - created_at: datetime | None = None - updated_at: datetime | None = None + preview_id: str | None = None + user_metadata: dict[str, Any] = Field(default_factory=dict) + is_immutable: bool = False + metadata: dict[str, Any] | None = None + prompt_id: str | None = None + created_at: datetime + updated_at: datetime last_access_time: datetime | None = None model_config = ConfigDict(from_attributes=True) @@ -23,50 +28,16 @@ class AssetSummary(BaseModel): return v.isoformat() if v else None +class AssetCreated(Asset): + created_new: bool + + class AssetsList(BaseModel): - assets: list[AssetSummary] + assets: list[Asset] total: int has_more: bool -class AssetUpdated(BaseModel): - id: str - name: str - asset_hash: str | None = None - tags: list[str] = Field(default_factory=list) - user_metadata: dict[str, Any] = Field(default_factory=dict) - updated_at: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("updated_at") - def _serialize_updated_at(self, v: datetime | None, _info): - return v.isoformat() if v else None - - -class AssetDetail(BaseModel): - id: str - name: str - asset_hash: str | None = None - size: int | None = None - mime_type: str | None = None - tags: list[str] = Field(default_factory=list) - user_metadata: dict[str, Any] = Field(default_factory=dict) - preview_id: str | None = None - created_at: datetime | None = None - last_access_time: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("created_at", "last_access_time") - def _serialize_datetime(self, v: datetime | None, _info): - return v.isoformat() if v else None - - -class AssetCreated(AssetDetail): - created_new: bool - - class TagUsage(BaseModel): name: str count: int @@ -83,11 +54,15 @@ class TagsAdd(BaseModel): model_config = ConfigDict(str_strip_whitespace=True) added: list[str] = Field(default_factory=list) already_present: list[str] = Field(default_factory=list) - total_tags: list[str] = Field(default_factory=list) + tags: list[str] = Field(default_factory=list) class TagsRemove(BaseModel): model_config = ConfigDict(str_strip_whitespace=True) removed: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list) - total_tags: list[str] = Field(default_factory=list) + tags: list[str] = Field(default_factory=list) + + +class TagHistogram(BaseModel): + tag_counts: dict[str, int] diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py index 721c12f4d..c36257ae0 100644 --- a/app/assets/api/upload.py +++ b/app/assets/api/upload.py @@ -52,6 +52,9 @@ async def parse_multipart_upload( user_metadata_raw: str | None = None provided_hash: str | None = None provided_hash_exists: bool | None = None + provided_id: str | None = None + provided_mime_type: str | None = None + provided_preview_id: str | None = None file_written = 0 tmp_path: str | None = None @@ -128,6 +131,12 @@ async def parse_multipart_upload( provided_name = (await field.text()) or None elif fname == "user_metadata": user_metadata_raw = (await field.text()) or None + elif fname == "id": + provided_id = ((await field.text()) or "").strip() or None + elif fname == "mime_type": + provided_mime_type = ((await field.text()) or "").strip() or None + elif fname == "preview_id": + provided_preview_id = ((await field.text()) or "").strip() or None if not file_present and not (provided_hash and provided_hash_exists): raise UploadError( @@ -152,6 +161,9 @@ async def parse_multipart_upload( user_metadata_raw=user_metadata_raw, provided_hash=provided_hash, provided_hash_exists=provided_hash_exists, + provided_id=provided_id, + provided_mime_type=provided_mime_type, + provided_preview_id=provided_preview_id, ) diff --git a/app/assets/database/models.py b/app/assets/database/models.py index 03c1c1707..22340ebd5 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -96,6 +96,10 @@ class AssetReference(Base): user_metadata: Mapped[dict[str, Any] | None] = mapped_column( JSON(none_as_null=True) ) + system_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True), nullable=True, default=None + ) + prompt_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=get_utc_now ) diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index 7888d0645..5283b400e 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -54,6 +54,7 @@ from app.assets.database.queries.tags import ( bulk_insert_tags_and_meta, ensure_tags_exist, get_reference_tags, + list_tag_counts_for_filtered_assets, list_tags_with_usage, remove_missing_tag_for_asset_id, remove_tags_from_reference, @@ -99,6 +100,7 @@ __all__ = [ "insert_reference", "list_references_by_asset_id", "list_references_page", + "list_tag_counts_for_filtered_assets", "list_tags_with_usage", "mark_references_missing_outside_prefixes", "reassign_asset_references", diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 6524791cc..d096670c1 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -24,6 +24,8 @@ from app.assets.database.models import ( ) from app.assets.database.queries.common import ( MAX_BIND_PARAMS, + apply_metadata_filter, + apply_tag_filters, build_prefix_like_conditions, build_visible_owner_clause, calculate_rows_per_statement, @@ -79,83 +81,6 @@ def convert_metadata_to_rows(key: str, value) -> list[dict]: return [{"key": key, "ordinal": 0, "val_json": value}] -def _apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetReferenceTag.asset_reference_id == AssetReference.id) - & (AssetReferenceTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetReferenceTag.asset_reference_id == AssetReference.id) - & (AssetReferenceTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def _apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: dict | None = None, -) -> sa.sql.Select: - """Apply filters using asset_reference_meta projection table.""" - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetReferenceMeta.asset_reference_id == AssetReference.id, - AssetReferenceMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - if value is None: - no_row_for_key = sa.not_( - sa.exists().where( - AssetReferenceMeta.asset_reference_id == AssetReference.id, - AssetReferenceMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetReferenceMeta.val_json.is_(None), - AssetReferenceMeta.val_str.is_(None), - AssetReferenceMeta.val_num.is_(None), - AssetReferenceMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - if isinstance(value, bool): - return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) - if isinstance(value, (int, float, Decimal)): - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetReferenceMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetReferenceMeta.val_str == value) - return _exists_for_pred(key, AssetReferenceMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt def get_reference_by_id( @@ -336,8 +261,8 @@ def list_references_page( escaped, esc = escape_sql_like_string(name_contains) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) - base = _apply_tag_filters(base, include_tags, exclude_tags) - base = _apply_metadata_filter(base, metadata_filter) + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() order = (order or "desc").lower() @@ -366,8 +291,8 @@ def list_references_page( count_stmt = count_stmt.where( AssetReference.name.ilike(f"%{escaped}%", escape=esc) ) - count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) total = int(session.execute(count_stmt).scalar_one() or 0) refs = session.execute(base).unique().scalars().all() diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py index 194c39a1e..94ec5a526 100644 --- a/app/assets/database/queries/common.py +++ b/app/assets/database/queries/common.py @@ -1,12 +1,14 @@ """Shared utilities for database query modules.""" import os -from typing import Iterable +from decimal import Decimal +from typing import Iterable, Sequence import sqlalchemy as sa +from sqlalchemy import exists -from app.assets.database.models import AssetReference -from app.assets.helpers import escape_sql_like_string +from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag +from app.assets.helpers import escape_sql_like_string, normalize_tags MAX_BIND_PARAMS = 800 @@ -52,3 +54,82 @@ def build_prefix_like_conditions( escaped, esc = escape_sql_like_string(base) conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) return conds + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_reference_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetReferenceMeta.val_json.is_(None), + AssetReferenceMeta.val_str.is_(None), + AssetReferenceMeta.val_num.is_(None), + AssetReferenceMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetReferenceMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetReferenceMeta.val_str == value) + return _exists_for_pred(key, AssetReferenceMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 8b25fee67..05acbdbd9 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from app.assets.database.models import ( + Asset, AssetReference, AssetReferenceMeta, AssetReferenceTag, Tag, ) from app.assets.database.queries.common import ( + apply_metadata_filter, + apply_tag_filters, build_visible_owner_clause, iter_row_chunks, ) @@ -320,6 +323,53 @@ def list_tags_with_usage( return rows_norm, int(total or 0) +def list_tag_counts_for_filtered_assets( + session: Session, + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 100, +) -> dict[str, int]: + """Return tag counts for assets matching the given filters. + + Uses the same filtering logic as list_references_page but returns + {tag_name: count} instead of paginated references. + """ + # Build a subquery of matching reference IDs + ref_sq = ( + select(AssetReference.id) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags) + ref_sq = apply_metadata_filter(ref_sq, metadata_filter) + ref_sq = ref_sq.subquery() + + # Count tags across those references + q = ( + select( + AssetReferenceTag.tag_name, + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id))) + .group_by(AssetReferenceTag.tag_name) + .order_by(func.count(AssetReferenceTag.asset_reference_id).desc()) + .limit(limit) + ) + + rows = session.execute(q).all() + return {tag_name: int(cnt) for tag_name, cnt in rows} + + def bulk_insert_tags_and_meta( session: Session, tag_rows: list[dict], diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 3fe7115c8..b85e77edb 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -20,6 +20,7 @@ from app.assets.database.queries import ( set_reference_metadata, set_reference_preview, set_reference_tags, + update_asset_hash_and_mime, update_reference_access_time, update_reference_name, update_reference_updated_at, @@ -67,6 +68,8 @@ def update_asset_metadata( user_metadata: UserMetadata = None, tag_origin: str = "manual", owner_id: str = "", + mime_type: str | None = None, + preview_id: str | None = None, ) -> AssetDetailResult: with create_session() as session: ref = get_reference_with_owner_check(session, reference_id, owner_id) @@ -103,6 +106,20 @@ def update_asset_metadata( ) touched = True + if mime_type is not None: + update_asset_hash_and_mime( + session, asset_id=ref.asset_id, mime_type=mime_type + ) + touched = True + + if preview_id is not None: + set_reference_preview( + session, + reference_id=reference_id, + preview_asset_id=preview_id, + ) + touched = True + if touched and user_metadata is None: update_reference_updated_at(session, reference_id=reference_id) diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index 44d7aef36..42ce08c41 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -18,6 +18,7 @@ from app.assets.database.queries import ( remove_missing_tag_for_asset_id, set_reference_metadata, set_reference_tags, + update_asset_hash_and_mime, upsert_asset, upsert_reference, validate_tags_exist, @@ -242,6 +243,9 @@ def upload_from_temp_path( client_filename: str | None = None, owner_id: str = "", expected_hash: str | None = None, + mime_type: str | None = None, + preview_id: str | None = None, + asset_id: str | None = None, ) -> UploadResult: try: digest, _ = hashing.compute_blake3_hash(temp_path) @@ -291,7 +295,7 @@ def upload_from_temp_path( dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) validate_path_within_base(dest_abs, base_dir) - content_type = ( + content_type = mime_type or ( mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] or mimetypes.guess_type(hashed_basename, strict=False)[0] or "application/octet-stream" @@ -315,7 +319,7 @@ def upload_from_temp_path( mime_type=content_type, info_name=_sanitize_filename(name or client_filename, fallback=digest), owner_id=owner_id, - preview_id=None, + preview_id=preview_id, user_metadata=user_metadata or {}, tags=tags, tag_origin="manual", @@ -348,6 +352,7 @@ def create_from_hash( tags: list[str] | None = None, user_metadata: dict | None = None, owner_id: str = "", + mime_type: str | None = None, ) -> UploadResult | None: canonical = hash_str.strip().lower() @@ -356,6 +361,10 @@ def create_from_hash( if not asset: return None + if mime_type and asset.mime_type != mime_type: + update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type) + session.commit() + result = _register_existing_asset( asset_hash=canonical, name=_sanitize_filename( diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index 8b1f1f4dc..d63c1f60d 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -23,9 +23,11 @@ class ReferenceData: file_path: str | None user_metadata: UserMetadata preview_id: str | None - created_at: datetime - updated_at: datetime - last_access_time: datetime | None + system_metadata: dict[str, Any] | None = None + prompt_id: str | None = None + created_at: datetime = None # type: ignore[assignment] + updated_at: datetime = None # type: ignore[assignment] + last_access_time: datetime | None = None @dataclass(frozen=True) @@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData: file_path=ref.file_path, user_metadata=ref.user_metadata, preview_id=ref.preview_id, + system_metadata=ref.system_metadata, + prompt_id=ref.prompt_id, created_at=ref.created_at, updated_at=ref.updated_at, last_access_time=ref.last_access_time, diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py index 28900464d..37b612753 100644 --- a/app/assets/services/tagging.py +++ b/app/assets/services/tagging.py @@ -1,3 +1,5 @@ +from typing import Sequence + from app.assets.database.queries import ( AddTagsResult, RemoveTagsResult, @@ -6,6 +8,7 @@ from app.assets.database.queries import ( list_tags_with_usage, remove_tags_from_reference, ) +from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets from app.assets.services.schemas import TagUsage from app.database.db import create_session @@ -73,3 +76,23 @@ def list_tags( ) return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total + + +def list_tag_histogram( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 100, +) -> dict[str, int]: + with create_session() as session: + return list_tag_counts_for_filtered_assets( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + ) diff --git a/tests-unit/assets_test/test_tags_api.py b/tests-unit/assets_test/test_tags_api.py index 595bf29c6..cc351ef1b 100644 --- a/tests-unit/assets_test/test_tags_api.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -97,7 +97,7 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset # normalized, deduplicated; 'unit-tests' was already present from the seed assert set(b1["added"]) == {"newtag", "beta"} assert set(b1["already_present"]) == {"unit-tests"} - assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] + assert "newtag" in b1["tags"] and "beta" in b1["tags"] rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) g = rg.json()