Compare commits

..

4 Commits

Author SHA1 Message Date
Luke Mino-Altherr
b5b5acf58c feat(assets): align local API with cloud spec
Unify response models, add missing fields, and align input schemas with
the cloud OpenAPI spec at cloud.comfy.org/openapi.

- Replace AssetSummary/AssetDetail/AssetUpdated with single Asset model
- Add is_immutable, metadata (system_metadata), prompt_id fields
- Support mime_type and preview_id in update endpoint
- Make CreateFromHashBody.name optional, add mime_type, require >=1 tag
- Add id/mime_type/preview_id to upload, relax tags to optional
- Rename total_tags → tags in tag add/remove responses
- Add GET /api/assets/tags/refine histogram endpoint
- Add DB migration for system_metadata and prompt_id columns

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 22:37:04 -07:00
Jukka Seppänen
06f85e2c79 Fix text encoder lora loading for wrapped models (#12852) 2026-03-09 16:08:51 -04:00
comfyanonymous
e4b0bb8305 Import assets seeder later, print some package versions. (#12841) 2026-03-08 16:25:30 -04:00
rattus
7723f20bbe comfy-aimdo 0.2.9 (#12840)
Comfy-aimdo 0.2.9 fixes a context issue where if a non-main thread does
a spurious garbage collection, cudaFrees are attempted with bad
context.

Some new APIs for displaying aimdo stats in UI widgets are also added.
These are purely additive getters that dont touch cuda APIs.
2026-03-08 16:17:40 -04:00
20 changed files with 356 additions and 603 deletions

View File

@@ -0,0 +1,31 @@
"""
Add system_metadata and prompt_id columns to asset_references.
Revision ID: 0003_add_metadata_prompt
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
revision = "0003_add_metadata_prompt"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("prompt_id", sa.String(length=36), nullable=True)
)
def downgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("prompt_id")
batch_op.drop_column("system_metadata")

View File

@@ -38,6 +38,7 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -122,6 +123,29 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at"
def _build_asset_response(result) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
preview_url = None
if result.ref.preview_id:
preview_url = f"/api/assets/{result.ref.preview_id}/content?disposition=inline"
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else 0,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
prompt_id=result.ref.prompt_id,
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response:
@@ -164,20 +188,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order,
)
summaries = [
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
summaries = [_build_asset_response(item) for item in result.items]
payload = schemas_out.AssetsList(
assets=summaries,
@@ -207,18 +218,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id},
)
payload = schemas_out.AssetDetail(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
payload = _build_asset_response(result)
except ValueError as e:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@@ -312,29 +312,27 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON."
)
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash(
hash_str=body.hash,
name=body.name,
name=name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
)
if result is None:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
**asset.model_dump(),
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
@@ -358,6 +356,9 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
"id": parsed.provided_id,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
}
)
except ValidationError as ve:
@@ -378,6 +379,21 @@ async def upload_asset(request: web.Request) -> web.Response:
)
try:
# Idempotent create: if spec.id is provided, check if reference already exists
if spec.id:
existing = get_asset_detail(
reference_id=spec.id,
owner_id=owner_id,
)
if existing:
delete_temp_file_if_exists(parsed.tmp_path)
asset = _build_asset_response(existing)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
created_new=False,
)
return web.json_response(payload_out.model_dump(mode="json"), status=200)
# Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash(
@@ -386,6 +402,7 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
mime_type=spec.mime_type,
)
if result is None:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -410,6 +427,9 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
asset_id=spec.id,
)
except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -428,21 +448,13 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
created_new=result.created_new,
)
status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status)
return web.json_response(payload_out.model_dump(mode="json"), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@@ -464,15 +476,10 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
)
payload = schemas_out.AssetUpdated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
payload = _build_asset_response(result)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve:
@@ -587,7 +594,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsAdd(
added=result.added,
already_present=result.already_present,
total_tags=result.total_tags,
tags=result.total_tags,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@@ -634,7 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsRemove(
removed=result.removed,
not_present=result.not_present,
total_tags=result.total_tags,
tags=result.total_tags,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@@ -653,6 +660,28 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
@_require_assets_feature_enabled
async def seed_assets(request: web.Request) -> web.Response:

View File

@@ -45,6 +45,9 @@ class ParsedUpload:
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel):
@@ -98,11 +101,18 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
mime_type: str | None = None
preview_id: str | None = None
@model_validator(mode="after")
def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
if all(
v is None
for v in (self.name, self.user_metadata, self.mime_type, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, mime_type, preview_id."
)
return self
@@ -110,9 +120,10 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
name: str | None = None
tags: list[str] = Field(default_factory=list, min_length=1)
user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
@field_validator("hash")
@classmethod
@@ -138,6 +149,44 @@ class CreateFromHashBody(BaseModel):
return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@@ -186,21 +235,27 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
- tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
- id: optional UUID for idempotent creation
- mime_type: optional MIME type override
- preview_id: optional asset ID for preview
Files are stored using the content hash as filename stem.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
id: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
@@ -278,14 +333,13 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
if self.tags:
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@@ -4,16 +4,21 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
class Asset(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
size: int = 0
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
preview_id: str | None = None
user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
prompt_id: str | None = None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@@ -23,50 +28,16 @@ class AssetSummary(BaseModel):
return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel):
assets: list[AssetSummary]
assets: list[Asset]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@@ -83,11 +54,15 @@ class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@@ -52,6 +52,9 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0
tmp_path: str | None = None
@@ -128,6 +131,12 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
elif fname == "id":
provided_id = ((await field.text()) or "").strip() or None
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
@@ -152,6 +161,9 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
provided_id=provided_id,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
)

View File

@@ -96,6 +96,10 @@ class AssetReference(Base):
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
prompt_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)

View File

@@ -54,6 +54,7 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
@@ -99,6 +100,7 @@ __all__ = [
"insert_reference",
"list_references_by_asset_id",
"list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",

View File

@@ -320,6 +320,59 @@ def list_tags_with_usage(
return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
from app.assets.database.queries.asset_reference import (
_apply_tag_filters,
_apply_metadata_filter,
)
from app.assets.database.models import Asset
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = _apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = _apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],

View File

@@ -67,6 +67,8 @@ def update_asset_metadata(
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult:
with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id)
@@ -103,6 +105,21 @@ def update_asset_metadata(
)
touched = True
if mime_type is not None:
from app.assets.database.queries import update_asset_hash_and_mime
update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_asset_id=preview_id,
)
touched = True
if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id)

View File

@@ -3,8 +3,12 @@ import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import IO, Any, Callable, Iterator
import logging
from blake3 import blake3
try:
from blake3 import blake3
except ModuleNotFoundError:
logging.warning("WARNING: blake3 package not installed")
DEFAULT_CHUNK = 8 * 1024 * 1024

View File

@@ -242,6 +242,9 @@ def upload_from_temp_path(
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
asset_id: str | None = None,
) -> UploadResult:
try:
digest, _ = hashing.compute_blake3_hash(temp_path)
@@ -291,7 +294,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = (
content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
@@ -315,7 +318,7 @@ def upload_from_temp_path(
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=None,
preview_id=preview_id,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
@@ -348,6 +351,7 @@ def create_from_hash(
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult | None:
canonical = hash_str.strip().lower()
@@ -356,6 +360,11 @@ def create_from_hash(
if not asset:
return None
if mime_type and asset.mime_type != mime_type:
from app.assets.database.queries import update_asset_hash_and_mime
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
session.commit()
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(

View File

@@ -23,9 +23,11 @@ class ReferenceData:
file_path: str | None
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None
system_metadata: dict[str, Any] | None = None
prompt_id: str | None = None
created_at: datetime = None # type: ignore[assignment]
updated_at: datetime = None # type: ignore[assignment]
last_access_time: datetime | None = None
@dataclass(frozen=True)
@@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
prompt_id=ref.prompt_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,

View File

@@ -1,3 +1,5 @@
from typing import Sequence
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
@@ -6,6 +8,7 @@ from app.assets.database.queries import (
list_tags_with_usage,
remove_tags_from_reference,
)
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage
from app.database.db import create_session
@@ -73,3 +76,23 @@ def list_tags(
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
if tp > 0 and not k.startswith("clip_"):
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False

View File

@@ -1,68 +0,0 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@@ -1,395 +0,0 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 750
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=6,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.024,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.007 : 0.04;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.007 : 0.04;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@@ -67,7 +67,6 @@ class _RequestConfig:
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass
@@ -203,13 +202,11 @@ async def sync_op_raw(
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
@@ -235,7 +232,6 @@ async def sync_op_raw(
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -773,12 +769,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
resp_headers = dict(resp.headers)
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@@ -786,7 +776,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=resp_headers,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
return bytes_payload

View File

@@ -3,11 +3,11 @@ comfy.options.enable_args_parsing()
import os
import importlib.util
import importlib.metadata
import folder_paths
import time
from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
from app.assets.seeder import asset_seeder
import itertools
import utils.extra_config
from utils.mime_types import init_mime_types
@@ -182,6 +182,7 @@ if 'torch' in sys.modules:
import comfy.utils
from app.assets.seeder import asset_seeder
import execution
import server
@@ -451,6 +452,11 @@ if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
for package in ("comfy-aimdo", "comfy-kitchen"):
try:
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
except:
pass
if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")

View File

@@ -23,7 +23,7 @@ SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.7
comfy-aimdo>=0.2.9
requests
simpleeval>=1.0.0
blake3

View File

@@ -97,7 +97,7 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
# normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"}
assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
assert "newtag" in b1["tags"] and "beta" in b1["tags"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g = rg.json()