feat(assets): align local API with cloud spec

Unify response models, add missing fields, and align input schemas with
the cloud OpenAPI spec at cloud.comfy.org/openapi.

- Replace AssetSummary/AssetDetail/AssetUpdated with single Asset model
- Add is_immutable, metadata (system_metadata), prompt_id fields
- Support mime_type and preview_id in update endpoint
- Make CreateFromHashBody.name optional, add mime_type, require >=1 tag
- Add id/mime_type/preview_id to upload, relax tags to optional
- Rename total_tags → tags in tag add/remove responses
- Add GET /api/assets/tags/refine histogram endpoint
- Add DB migration for system_metadata and prompt_id columns

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Luke Mino-Altherr
2026-03-09 22:36:00 -07:00
parent 06f85e2c79
commit b5b5acf58c
13 changed files with 339 additions and 126 deletions

View File

@@ -38,6 +38,7 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -122,6 +123,29 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at"
def _build_asset_response(result) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
preview_url = None
if result.ref.preview_id:
preview_url = f"/api/assets/{result.ref.preview_id}/content?disposition=inline"
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else 0,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
prompt_id=result.ref.prompt_id,
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response:
@@ -164,20 +188,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order,
)
summaries = [
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
summaries = [_build_asset_response(item) for item in result.items]
payload = schemas_out.AssetsList(
assets=summaries,
@@ -207,18 +218,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id},
)
payload = schemas_out.AssetDetail(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
payload = _build_asset_response(result)
except ValueError as e:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@@ -312,29 +312,27 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON."
)
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash(
hash_str=body.hash,
name=body.name,
name=name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
)
if result is None:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
**asset.model_dump(),
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
@@ -358,6 +356,9 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
"id": parsed.provided_id,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
}
)
except ValidationError as ve:
@@ -378,6 +379,21 @@ async def upload_asset(request: web.Request) -> web.Response:
)
try:
# Idempotent create: if spec.id is provided, check if reference already exists
if spec.id:
existing = get_asset_detail(
reference_id=spec.id,
owner_id=owner_id,
)
if existing:
delete_temp_file_if_exists(parsed.tmp_path)
asset = _build_asset_response(existing)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
created_new=False,
)
return web.json_response(payload_out.model_dump(mode="json"), status=200)
# Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash(
@@ -386,6 +402,7 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
mime_type=spec.mime_type,
)
if result is None:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -410,6 +427,9 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
asset_id=spec.id,
)
except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -428,21 +448,13 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
created_new=result.created_new,
)
status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status)
return web.json_response(payload_out.model_dump(mode="json"), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@@ -464,15 +476,10 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
)
payload = schemas_out.AssetUpdated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
payload = _build_asset_response(result)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve:
@@ -587,7 +594,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsAdd(
added=result.added,
already_present=result.already_present,
total_tags=result.total_tags,
tags=result.total_tags,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@@ -634,7 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsRemove(
removed=result.removed,
not_present=result.not_present,
total_tags=result.total_tags,
tags=result.total_tags,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@@ -653,6 +660,28 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
@_require_assets_feature_enabled
async def seed_assets(request: web.Request) -> web.Response:

View File

@@ -45,6 +45,9 @@ class ParsedUpload:
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel):
@@ -98,11 +101,18 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
mime_type: str | None = None
preview_id: str | None = None
@model_validator(mode="after")
def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
if all(
v is None
for v in (self.name, self.user_metadata, self.mime_type, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, mime_type, preview_id."
)
return self
@@ -110,9 +120,10 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
name: str | None = None
tags: list[str] = Field(default_factory=list, min_length=1)
user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
@field_validator("hash")
@classmethod
@@ -138,6 +149,44 @@ class CreateFromHashBody(BaseModel):
return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@@ -186,21 +235,27 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
- tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
- id: optional UUID for idempotent creation
- mime_type: optional MIME type override
- preview_id: optional asset ID for preview
Files are stored using the content hash as filename stem.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
id: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
@@ -278,14 +333,13 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
if self.tags:
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@@ -4,16 +4,21 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
class Asset(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
size: int = 0
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
preview_id: str | None = None
user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
prompt_id: str | None = None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@@ -23,50 +28,16 @@ class AssetSummary(BaseModel):
return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel):
assets: list[AssetSummary]
assets: list[Asset]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@@ -83,11 +54,15 @@ class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@@ -52,6 +52,9 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0
tmp_path: str | None = None
@@ -128,6 +131,12 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
elif fname == "id":
provided_id = ((await field.text()) or "").strip() or None
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
@@ -152,6 +161,9 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
provided_id=provided_id,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
)