Compare commits

..

5 Commits

Author SHA1 Message Date
rattus
535c16ce6e Widen OOM_EXCEPTION to AcceleratorError form (#12835)
Pytorch only filters for OOMs in its own allocators however there are
paths that can OOM on allocators made outside the pytorch allocators.
These manifest as an AllocatorError as pytorch does not have universal
error translation to its OOM type on exception. Handle it. A log I have
for this also shows a double report of the error async, so call the
async discarder to cleanup and make these OOMs look like OOMs.
2026-03-10 00:41:02 -04:00
rattus
a912809c25 model_detection: deep clone pre edited edited weights (#12862)
Deep clone these weights as needed to avoid segfaulting when it tries
to touch the original mmap.
2026-03-09 23:50:10 -04:00
comfyanonymous
c4fb0271cd Add a way for nodes to add pre attn patches to flux model. (#12861) 2026-03-09 23:37:58 -04:00
Dr.Lt.Data
740d998c9c fix(manager): improve install guidance when comfyui-manager is not installed (#12810) 2026-03-09 22:49:31 -04:00
ComfyUI Wiki
814dab9f46 Update workflow templates to v0.9.18 (#12857) 2026-03-09 22:03:22 -04:00
27 changed files with 188 additions and 354 deletions

View File

@@ -1,31 +0,0 @@
"""
Add system_metadata and prompt_id columns to asset_references.
Revision ID: 0003_add_metadata_prompt
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
revision = "0003_add_metadata_prompt"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("prompt_id", sa.String(length=36), nullable=True)
)
def downgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("prompt_id")
batch_op.drop_column("system_metadata")

View File

@@ -38,7 +38,6 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -123,29 +122,6 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at"
def _build_asset_response(result) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
preview_url = None
if result.ref.preview_id:
preview_url = f"/api/assets/{result.ref.preview_id}/content?disposition=inline"
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else 0,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
prompt_id=result.ref.prompt_id,
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response:
@@ -188,7 +164,20 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order,
)
summaries = [_build_asset_response(item) for item in result.items]
summaries = [
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList(
assets=summaries,
@@ -218,7 +207,18 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id},
)
payload = _build_asset_response(result)
payload = schemas_out.AssetDetail(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
except ValueError as e:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@@ -312,27 +312,29 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON."
)
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash(
hash_str=body.hash,
name=name,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
)
if result is None:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
@@ -356,9 +358,6 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
"id": parsed.provided_id,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
}
)
except ValidationError as ve:
@@ -379,21 +378,6 @@ async def upload_asset(request: web.Request) -> web.Response:
)
try:
# Idempotent create: if spec.id is provided, check if reference already exists
if spec.id:
existing = get_asset_detail(
reference_id=spec.id,
owner_id=owner_id,
)
if existing:
delete_temp_file_if_exists(parsed.tmp_path)
asset = _build_asset_response(existing)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
created_new=False,
)
return web.json_response(payload_out.model_dump(mode="json"), status=200)
# Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash(
@@ -402,7 +386,6 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
mime_type=spec.mime_type,
)
if result is None:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -427,9 +410,6 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
asset_id=spec.id,
)
except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -448,13 +428,21 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
payload = schemas_out.AssetCreated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new,
)
status = 201 if result.created_new else 200
return web.json_response(payload_out.model_dump(mode="json"), status=status)
return web.json_response(payload.model_dump(mode="json"), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@@ -476,10 +464,15 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
)
payload = _build_asset_response(result)
payload = schemas_out.AssetUpdated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve:
@@ -594,7 +587,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsAdd(
added=result.added,
already_present=result.already_present,
tags=result.total_tags,
total_tags=result.total_tags,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@@ -641,7 +634,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsRemove(
removed=result.removed,
not_present=result.not_present,
tags=result.total_tags,
total_tags=result.total_tags,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@@ -660,28 +653,6 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
@_require_assets_feature_enabled
async def seed_assets(request: web.Request) -> web.Response:

View File

@@ -45,9 +45,6 @@ class ParsedUpload:
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel):
@@ -101,18 +98,11 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
mime_type: str | None = None
preview_id: str | None = None
@model_validator(mode="after")
def _validate_at_least_one_field(self):
if all(
v is None
for v in (self.name, self.user_metadata, self.mime_type, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, mime_type, preview_id."
)
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
@@ -120,10 +110,9 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str | None = None
tags: list[str] = Field(default_factory=list, min_length=1)
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
@field_validator("hash")
@classmethod
@@ -149,44 +138,6 @@ class CreateFromHashBody(BaseModel):
return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@@ -235,27 +186,21 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: optional list; if provided, first is root ('models'|'input'|'output');
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
- id: optional UUID for idempotent creation
- mime_type: optional MIME type override
- preview_id: optional asset ID for preview
Files are stored using the content hash as filename stem.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(default_factory=list)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
id: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
@@ -333,13 +278,14 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if self.tags:
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@@ -4,21 +4,16 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class Asset(BaseModel):
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int = 0
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
preview_id: str | None = None
user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
prompt_id: str | None = None
created_at: datetime
updated_at: datetime
created_at: datetime | None = None
updated_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@@ -28,16 +23,50 @@ class Asset(BaseModel):
return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel):
assets: list[Asset]
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@@ -54,15 +83,11 @@ class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]
total_tags: list[str] = Field(default_factory=list)

View File

@@ -52,9 +52,6 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0
tmp_path: str | None = None
@@ -131,12 +128,6 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
elif fname == "id":
provided_id = ((await field.text()) or "").strip() or None
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
@@ -161,9 +152,6 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
provided_id=provided_id,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
)

View File

@@ -96,10 +96,6 @@ class AssetReference(Base):
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
prompt_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)

View File

@@ -54,7 +54,6 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
@@ -100,7 +99,6 @@ __all__ = [
"insert_reference",
"list_references_by_asset_id",
"list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",

View File

@@ -320,59 +320,6 @@ def list_tags_with_usage(
return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
from app.assets.database.queries.asset_reference import (
_apply_tag_filters,
_apply_metadata_filter,
)
from app.assets.database.models import Asset
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = _apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = _apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],

View File

@@ -67,8 +67,6 @@ def update_asset_metadata(
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult:
with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id)
@@ -105,21 +103,6 @@ def update_asset_metadata(
)
touched = True
if mime_type is not None:
from app.assets.database.queries import update_asset_hash_and_mime
update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_asset_id=preview_id,
)
touched = True
if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id)

View File

@@ -242,9 +242,6 @@ def upload_from_temp_path(
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
asset_id: str | None = None,
) -> UploadResult:
try:
digest, _ = hashing.compute_blake3_hash(temp_path)
@@ -294,7 +291,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = mime_type or (
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
@@ -318,7 +315,7 @@ def upload_from_temp_path(
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=preview_id,
preview_id=None,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
@@ -351,7 +348,6 @@ def create_from_hash(
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult | None:
canonical = hash_str.strip().lower()
@@ -360,11 +356,6 @@ def create_from_hash(
if not asset:
return None
if mime_type and asset.mime_type != mime_type:
from app.assets.database.queries import update_asset_hash_and_mime
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
session.commit()
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(

View File

@@ -23,11 +23,9 @@ class ReferenceData:
file_path: str | None
user_metadata: UserMetadata
preview_id: str | None
system_metadata: dict[str, Any] | None = None
prompt_id: str | None = None
created_at: datetime = None # type: ignore[assignment]
updated_at: datetime = None # type: ignore[assignment]
last_access_time: datetime | None = None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None
@dataclass(frozen=True)
@@ -95,8 +93,6 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
prompt_id=ref.prompt_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,

View File

@@ -1,5 +1,3 @@
from typing import Sequence
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
@@ -8,7 +6,6 @@ from app.assets.database.queries import (
list_tags_with_usage,
remove_tags_from_reference,
)
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage
from app.database.db import create_session
@@ -76,23 +73,3 @@ def list_tags(
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module):
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
@@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module):
del qkv
q, k = self.norm(q, k, v)
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

View File

@@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])

View File

@@ -170,7 +170,7 @@ class Flux(nn.Module):
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]

View File

@@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:

View File

@@ -258,7 +258,8 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
@@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:

View File

@@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)

View File

@@ -1,4 +1,5 @@
import json
import comfy.memory_management
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
@@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
new[:old_weight.shape[0]] = old_weight
old_weight = new
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
old_weight = old_weight.clone()
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
if comfy.memory_management.aimdo_enabled:
weight = weight.clone()
old_weight = weight
w = weight
w[:] = fun(weight)

View File

@@ -270,6 +270,18 @@ try:
except:
OOM_EXCEPTION = Exception
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
discard_cuda_async_error()
return True
return False
def raise_non_oom(e):
if not is_oom(e):
raise e
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:

View File

@@ -954,7 +954,8 @@ class VAE:
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
@@ -1029,7 +1030,8 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.

View File

@@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode):
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
tile //= 2
if tile < 128:
raise e

View File

@@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.error(traceback.format_exc())
tips = ""
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
if comfy.model_management.is_oom(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")

13
main.py
View File

@@ -3,6 +3,7 @@ comfy.options.enable_args_parsing()
import os
import importlib.util
import shutil
import importlib.metadata
import folder_paths
import time
@@ -64,8 +65,15 @@ if __name__ == "__main__":
def handle_comfyui_manager_unavailable():
if not args.windows_standalone_build:
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
uv_available = shutil.which("uv") is not None
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
if uv_available:
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
msg += "\n"
logging.warning(msg)
args.enable_manager = False
@@ -173,7 +181,6 @@ execute_prestartup_script()
# Main code
import asyncio
import shutil
import threading
import gc

View File

@@ -1 +1 @@
comfyui_manager==4.1b1
comfyui_manager==4.1b2

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.11
comfyui-workflow-templates==0.9.18
comfyui-embedded-docs==0.4.3
torch
torchsde

View File

@@ -97,7 +97,7 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
# normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"}
assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["tags"] and "beta" in b1["tags"]
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g = rg.json()