mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-07 04:59:58 +00:00
Compare commits
53 Commits
v0.16.2
...
pyisolate-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82af45530d | ||
|
|
9e3e939db1 | ||
|
|
b11129e169 | ||
|
|
a6b5e6545d | ||
|
|
d90e28863e | ||
|
|
683e2d6a73 | ||
|
|
878684d8b2 | ||
|
|
c02372936d | ||
|
|
6aa0b838a0 | ||
|
|
54461f9ecc | ||
|
|
b602cc4533 | ||
|
|
08b92a48c3 | ||
|
|
c5e7b9cdaf | ||
|
|
9ce4c3dd87 | ||
|
|
abc87d3669 | ||
|
|
f6274c06b4 | ||
|
|
4f4f8659c2 | ||
|
|
3365008dfe | ||
|
|
980621da83 | ||
|
|
9642e4407b | ||
|
|
3ad36d6be6 | ||
|
|
8086468d2a | ||
|
|
535c16ce6e | ||
|
|
a912809c25 | ||
|
|
c4fb0271cd | ||
|
|
740d998c9c | ||
|
|
814dab9f46 | ||
|
|
06f85e2c79 | ||
|
|
e4b0bb8305 | ||
|
|
7723f20bbe | ||
|
|
29b24cb517 | ||
|
|
a7a6335be5 | ||
|
|
bcf1a1fab1 | ||
|
|
6ac8152fc8 | ||
|
|
afc00f0055 | ||
|
|
d69d30819b | ||
|
|
f466b06601 | ||
|
|
34e55f0061 | ||
|
|
3b93d5d571 | ||
|
|
e544c65db9 | ||
|
|
1c21828236 | ||
|
|
58017e8726 | ||
|
|
17b43c2b87 | ||
|
|
8befce5c7b | ||
|
|
623a9d21e9 | ||
|
|
9250191c65 | ||
|
|
a0f8784e9f | ||
|
|
7962db477a | ||
|
|
3c8ba051b6 | ||
|
|
a1c3124821 | ||
|
|
9ca799362d | ||
|
|
22f5e43c12 | ||
|
|
3cfd5e3311 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -24,3 +24,4 @@ web_custom_versions/
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
.pyisolate_venvs/
|
||||
|
||||
267
alembic_db/versions/0002_merge_to_asset_references.py
Normal file
267
alembic_db/versions/0002_merge_to_asset_references.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Merge AssetInfo and AssetCacheState into unified asset_references table.
|
||||
|
||||
This migration drops old tables and creates the new unified schema.
|
||||
All existing data is discarded.
|
||||
|
||||
Revision ID: 0002_merge_to_asset_references
|
||||
Revises: 0001_assets
|
||||
Create Date: 2025-02-11
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0002_merge_to_asset_references"
|
||||
down_revision = "0001_assets"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop old tables (order matters due to FK constraints)
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
# Truncate assets table (cascades handled by dropping dependent tables first)
|
||||
op.execute("DELETE FROM assets")
|
||||
|
||||
# Create asset_references table
|
||||
op.create_table(
|
||||
"asset_references",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column(
|
||||
"asset_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("assets.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("file_path", sa.Text(), nullable=True),
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column(
|
||||
"needs_verify",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column(
|
||||
"is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column(
|
||||
"preview_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("assets.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"enrichment_level >= 0 AND enrichment_level <= 2",
|
||||
name="ck_ar_enrichment_level_range",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"uq_asset_references_file_path", "asset_references", ["file_path"], unique=True
|
||||
)
|
||||
op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"])
|
||||
op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"])
|
||||
op.create_index("ix_asset_references_name", "asset_references", ["name"])
|
||||
op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"])
|
||||
op.create_index(
|
||||
"ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"]
|
||||
)
|
||||
op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"])
|
||||
op.create_index(
|
||||
"ix_asset_references_last_access_time", "asset_references", ["last_access_time"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_references_owner_name", "asset_references", ["owner_id", "name"]
|
||||
)
|
||||
op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"])
|
||||
|
||||
# Create asset_reference_tags table
|
||||
op.create_table(
|
||||
"asset_reference_tags",
|
||||
sa.Column(
|
||||
"asset_reference_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"tag_name",
|
||||
sa.String(length=512),
|
||||
sa.ForeignKey("tags.name", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"origin", sa.String(length=32), nullable=False, server_default="manual"
|
||||
),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint(
|
||||
"asset_reference_id", "tag_name", name="pk_asset_reference_tags"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_tags_asset_reference_id",
|
||||
"asset_reference_tags",
|
||||
["asset_reference_id"],
|
||||
)
|
||||
|
||||
# Create asset_reference_meta table
|
||||
op.create_table(
|
||||
"asset_reference_meta",
|
||||
sa.Column(
|
||||
"asset_reference_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint(
|
||||
"asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"])
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_bool",
|
||||
"asset_reference_meta",
|
||||
["key", "val_bool"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema.
|
||||
|
||||
NOTE: Data is not recoverable. The upgrade discards all rows from the old
|
||||
tables and truncates assets. After downgrade the old schema will be empty.
|
||||
A filesystem rescan will repopulate data once the older code is running.
|
||||
"""
|
||||
# Drop new tables (order matters due to FK constraints)
|
||||
op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta")
|
||||
op.drop_table("asset_reference_meta")
|
||||
|
||||
op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags")
|
||||
op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags")
|
||||
op.drop_table("asset_reference_tags")
|
||||
|
||||
op.drop_index("ix_asset_references_deleted_at", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_owner_name", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_last_access_time", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_created_at", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_is_missing", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_name", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_owner_id", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_asset_id", table_name="asset_references")
|
||||
op.drop_index("uq_asset_references_file_path", table_name="asset_references")
|
||||
op.drop_table("asset_references")
|
||||
|
||||
# Truncate assets (upgrade deleted all rows; downgrade starts fresh too)
|
||||
op.execute("DELETE FROM assets")
|
||||
|
||||
# Recreate old tables from 0001_assets schema
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False),
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.assets.helpers import validate_blake3_hash
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -10,6 +12,41 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class UploadError(Exception):
|
||||
"""Error during upload parsing with HTTP status and code."""
|
||||
|
||||
def __init__(self, status: int, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class AssetValidationError(Exception):
|
||||
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
|
||||
|
||||
def __init__(self, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedUpload:
|
||||
"""Result of parsing a multipart upload request."""
|
||||
|
||||
file_present: bool
|
||||
file_written: int
|
||||
file_client_name: str | None
|
||||
tmp_path: str | None
|
||||
tags_raw: list[str]
|
||||
provided_name: str | None
|
||||
user_metadata_raw: str | None
|
||||
provided_hash: str | None
|
||||
provided_hash_exists: bool | None
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@@ -21,7 +58,9 @@ class ListAssetsQuery(BaseModel):
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
)
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@@ -61,7 +100,7 @@ class UpdateAssetBody(BaseModel):
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
def _validate_at_least_one_field(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
@@ -78,19 +117,11 @@ class CreateFromHashBody(BaseModel):
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
return validate_blake3_hash(v or "")
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
def _normalize_tags_field(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
@@ -154,15 +185,16 @@ class TagsRemove(TagsAdd):
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
if root == 'models', second must be a valid category
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
Files are stored using the content hash as filename stem.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
@@ -175,17 +207,10 @@ class UploadAssetSpec(BaseModel):
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
return validate_blake3_hash(s)
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
@@ -260,5 +285,7 @@ class UploadAssetSpec(BaseModel):
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
raise ValueError(
|
||||
"models uploads require a category tag as the second tag"
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -19,7 +19,7 @@ class AssetSummary(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class AssetUpdated(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class AssetDetail(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
|
||||
171
app/assets/api/upload.py
Normal file
171
app/assets/api/upload.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
import folder_paths
|
||||
from app.assets.api.schemas_in import ParsedUpload, UploadError
|
||||
from app.assets.helpers import validate_blake3_hash
|
||||
|
||||
|
||||
def normalize_and_validate_hash(s: str) -> str:
|
||||
"""Validate and normalize a hash string.
|
||||
|
||||
Returns canonical 'blake3:<hex>' or raises UploadError.
|
||||
"""
|
||||
try:
|
||||
return validate_blake3_hash(s)
|
||||
except ValueError:
|
||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
|
||||
async def parse_multipart_upload(
|
||||
request: web.Request,
|
||||
check_hash_exists: Callable[[str], bool],
|
||||
) -> ParsedUpload:
|
||||
"""
|
||||
Parse a multipart/form-data upload request.
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
|
||||
|
||||
Returns:
|
||||
ParsedUpload with parsed fields and temp file path
|
||||
|
||||
Raises:
|
||||
UploadError: On validation or I/O errors
|
||||
"""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
raise UploadError(
|
||||
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
|
||||
)
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
||||
)
|
||||
|
||||
if s:
|
||||
provided_hash = normalize_and_validate_hash(s)
|
||||
try:
|
||||
provided_hash_exists = check_hash_exists(provided_hash)
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
"check_hash_exists failed for hash=%s: %s", provided_hash, e
|
||||
)
|
||||
raise UploadError(
|
||||
500,
|
||||
"HASH_CHECK_FAILED",
|
||||
"Backend error while checking asset hash.",
|
||||
)
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# Hash exists - drain file but don't write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
|
||||
)
|
||||
continue
|
||||
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
|
||||
)
|
||||
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
|
||||
)
|
||||
|
||||
if (
|
||||
file_present
|
||||
and file_written == 0
|
||||
and not (provided_hash and provided_hash_exists)
|
||||
):
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
return ParsedUpload(
|
||||
file_present=file_present,
|
||||
file_written=file_written,
|
||||
file_client_name=file_client_name,
|
||||
tmp_path=tmp_path,
|
||||
tags_raw=tags_raw,
|
||||
provided_name=provided_name,
|
||||
user_metadata_raw=user_metadata_raw,
|
||||
provided_hash=provided_hash,
|
||||
provided_hash_exists=provided_hash_exists,
|
||||
)
|
||||
|
||||
|
||||
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
|
||||
"""Safely remove a temp file and its parent directory if empty."""
|
||||
if tmp_path:
|
||||
try:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
except OSError as e:
|
||||
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)
|
||||
try:
|
||||
parent = os.path.dirname(tmp_path)
|
||||
if parent and os.path.isdir(parent):
|
||||
os.rmdir(parent) # only succeeds if empty
|
||||
except OSError:
|
||||
pass
|
||||
@@ -1,204 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
)
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
@@ -16,102 +16,102 @@ from sqlalchemy import (
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.database.models import Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
references: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
|
||||
foreign_keys=lambda: [AssetReference.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
preview_of: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
||||
foreign_keys=lambda: [AssetReference.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
class AssetReference(Base):
|
||||
"""Unified model combining file cache state and user-facing metadata.
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
Each row represents either:
|
||||
- A filesystem reference (file_path is set) with cache state
|
||||
- An API-created reference (file_path is NULL) without cache state
|
||||
"""
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
__tablename__ = "asset_references"
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
asset_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
# Cache state fields (from former AssetCacheState)
|
||||
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
# Info fields (from former AssetInfo)
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
preview_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
||||
)
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True)
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=False), nullable=True, default=None
|
||||
)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
back_populates="references",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
@@ -121,51 +121,59 @@ class AssetInfo(Base):
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||
back_populates="asset_info",
|
||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||
back_populates="asset_reference",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="asset_info",
|
||||
tag_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="asset_reference",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
overlaps="tags,asset_references",
|
||||
)
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
secondary="asset_reference_tags",
|
||||
back_populates="asset_references",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
overlaps="tag_links,asset_reference_links,asset_references,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
Index("uq_asset_references_file_path", "file_path", unique=True),
|
||||
Index("ix_asset_references_asset_id", "asset_id"),
|
||||
Index("ix_asset_references_owner_id", "owner_id"),
|
||||
Index("ix_asset_references_name", "name"),
|
||||
Index("ix_asset_references_is_missing", "is_missing"),
|
||||
Index("ix_asset_references_enrichment_level", "enrichment_level"),
|
||||
Index("ix_asset_references_created_at", "created_at"),
|
||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||
CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
),
|
||||
CheckConstraint(
|
||||
"enrichment_level >= 0 AND enrichment_level <= 2",
|
||||
name="ck_ar_enrichment_level_range",
|
||||
),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
path_part = f" path={self.file_path!r}" if self.file_path else ""
|
||||
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
class AssetReferenceMeta(Base):
|
||||
__tablename__ = "asset_reference_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
asset_reference_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
@@ -175,36 +183,40 @@ class AssetInfoMeta(Base):
|
||||
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||
asset_reference: Mapped[AssetReference] = relationship(
|
||||
back_populates="metadata_entries"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
Index("ix_asset_reference_meta_key", "key"),
|
||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
class AssetReferenceTag(Base):
|
||||
__tablename__ = "asset_reference_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
asset_reference_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||
asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
Index("ix_asset_reference_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
|
||||
)
|
||||
|
||||
|
||||
@@ -214,20 +226,18 @@ class Tag(Base):
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
overlaps="asset_references,tags",
|
||||
)
|
||||
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
asset_references: Mapped[list[AssetReference]] = relationship(
|
||||
secondary="asset_reference_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
overlaps="asset_reference_links,tag_links,tags,asset_reference",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
@@ -1,976 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
121
app/assets/database/queries/__init__.py
Normal file
121
app/assets/database/queries/__init__.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
)
|
||||
from app.assets.database.queries.asset_reference import (
|
||||
CacheStateRow,
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
delete_assets_by_ids,
|
||||
delete_orphaned_seed_asset,
|
||||
delete_reference_by_id,
|
||||
delete_references_by_ids,
|
||||
fetch_reference_and_asset,
|
||||
fetch_reference_asset_and_tags,
|
||||
get_or_create_reference,
|
||||
get_reference_by_file_path,
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
insert_reference,
|
||||
list_references_by_asset_id,
|
||||
list_references_page,
|
||||
mark_references_missing_outside_prefixes,
|
||||
reference_exists_for_asset_id,
|
||||
restore_references_by_paths,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
soft_delete_reference_by_id,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_timestamps,
|
||||
update_reference_updated_at,
|
||||
upsert_reference,
|
||||
)
|
||||
from app.assets.database.queries.tags import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
SetTagsResult,
|
||||
add_missing_tag_for_asset_id,
|
||||
add_tags_to_reference,
|
||||
bulk_insert_tags_and_meta,
|
||||
ensure_tags_exist,
|
||||
get_reference_tags,
|
||||
list_tags_with_usage,
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_reference,
|
||||
set_reference_tags,
|
||||
validate_tags_exist,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsResult",
|
||||
"CacheStateRow",
|
||||
"RemoveTagsResult",
|
||||
"SetTagsResult",
|
||||
"UnenrichedReferenceRow",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"add_tags_to_reference",
|
||||
"asset_exists_by_hash",
|
||||
"bulk_insert_assets",
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
"delete_assets_by_ids",
|
||||
"delete_orphaned_seed_asset",
|
||||
"delete_reference_by_id",
|
||||
"delete_references_by_ids",
|
||||
"ensure_tags_exist",
|
||||
"fetch_reference_and_asset",
|
||||
"fetch_reference_asset_and_tags",
|
||||
"get_asset_by_hash",
|
||||
"get_existing_asset_ids",
|
||||
"get_or_create_reference",
|
||||
"get_reference_by_file_path",
|
||||
"get_reference_by_id",
|
||||
"get_reference_with_owner_check",
|
||||
"get_reference_ids_by_ids",
|
||||
"get_reference_tags",
|
||||
"get_references_by_paths_and_asset_ids",
|
||||
"get_references_for_prefixes",
|
||||
"get_unenriched_references",
|
||||
"get_unreferenced_unhashed_asset_ids",
|
||||
"insert_reference",
|
||||
"list_references_by_asset_id",
|
||||
"list_references_page",
|
||||
"list_tags_with_usage",
|
||||
"mark_references_missing_outside_prefixes",
|
||||
"reassign_asset_references",
|
||||
"reference_exists_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_tags_from_reference",
|
||||
"restore_references_by_paths",
|
||||
"set_reference_metadata",
|
||||
"set_reference_preview",
|
||||
"soft_delete_reference_by_id",
|
||||
"set_reference_tags",
|
||||
"update_asset_hash_and_mime",
|
||||
"update_reference_access_time",
|
||||
"update_reference_name",
|
||||
"update_reference_timestamps",
|
||||
"update_reference_updated_at",
|
||||
"upsert_asset",
|
||||
"upsert_reference",
|
||||
"validate_tags_exist",
|
||||
]
|
||||
140
app/assets/database/queries/asset.py
Normal file
140
app/assets/database/queries/asset.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True))
|
||||
.select_from(Asset)
|
||||
.where(Asset.hash == asset_hash)
|
||||
.limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def upsert_asset(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> tuple[Asset, bool, bool]:
|
||||
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
|
||||
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
|
||||
if mime_type:
|
||||
vals["mime_type"] = mime_type
|
||||
|
||||
ins = (
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
res = session.execute(ins)
|
||||
created = int(res.rowcount or 0) > 0
|
||||
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
|
||||
updated = False
|
||||
if not created:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
updated = True
|
||||
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
) -> None:
|
||||
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
|
||||
if not rows:
|
||||
return
|
||||
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
|
||||
session.execute(ins, chunk)
|
||||
|
||||
|
||||
def get_existing_asset_ids(
|
||||
session: Session,
|
||||
asset_ids: list[str],
|
||||
) -> set[str]:
|
||||
"""Return the subset of asset_ids that exist in the database."""
|
||||
if not asset_ids:
|
||||
return set()
|
||||
found: set[str] = set()
|
||||
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
|
||||
rows = session.execute(
|
||||
select(Asset.id).where(Asset.id.in_(chunk))
|
||||
).fetchall()
|
||||
found.update(row[0] for row in rows)
|
||||
return found
|
||||
|
||||
|
||||
def update_asset_hash_and_mime(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
asset_hash: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> bool:
|
||||
"""Update asset hash and/or mime_type. Returns True if asset was found."""
|
||||
asset = session.get(Asset, asset_id)
|
||||
if not asset:
|
||||
return False
|
||||
if asset_hash is not None:
|
||||
asset.hash = asset_hash
|
||||
if mime_type is not None:
|
||||
asset.mime_type = mime_type
|
||||
return True
|
||||
|
||||
|
||||
def reassign_asset_references(
|
||||
session: Session,
|
||||
from_asset_id: str,
|
||||
to_asset_id: str,
|
||||
reference_id: str,
|
||||
) -> None:
|
||||
"""Reassign a reference from one asset to another.
|
||||
|
||||
Used when merging a stub asset into an existing asset with the same hash.
|
||||
"""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if ref and ref.asset_id == from_asset_id:
|
||||
ref.asset_id = to_asset_id
|
||||
|
||||
session.flush()
|
||||
1033
app/assets/database/queries/asset_reference.py
Normal file
1033
app/assets/database/queries/asset_reference.py
Normal file
File diff suppressed because it is too large
Load Diff
54
app/assets/database/queries/common.py
Normal file
54
app/assets/database/queries/common.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Shared utilities for database query modules."""
|
||||
|
||||
import os
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.assets.database.models import AssetReference
|
||||
from app.assets.helpers import escape_sql_like_string
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
def calculate_rows_per_statement(cols: int) -> int:
|
||||
"""Calculate how many rows can fit in one statement given column count."""
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def iter_chunks(seq, n: int):
|
||||
"""Yield successive n-sized chunks from seq."""
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i : i + n]
|
||||
|
||||
|
||||
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
|
||||
"""Yield chunks of rows sized to fit within bind param limits."""
|
||||
if not rows:
|
||||
return
|
||||
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
|
||||
|
||||
|
||||
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads.
|
||||
|
||||
Owner-less rows are visible to everyone.
|
||||
"""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetReference.owner_id == ""
|
||||
return AssetReference.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def build_prefix_like_conditions(
|
||||
prefixes: list[str],
|
||||
) -> list[sa.sql.ColumnElement]:
|
||||
"""Build LIKE conditions for matching file paths under directory prefixes."""
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||
return conds
|
||||
356
app/assets/database/queries/tags.py
Normal file
356
app/assets/database/queries/tags.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import (
|
||||
AssetReference,
|
||||
AssetReferenceMeta,
|
||||
AssetReferenceTag,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddTagsResult:
|
||||
added: list[str]
|
||||
already_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoveTagsResult:
|
||||
removed: list[str]
|
||||
not_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetTagsResult:
|
||||
added: list[str]
|
||||
removed: list[str]
|
||||
total: list[str]
|
||||
|
||||
|
||||
def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any of the given tag names do not exist."""
|
||||
existing_tag_names = set(
|
||||
name
|
||||
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
|
||||
)
|
||||
missing = [t for t in tags if t not in existing_tag_names]
|
||||
if missing:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def set_reference_tags(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> SetTagsResult:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id,
|
||||
AssetReferenceTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
||||
|
||||
|
||||
def add_tags_to_reference(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
reference_row: AssetReference | None = None,
|
||||
) -> AddTagsResult:
|
||||
if not reference_row:
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_reference_tags(session, reference_id=reference_id))
|
||||
return AddTagsResult(
|
||||
added=sorted(((after - current) & want)),
|
||||
already_present=sorted(want & current),
|
||||
total_tags=sorted(after),
|
||||
)
|
||||
|
||||
|
||||
def remove_tags_from_reference(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> RemoveTagsResult:
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
|
||||
|
||||
existing = set(get_reference_tags(session, reference_id))
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id,
|
||||
AssetReferenceTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
|
||||
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sa.select(
|
||||
AssetReference.id.label("asset_reference_id"),
|
||||
sa.literal("missing").label("tag_name"),
|
||||
sa.literal(origin).label("origin"),
|
||||
sa.literal(get_utc_now()).label("added_at"),
|
||||
)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(
|
||||
sa.not_(
|
||||
sa.exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == "missing")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetReferenceTag)
|
||||
.from_select(
|
||||
["asset_reference_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceTag.asset_reference_id,
|
||||
AssetReferenceTag.tag_name,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id.in_(
|
||||
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
|
||||
),
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name.label("tag_name"),
|
||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetReferenceTag)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
visible_tags_sq = (
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
)
|
||||
total_q = total_q.where(Tag.name.in_(visible_tags_sq))
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
) -> None:
|
||||
"""Batch insert into asset_reference_tags and asset_reference_meta.
|
||||
|
||||
Uses ON CONFLICT DO NOTHING.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
|
||||
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceTag.asset_reference_id,
|
||||
AssetReferenceTag.tag_name,
|
||||
]
|
||||
)
|
||||
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
|
||||
session.execute(ins_tags, chunk)
|
||||
|
||||
if meta_rows:
|
||||
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceMeta.asset_reference_id,
|
||||
AssetReferenceMeta.key,
|
||||
AssetReferenceMeta.ordinal,
|
||||
]
|
||||
)
|
||||
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
|
||||
session.execute(ins_meta, chunk)
|
||||
@@ -1,62 +0,0 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
@@ -1,226 +1,42 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
def select_best_live_path(states: Sequence) -> str:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
alive = [
|
||||
s
|
||||
for s in states
|
||||
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
|
||||
]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char in a LIKE prefix.
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
Returns (escaped_prefix, escape_char).
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
def get_utc_now() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
@@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
"""Validate and normalize a blake3 hash string.
|
||||
|
||||
def project_kv(key: str, value):
|
||||
Returns canonical 'blake3:<hex>' or raises ValueError.
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
s = s.strip().lower()
|
||||
if not s or ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if (
|
||||
algo != "blake3"
|
||||
or len(digest) != 64
|
||||
or any(c for c in digest if c not in "0123456789abcdef")
|
||||
):
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@@ -1,516 +0,0 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
@@ -1,263 +1,567 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, TypedDict
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
from app.assets.database.queries import (
|
||||
add_missing_tag_for_asset_id,
|
||||
bulk_update_enrichment_level,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
delete_orphaned_seed_asset,
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
reassign_asset_references,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
from app.assets.services.bulk_ingest import (
|
||||
SeedAssetSpec,
|
||||
batch_insert_seed_assets,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
is_visible,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
class _RefInfo(TypedDict):
|
||||
ref_id: str
|
||||
file_path: str
|
||||
exists: bool
|
||||
stat_unchanged: bool
|
||||
needs_verify: bool
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
class _AssetAccumulator(TypedDict):
|
||||
hash: str | None
|
||||
size_db: int
|
||||
refs: list[_RefInfo]
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
|
||||
|
||||
def get_prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def get_all_known_prefixes() -> list[str]:
|
||||
"""Get all known asset prefixes across all root types."""
|
||||
all_roots: tuple[RootType, ...] = ("models", "input", "output")
|
||||
return [p for root in all_roots for p in get_prefixes_for_root(root)]
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
if not all(is_visible(part) for part in Path(rel_path).parts):
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
abs_p = Path(abs_path)
|
||||
for b in bases:
|
||||
if abs_p.is_relative_to(os.path.abspath(b)):
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
def sync_references_with_filesystem(
|
||||
session,
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""Reconcile asset references with filesystem for a root.
|
||||
|
||||
- Toggle needs_verify per reference using mtime/size stat check
|
||||
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
|
||||
- For seed assets with all refs missing: delete Asset and its references
|
||||
- Optionally add/remove 'missing' tags based on stat check in this root
|
||||
- Optionally return surviving absolute paths
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
root: Root type to scan
|
||||
collect_existing_paths: If True, return set of surviving file paths
|
||||
update_missing_tags: If True, update 'missing' tags based on file status
|
||||
|
||||
Returns:
|
||||
Set of surviving absolute paths if collect_existing_paths=True, else None
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
rows = get_references_for_prefixes(
|
||||
session, prefixes, include_missing=update_missing_tags
|
||||
)
|
||||
|
||||
by_asset: dict[str, _AssetAccumulator] = {}
|
||||
for row in rows:
|
||||
acc = by_asset.get(row.asset_id)
|
||||
if acc is None:
|
||||
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
|
||||
by_asset[row.asset_id] = acc
|
||||
|
||||
stat_unchanged = False
|
||||
try:
|
||||
exists = True
|
||||
stat_unchanged = verify_file_unchanged(
|
||||
mtime_db=row.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(row.file_path, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except PermissionError:
|
||||
exists = True
|
||||
logging.debug("Permission denied accessing %s", row.file_path)
|
||||
except OSError as e:
|
||||
exists = False
|
||||
logging.debug("OSError checking %s: %s", row.file_path, e)
|
||||
|
||||
acc["refs"].append(
|
||||
{
|
||||
"ref_id": row.reference_id,
|
||||
"file_path": row.file_path,
|
||||
"exists": exists,
|
||||
"stat_unchanged": stat_unchanged,
|
||||
"needs_verify": row.needs_verify,
|
||||
}
|
||||
)
|
||||
|
||||
to_set_verify: list[str] = []
|
||||
to_clear_verify: list[str] = []
|
||||
stale_ref_ids: list[str] = []
|
||||
to_mark_missing: list[str] = []
|
||||
to_clear_missing: list[str] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
refs = acc["refs"]
|
||||
any_unchanged = any(r["stat_unchanged"] for r in refs)
|
||||
all_missing = all(not r["exists"] for r in refs)
|
||||
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
to_mark_missing.append(r["ref_id"])
|
||||
continue
|
||||
if r["stat_unchanged"]:
|
||||
to_clear_missing.append(r["ref_id"])
|
||||
if r["needs_verify"]:
|
||||
to_clear_verify.append(r["ref_id"])
|
||||
if not r["stat_unchanged"] and not r["needs_verify"]:
|
||||
to_set_verify.append(r["ref_id"])
|
||||
|
||||
if a_hash is None:
|
||||
if refs and all_missing:
|
||||
delete_orphaned_seed_asset(session, aid)
|
||||
else:
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
continue
|
||||
|
||||
if any_unchanged:
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
stale_ref_ids.append(r["ref_id"])
|
||||
if update_missing_tags:
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=aid)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"Failed to remove missing tag for asset %s: %s", aid, e
|
||||
)
|
||||
elif update_missing_tags:
|
||||
try:
|
||||
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
|
||||
except Exception as e:
|
||||
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
|
||||
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
|
||||
delete_references_by_ids(session, stale_ref_ids)
|
||||
stale_set = set(stale_ref_ids)
|
||||
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
|
||||
bulk_update_is_missing(session, to_mark_missing, value=True)
|
||||
bulk_update_is_missing(session, to_clear_missing, value=False)
|
||||
bulk_update_needs_verify(session, to_set_verify, value=True)
|
||||
bulk_update_needs_verify(session, to_clear_verify, value=False)
|
||||
|
||||
return survivors if collect_existing_paths else None
|
||||
|
||||
|
||||
def sync_root_safely(root: RootType) -> set[str]:
|
||||
"""Sync a single root's references with the filesystem.
|
||||
|
||||
Returns survivors (existing paths) or empty set on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
survivors = sync_references_with_filesystem(
|
||||
sess,
|
||||
root,
|
||||
collect_existing_paths=True,
|
||||
update_missing_tags=True,
|
||||
)
|
||||
sess.commit()
|
||||
return survivors or set()
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", root, e)
|
||||
return set()
|
||||
|
||||
|
||||
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
|
||||
"""Mark references as missing when outside the given prefixes.
|
||||
|
||||
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
count = mark_references_missing_outside_prefixes(sess, prefixes)
|
||||
sess.commit()
|
||||
return count
|
||||
except Exception as e:
|
||||
logging.exception("marking missing assets failed: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
|
||||
"""Collect all file paths for the given roots."""
|
||||
paths: list[str] = []
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
|
||||
return paths
|
||||
|
||||
|
||||
def build_asset_specs(
|
||||
paths: list[str],
|
||||
existing_paths: set[str],
|
||||
enable_metadata_extraction: bool = True,
|
||||
compute_hashes: bool = False,
|
||||
) -> tuple[list[SeedAssetSpec], set[str], int]:
|
||||
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
|
||||
|
||||
Args:
|
||||
paths: List of file paths to process
|
||||
existing_paths: Set of paths that already exist in the database
|
||||
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
|
||||
compute_hashes: If True, compute blake3 hashes (slow for large files)
|
||||
"""
|
||||
specs: list[SeedAssetSpec] = []
|
||||
tag_pool: set[str] = set()
|
||||
skipped = 0
|
||||
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
rel_fname = compute_relative_filename(abs_p)
|
||||
|
||||
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
|
||||
metadata = None
|
||||
if enable_metadata_extraction:
|
||||
metadata = extract_file_metadata(
|
||||
abs_p,
|
||||
stat_result=stat_p,
|
||||
relative_filename=rel_fname,
|
||||
)
|
||||
|
||||
# Compute hash if requested
|
||||
asset_hash: str | None = None
|
||||
if compute_hashes:
|
||||
try:
|
||||
digest, _ = compute_blake3_hash(abs_p)
|
||||
asset_hash = "blake3:" + digest
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", abs_p, e)
|
||||
|
||||
mime_type = metadata.content_type if metadata else None
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": get_mtime_ns(stat_p),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": rel_fname,
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
|
||||
return specs, tag_pool, skipped
|
||||
|
||||
|
||||
|
||||
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
"""Insert asset specs into database, returning count of created refs."""
|
||||
if not specs:
|
||||
return 0
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||
sess.commit()
|
||||
return result.inserted_refs
|
||||
|
||||
|
||||
# Enrichment level constants
|
||||
ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only
|
||||
ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type)
|
||||
ENRICHMENT_HASHED = 2 # Hash computed (blake3)
|
||||
|
||||
|
||||
def get_unenriched_assets_for_roots(
|
||||
roots: tuple[RootType, ...],
|
||||
max_level: int = ENRICHMENT_STUB,
|
||||
limit: int = 1000,
|
||||
) -> list:
|
||||
"""Get assets that need enrichment for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
max_level: Maximum enrichment level to include
|
||||
limit: Maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
List of UnenrichedReferenceRow
|
||||
"""
|
||||
prefixes: list[str] = []
|
||||
for root in roots:
|
||||
prefixes.extend(get_prefixes_for_root(root))
|
||||
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
return get_unenriched_references(
|
||||
sess, prefixes, max_level=max_level, limit=limit
|
||||
)
|
||||
|
||||
|
||||
def enrich_asset(
|
||||
session,
|
||||
file_path: str,
|
||||
reference_id: str,
|
||||
asset_id: str,
|
||||
extract_metadata: bool = True,
|
||||
compute_hash: bool = False,
|
||||
interrupt_check: Callable[[], bool] | None = None,
|
||||
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
|
||||
) -> int:
|
||||
"""Enrich a single asset with metadata and/or hash.
|
||||
|
||||
Args:
|
||||
session: Database session (caller manages lifecycle)
|
||||
file_path: Absolute path to the file
|
||||
reference_id: ID of the reference to update
|
||||
asset_id: ID of the asset to update (for mime_type and hash)
|
||||
extract_metadata: If True, extract safetensors header and mime type
|
||||
compute_hash: If True, compute blake3 hash
|
||||
interrupt_check: Optional non-blocking callable that returns True if
|
||||
the operation should be interrupted (e.g. paused or cancelled)
|
||||
hash_checkpoints: Optional dict for saving/restoring hash progress
|
||||
across interruptions, keyed by file path
|
||||
|
||||
Returns:
|
||||
New enrichment level achieved
|
||||
"""
|
||||
new_level = ENRICHMENT_STUB
|
||||
|
||||
try:
|
||||
stat_p = os.stat(file_path, follow_symlinks=True)
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
|
||||
if extract_metadata:
|
||||
metadata = extract_file_metadata(
|
||||
file_path,
|
||||
stat_result=stat_p,
|
||||
relative_filename=rel_fname,
|
||||
)
|
||||
if metadata:
|
||||
mime_type = metadata.content_type
|
||||
new_level = ENRICHMENT_METADATA
|
||||
|
||||
full_hash: str | None = None
|
||||
if compute_hash:
|
||||
try:
|
||||
mtime_before = get_mtime_ns(stat_p)
|
||||
size_before = stat_p.st_size
|
||||
|
||||
# Restore checkpoint if available and file unchanged
|
||||
checkpoint = None
|
||||
if hash_checkpoints is not None:
|
||||
checkpoint = hash_checkpoints.get(file_path)
|
||||
if checkpoint is not None:
|
||||
cur_stat = os.stat(file_path, follow_symlinks=True)
|
||||
if (checkpoint.mtime_ns != get_mtime_ns(cur_stat)
|
||||
or checkpoint.file_size != cur_stat.st_size):
|
||||
checkpoint = None
|
||||
hash_checkpoints.pop(file_path, None)
|
||||
else:
|
||||
mtime_before = get_mtime_ns(cur_stat)
|
||||
|
||||
digest, new_checkpoint = compute_blake3_hash(
|
||||
file_path,
|
||||
interrupt_check=interrupt_check,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
if digest is None:
|
||||
# Interrupted — save checkpoint for later resumption
|
||||
if hash_checkpoints is not None and new_checkpoint is not None:
|
||||
new_checkpoint.mtime_ns = mtime_before
|
||||
new_checkpoint.file_size = size_before
|
||||
hash_checkpoints[file_path] = new_checkpoint
|
||||
return new_level
|
||||
|
||||
# Completed — clear any saved checkpoint
|
||||
if hash_checkpoints is not None:
|
||||
hash_checkpoints.pop(file_path, None)
|
||||
|
||||
stat_after = os.stat(file_path, follow_symlinks=True)
|
||||
mtime_after = get_mtime_ns(stat_after)
|
||||
if mtime_before != mtime_after:
|
||||
logging.warning("File modified during hashing, discarding hash: %s", file_path)
|
||||
else:
|
||||
full_hash = f"blake3:{digest}"
|
||||
metadata_ok = not extract_metadata or metadata is not None
|
||||
if metadata_ok:
|
||||
new_level = ENRICHMENT_HASHED
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
if extract_metadata and metadata:
|
||||
user_metadata = metadata.to_user_metadata()
|
||||
set_reference_metadata(session, reference_id, user_metadata)
|
||||
|
||||
if full_hash:
|
||||
existing = get_asset_by_hash(session, full_hash)
|
||||
if existing and existing.id != asset_id:
|
||||
reassign_asset_references(session, asset_id, existing.id, reference_id)
|
||||
delete_orphaned_seed_asset(session, asset_id)
|
||||
if mime_type:
|
||||
update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
|
||||
else:
|
||||
update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
|
||||
elif mime_type:
|
||||
update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
|
||||
|
||||
bulk_update_enrichment_level(session, [reference_id], new_level)
|
||||
session.commit()
|
||||
|
||||
return new_level
|
||||
|
||||
|
||||
def enrich_assets_batch(
|
||||
rows: list,
|
||||
extract_metadata: bool = True,
|
||||
compute_hash: bool = False,
|
||||
interrupt_check: Callable[[], bool] | None = None,
|
||||
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Enrich a batch of assets.
|
||||
|
||||
Uses a single DB session for the entire batch, committing after each
|
||||
individual asset to avoid long-held transactions while eliminating
|
||||
per-asset session creation overhead.
|
||||
|
||||
Args:
|
||||
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
|
||||
extract_metadata: If True, extract metadata for each asset
|
||||
compute_hash: If True, compute hash for each asset
|
||||
interrupt_check: Optional non-blocking callable that returns True if
|
||||
the operation should be interrupted (e.g. paused or cancelled)
|
||||
hash_checkpoints: Optional dict for saving/restoring hash progress
|
||||
across interruptions, keyed by file path
|
||||
|
||||
Returns:
|
||||
Tuple of (enriched_count, failed_reference_ids)
|
||||
"""
|
||||
enriched = 0
|
||||
failed_ids: list[str] = []
|
||||
|
||||
with create_session() as sess:
|
||||
for row in rows:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
break
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
new_level = enrich_asset(
|
||||
sess,
|
||||
file_path=row.file_path,
|
||||
reference_id=row.reference_id,
|
||||
asset_id=row.asset_id,
|
||||
extract_metadata=extract_metadata,
|
||||
compute_hash=compute_hash,
|
||||
interrupt_check=interrupt_check,
|
||||
hash_checkpoints=hash_checkpoints,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
if new_level > row.enrichment_level:
|
||||
enriched += 1
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
failed_ids.append(row.reference_id)
|
||||
except Exception as e:
|
||||
logging.warning("Failed to enrich %s: %s", row.file_path, e)
|
||||
sess.rollback()
|
||||
failed_ids.append(row.reference_id)
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
return enriched, failed_ids
|
||||
|
||||
794
app/assets/seeder.py
Normal file
794
app/assets/seeder.py
Normal file
@@ -0,0 +1,794 @@
|
||||
"""Background asset seeder with thread management and cancellation support."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
from app.assets.scanner import (
|
||||
ENRICHMENT_METADATA,
|
||||
ENRICHMENT_STUB,
|
||||
RootType,
|
||||
build_asset_specs,
|
||||
collect_paths_for_roots,
|
||||
enrich_assets_batch,
|
||||
get_all_known_prefixes,
|
||||
get_prefixes_for_root,
|
||||
get_unenriched_assets_for_roots,
|
||||
insert_asset_specs,
|
||||
mark_missing_outside_prefixes_safely,
|
||||
sync_root_safely,
|
||||
)
|
||||
from app.database.db import dependencies_available
|
||||
|
||||
|
||||
class ScanInProgressError(Exception):
|
||||
"""Raised when an operation cannot proceed because a scan is running."""
|
||||
|
||||
|
||||
class State(Enum):
|
||||
"""Seeder state machine states."""
|
||||
|
||||
IDLE = "IDLE"
|
||||
RUNNING = "RUNNING"
|
||||
PAUSED = "PAUSED"
|
||||
CANCELLING = "CANCELLING"
|
||||
|
||||
|
||||
class ScanPhase(Enum):
|
||||
"""Scan phase options."""
|
||||
|
||||
FAST = "fast" # Phase 1: filesystem only (stubs)
|
||||
ENRICH = "enrich" # Phase 2: metadata + hash
|
||||
FULL = "full" # Both phases sequentially
|
||||
|
||||
|
||||
@dataclass
|
||||
class Progress:
|
||||
"""Progress information for a scan operation."""
|
||||
|
||||
scanned: int = 0
|
||||
total: int = 0
|
||||
created: int = 0
|
||||
skipped: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanStatus:
|
||||
"""Current status of the asset seeder."""
|
||||
|
||||
state: State
|
||||
progress: Progress | None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[Progress], None]
|
||||
|
||||
|
||||
class _AssetSeeder:
|
||||
"""Background asset scanning manager.
|
||||
|
||||
Spawns ephemeral daemon threads for scanning.
|
||||
Each scan creates a new thread that exits when complete.
|
||||
Use the module-level ``asset_seeder`` instance.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
self._errors: list[str] = []
|
||||
self._thread: threading.Thread | None = None
|
||||
self._cancel_event = threading.Event()
|
||||
self._run_gate = threading.Event()
|
||||
self._run_gate.set() # Start unpaused (set = running, clear = paused)
|
||||
self._roots: tuple[RootType, ...] = ()
|
||||
self._phase: ScanPhase = ScanPhase.FULL
|
||||
self._compute_hashes: bool = False
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
self._disabled = True
|
||||
logging.info("Asset seeder disabled")
|
||||
|
||||
def is_disabled(self) -> bool:
|
||||
"""Check if the asset seeder is disabled."""
|
||||
return self._disabled
|
||||
|
||||
def start(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
phase: ScanPhase = ScanPhase.FULL,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool = False,
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start a background scan for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan (models, input, output)
|
||||
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
|
||||
progress_callback: Optional callback called with progress updates
|
||||
prune_first: If True, prune orphaned assets before scanning
|
||||
compute_hashes: If True, compute blake3 hashes (slow)
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
if self._disabled:
|
||||
logging.debug("Asset seeder is disabled, skipping start")
|
||||
return False
|
||||
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
logging.info("Asset seeder already running, skipping start")
|
||||
return False
|
||||
self._state = State.RUNNING
|
||||
self._progress = Progress()
|
||||
self._errors = []
|
||||
self._roots = roots
|
||||
self._phase = phase
|
||||
self._prune_first = prune_first
|
||||
self._compute_hashes = compute_hashes
|
||||
self._progress_callback = progress_callback
|
||||
self._cancel_event.clear()
|
||||
self._run_gate.set() # Ensure unpaused when starting
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_scan,
|
||||
name="_AssetSeeder",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
return True
|
||||
|
||||
def start_fast(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool = False,
|
||||
) -> bool:
|
||||
"""Start a fast scan (phase 1 only) - creates stub records.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
progress_callback: Optional callback for progress updates
|
||||
prune_first: If True, prune orphaned assets before scanning
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
return self.start(
|
||||
roots=roots,
|
||||
phase=ScanPhase.FAST,
|
||||
progress_callback=progress_callback,
|
||||
prune_first=prune_first,
|
||||
compute_hashes=False,
|
||||
)
|
||||
|
||||
def start_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan (phase 2 only) - extracts metadata and hashes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
progress_callback: Optional callback for progress updates
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
return self.start(
|
||||
roots=roots,
|
||||
phase=ScanPhase.ENRICH,
|
||||
progress_callback=progress_callback,
|
||||
prune_first=False,
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
Returns:
|
||||
True if cancellation was requested, False if not running or paused
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state not in (State.RUNNING, State.PAUSED):
|
||||
return False
|
||||
logging.info("Asset seeder cancelling (was %s)", self._state.value)
|
||||
self._state = State.CANCELLING
|
||||
self._cancel_event.set()
|
||||
self._run_gate.set() # Unblock if paused so thread can exit
|
||||
return True
|
||||
|
||||
def stop(self) -> bool:
|
||||
"""Stop the current scan (alias for cancel).
|
||||
|
||||
Returns:
|
||||
True if stop was requested, False if not running
|
||||
"""
|
||||
return self.cancel()
|
||||
|
||||
def pause(self) -> bool:
|
||||
"""Pause the current scan.
|
||||
|
||||
The scan will complete its current batch before pausing.
|
||||
|
||||
Returns:
|
||||
True if pause was requested, False if not running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.RUNNING:
|
||||
return False
|
||||
logging.info("Asset seeder pausing")
|
||||
self._state = State.PAUSED
|
||||
self._run_gate.clear()
|
||||
return True
|
||||
|
||||
def resume(self) -> bool:
|
||||
"""Resume a paused scan.
|
||||
|
||||
This is a noop if the scan is not in the PAUSED state
|
||||
|
||||
Returns:
|
||||
True if resumed, False if not paused
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.PAUSED:
|
||||
return False
|
||||
logging.info("Asset seeder resuming")
|
||||
self._state = State.RUNNING
|
||||
self._run_gate.set()
|
||||
self._emit_event("assets.seed.resumed", {})
|
||||
return True
|
||||
|
||||
def restart(
|
||||
self,
|
||||
roots: tuple[RootType, ...] | None = None,
|
||||
phase: ScanPhase | None = None,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool | None = None,
|
||||
compute_hashes: bool | None = None,
|
||||
timeout: float = 5.0,
|
||||
) -> bool:
|
||||
"""Cancel any running scan and start a new one.
|
||||
|
||||
Args:
|
||||
roots: Roots to scan (defaults to previous roots)
|
||||
phase: Scan phase (defaults to previous phase)
|
||||
progress_callback: Progress callback (defaults to previous)
|
||||
prune_first: Prune before scan (defaults to previous)
|
||||
compute_hashes: Compute hashes (defaults to previous)
|
||||
timeout: Max seconds to wait for current scan to stop
|
||||
|
||||
Returns:
|
||||
True if new scan was started, False if failed to stop previous
|
||||
"""
|
||||
logging.info("Asset seeder restart requested")
|
||||
with self._lock:
|
||||
prev_roots = self._roots
|
||||
prev_phase = self._phase
|
||||
prev_callback = self._progress_callback
|
||||
prev_prune = self._prune_first
|
||||
prev_hashes = self._compute_hashes
|
||||
|
||||
self.cancel()
|
||||
if not self.wait(timeout=timeout):
|
||||
return False
|
||||
|
||||
cb = progress_callback if progress_callback is not None else prev_callback
|
||||
return self.start(
|
||||
roots=roots if roots is not None else prev_roots,
|
||||
phase=phase if phase is not None else prev_phase,
|
||||
progress_callback=cb,
|
||||
prune_first=prune_first if prune_first is not None else prev_prune,
|
||||
compute_hashes=(
|
||||
compute_hashes if compute_hashes is not None else prev_hashes
|
||||
),
|
||||
)
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""Wait for the current scan to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait, or None for no timeout
|
||||
|
||||
Returns:
|
||||
True if scan completed, False if timeout expired or no scan running
|
||||
"""
|
||||
with self._lock:
|
||||
thread = self._thread
|
||||
if thread is None:
|
||||
return True
|
||||
thread.join(timeout=timeout)
|
||||
return not thread.is_alive()
|
||||
|
||||
def get_status(self) -> ScanStatus:
|
||||
"""Get the current status and progress of the seeder."""
|
||||
with self._lock:
|
||||
src = self._progress or self._last_progress
|
||||
return ScanStatus(
|
||||
state=self._state,
|
||||
progress=Progress(
|
||||
scanned=src.scanned,
|
||||
total=src.total,
|
||||
created=src.created,
|
||||
skipped=src.skipped,
|
||||
)
|
||||
if src
|
||||
else None,
|
||||
errors=list(self._errors),
|
||||
)
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown: cancel any running scan and wait for thread.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for thread to exit
|
||||
"""
|
||||
self.cancel()
|
||||
self.wait(timeout=timeout)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
def mark_missing_outside_prefixes(self) -> int:
|
||||
"""Mark references as missing when outside all known root prefixes.
|
||||
|
||||
This is a non-destructive soft-delete operation. Assets and their
|
||||
metadata are preserved, but references are flagged as missing.
|
||||
They can be restored if the file reappears in a future scan.
|
||||
|
||||
This operation is decoupled from scanning to prevent partial scans
|
||||
from accidentally marking assets belonging to other roots.
|
||||
|
||||
Should be called explicitly when cleanup is desired, typically after
|
||||
a full scan of all roots or during maintenance.
|
||||
|
||||
Returns:
|
||||
Number of references marked as missing
|
||||
|
||||
Raises:
|
||||
ScanInProgressError: If a scan is currently running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
raise ScanInProgressError(
|
||||
"Cannot mark missing assets while scan is running"
|
||||
)
|
||||
self._state = State.RUNNING
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
logging.warning(
|
||||
"Database dependencies not available, skipping mark missing"
|
||||
)
|
||||
return 0
|
||||
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d references as missing", marked)
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def _is_paused_or_cancelled(self) -> bool:
|
||||
"""Non-blocking check: True if paused or cancelled.
|
||||
|
||||
Use as interrupt_check for I/O-bound work (e.g. hashing) so that
|
||||
file handles are released immediately on pause rather than held
|
||||
open while blocked. The caller is responsible for blocking on
|
||||
_check_pause_and_cancel() afterward.
|
||||
"""
|
||||
return not self._run_gate.is_set() or self._cancel_event.is_set()
|
||||
|
||||
def _check_pause_and_cancel(self) -> bool:
|
||||
"""Block while paused, then check if cancelled.
|
||||
|
||||
Call this at checkpoint locations in scan loops. It will:
|
||||
1. Block indefinitely while paused (until resume or cancel)
|
||||
2. Return True if cancelled, False to continue
|
||||
|
||||
Returns:
|
||||
True if scan should stop, False to continue
|
||||
"""
|
||||
if not self._run_gate.is_set():
|
||||
self._emit_event("assets.seed.paused", {})
|
||||
self._run_gate.wait() # Blocks if paused
|
||||
return self._is_cancelled()
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||
"""Emit a WebSocket event if server is available."""
|
||||
try:
|
||||
from server import PromptServer
|
||||
|
||||
if hasattr(PromptServer, "instance") and PromptServer.instance:
|
||||
PromptServer.instance.send_sync(event_type, data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
scanned: int | None = None,
|
||||
total: int | None = None,
|
||||
created: int | None = None,
|
||||
skipped: int | None = None,
|
||||
) -> None:
|
||||
"""Update progress counters (thread-safe)."""
|
||||
callback: ProgressCallback | None = None
|
||||
progress: Progress | None = None
|
||||
|
||||
with self._lock:
|
||||
if self._progress is None:
|
||||
return
|
||||
if scanned is not None:
|
||||
self._progress.scanned = scanned
|
||||
if total is not None:
|
||||
self._progress.total = total
|
||||
if created is not None:
|
||||
self._progress.created = created
|
||||
if skipped is not None:
|
||||
self._progress.skipped = skipped
|
||||
if self._progress_callback:
|
||||
callback = self._progress_callback
|
||||
progress = Progress(
|
||||
scanned=self._progress.scanned,
|
||||
total=self._progress.total,
|
||||
created=self._progress.created,
|
||||
skipped=self._progress.skipped,
|
||||
)
|
||||
|
||||
if callback and progress:
|
||||
try:
|
||||
callback(progress)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_MAX_ERRORS = 200
|
||||
|
||||
def _add_error(self, message: str) -> None:
|
||||
"""Add an error message (thread-safe), capped at _MAX_ERRORS."""
|
||||
with self._lock:
|
||||
if len(self._errors) < self._MAX_ERRORS:
|
||||
self._errors.append(message)
|
||||
|
||||
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
|
||||
"""Log the directories that will be scanned."""
|
||||
import folder_paths
|
||||
|
||||
for root in roots:
|
||||
if root == "models":
|
||||
logging.info(
|
||||
"Asset scan [models] directory: %s",
|
||||
os.path.abspath(folder_paths.models_dir),
|
||||
)
|
||||
else:
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if prefixes:
|
||||
logging.info("Asset scan [%s] directories: %s", root, prefixes)
|
||||
|
||||
def _run_scan(self) -> None:
|
||||
"""Main scan loop running in background thread."""
|
||||
t_start = time.perf_counter()
|
||||
roots = self._roots
|
||||
phase = self._phase
|
||||
cancelled = False
|
||||
total_created = 0
|
||||
total_enriched = 0
|
||||
skipped_existing = 0
|
||||
total_paths = 0
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
self._add_error("Database dependencies not available")
|
||||
self._emit_event(
|
||||
"assets.seed.error",
|
||||
{"message": "Database dependencies not available"},
|
||||
)
|
||||
return
|
||||
|
||||
if self._prune_first:
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d refs as missing before scan", marked)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info("Asset scan cancelled after pruning phase")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._log_scan_config(roots)
|
||||
|
||||
# Phase 1: Fast scan (stub records)
|
||||
if phase in (ScanPhase.FAST, ScanPhase.FULL):
|
||||
created, skipped, paths = self._run_fast_phase(roots)
|
||||
total_created, skipped_existing, total_paths = created, skipped, paths
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.fast_complete",
|
||||
{
|
||||
"roots": list(roots),
|
||||
"created": total_created,
|
||||
"skipped": skipped_existing,
|
||||
"total": total_paths,
|
||||
},
|
||||
)
|
||||
|
||||
# Phase 2: Enrichment scan (metadata + hashes)
|
||||
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
|
||||
if self._check_pause_and_cancel():
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
|
||||
|
||||
if enrich_cancelled:
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.enrich_complete",
|
||||
{
|
||||
"roots": list(roots),
|
||||
"enriched": total_enriched,
|
||||
},
|
||||
)
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
logging.info(
|
||||
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
|
||||
roots,
|
||||
phase.value,
|
||||
elapsed,
|
||||
total_created,
|
||||
total_enriched,
|
||||
skipped_existing,
|
||||
)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.completed",
|
||||
{
|
||||
"phase": phase.value,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
"enriched": total_enriched,
|
||||
"skipped": skipped_existing,
|
||||
"elapsed": round(elapsed, 3),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._add_error(f"Scan failed: {e}")
|
||||
logging.exception("Asset scan failed")
|
||||
self._emit_event("assets.seed.error", {"message": str(e)})
|
||||
finally:
|
||||
if cancelled:
|
||||
self._emit_event(
|
||||
"assets.seed.cancelled",
|
||||
{
|
||||
"scanned": self._progress.scanned if self._progress else 0,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_created, skipped_existing, total_paths)
|
||||
"""
|
||||
t_fast_start = time.perf_counter()
|
||||
total_created = 0
|
||||
skipped_existing = 0
|
||||
|
||||
existing_paths: set[str] = set()
|
||||
t_sync = time.perf_counter()
|
||||
for r in roots:
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, 0
|
||||
existing_paths.update(sync_root_safely(r))
|
||||
logging.debug(
|
||||
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
|
||||
time.perf_counter() - t_sync,
|
||||
len(existing_paths),
|
||||
)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, 0
|
||||
|
||||
t_collect = time.perf_counter()
|
||||
paths = collect_paths_for_roots(roots)
|
||||
logging.debug(
|
||||
"Fast scan: collect_paths took %.3fs (%d paths found)",
|
||||
time.perf_counter() - t_collect,
|
||||
len(paths),
|
||||
)
|
||||
total_paths = len(paths)
|
||||
self._update_progress(total=total_paths)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "total": total_paths, "phase": "fast"},
|
||||
)
|
||||
|
||||
# Use stub specs (no metadata extraction, no hashing)
|
||||
t_specs = time.perf_counter()
|
||||
specs, tag_pool, skipped_existing = build_asset_specs(
|
||||
paths,
|
||||
existing_paths,
|
||||
enable_metadata_extraction=False,
|
||||
compute_hashes=False,
|
||||
)
|
||||
logging.debug(
|
||||
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
|
||||
time.perf_counter() - t_specs,
|
||||
len(specs),
|
||||
skipped_existing,
|
||||
)
|
||||
self._update_progress(skipped=skipped_existing)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
batch_size = 500
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
for i in range(0, len(specs), batch_size):
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info(
|
||||
"Fast scan cancelled after %d/%d files (created=%d)",
|
||||
i,
|
||||
len(specs),
|
||||
total_created,
|
||||
)
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
batch = specs[i : i + batch_size]
|
||||
batch_tags = {t for spec in batch for t in spec["tags"]}
|
||||
try:
|
||||
created = insert_asset_specs(batch, batch_tags)
|
||||
total_created += created
|
||||
except Exception as e:
|
||||
self._add_error(f"Batch insert failed at offset {i}: {e}")
|
||||
logging.exception("Batch insert failed at offset %d", i)
|
||||
|
||||
scanned = i + len(batch)
|
||||
now = time.perf_counter()
|
||||
self._update_progress(scanned=scanned, created=total_created)
|
||||
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{
|
||||
"phase": "fast",
|
||||
"scanned": scanned,
|
||||
"total": len(specs),
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
self._update_progress(scanned=len(specs), created=total_created)
|
||||
logging.info(
|
||||
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
|
||||
time.perf_counter() - t_fast_start,
|
||||
total_created,
|
||||
skipped_existing,
|
||||
total_paths,
|
||||
)
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
|
||||
"""Run phase 2: enrich existing records with metadata and hashes.
|
||||
|
||||
Returns:
|
||||
Tuple of (cancelled, total_enriched)
|
||||
"""
|
||||
total_enriched = 0
|
||||
batch_size = 100
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
# Get the target enrichment level based on compute_hashes
|
||||
if not self._compute_hashes:
|
||||
target_max_level = ENRICHMENT_STUB
|
||||
else:
|
||||
target_max_level = ENRICHMENT_METADATA
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "phase": "enrich"},
|
||||
)
|
||||
|
||||
skip_ids: set[str] = set()
|
||||
consecutive_empty = 0
|
||||
max_consecutive_empty = 3
|
||||
|
||||
# Hash checkpoints survive across batches so interrupted hashes
|
||||
# can be resumed without re-reading the entire file.
|
||||
hash_checkpoints: dict[str, object] = {}
|
||||
|
||||
while True:
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info("Enrich scan cancelled after %d assets", total_enriched)
|
||||
return True, total_enriched
|
||||
|
||||
# Fetch next batch of unenriched assets
|
||||
unenriched = get_unenriched_assets_for_roots(
|
||||
roots,
|
||||
max_level=target_max_level,
|
||||
limit=batch_size,
|
||||
)
|
||||
|
||||
# Filter out previously failed references
|
||||
if skip_ids:
|
||||
unenriched = [r for r in unenriched if r.reference_id not in skip_ids]
|
||||
|
||||
if not unenriched:
|
||||
break
|
||||
|
||||
enriched, failed_ids = enrich_assets_batch(
|
||||
unenriched,
|
||||
extract_metadata=True,
|
||||
compute_hash=self._compute_hashes,
|
||||
interrupt_check=self._is_paused_or_cancelled,
|
||||
hash_checkpoints=hash_checkpoints,
|
||||
)
|
||||
total_enriched += enriched
|
||||
skip_ids.update(failed_ids)
|
||||
|
||||
if enriched == 0:
|
||||
consecutive_empty += 1
|
||||
if consecutive_empty >= max_consecutive_empty:
|
||||
logging.warning(
|
||||
"Enrich phase stopping: %d consecutive batches with no progress (%d skipped)",
|
||||
consecutive_empty,
|
||||
len(skip_ids),
|
||||
)
|
||||
break
|
||||
else:
|
||||
consecutive_empty = 0
|
||||
|
||||
now = time.perf_counter()
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{
|
||||
"phase": "enrich",
|
||||
"enriched": total_enriched,
|
||||
},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
return False, total_enriched
|
||||
|
||||
|
||||
asset_seeder = _AssetSeeder()
|
||||
87
app/assets/services/__init__.py
Normal file
87
app/assets/services/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from app.assets.services.asset_management import (
|
||||
asset_exists,
|
||||
delete_asset_reference,
|
||||
get_asset_by_hash,
|
||||
get_asset_detail,
|
||||
list_assets_page,
|
||||
resolve_asset_for_download,
|
||||
set_asset_preview,
|
||||
update_asset_metadata,
|
||||
)
|
||||
from app.assets.services.bulk_ingest import (
|
||||
BulkInsertResult,
|
||||
batch_insert_seed_assets,
|
||||
cleanup_unreferenced_assets,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
get_size_and_mtime_ns,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
IngestResult,
|
||||
ListAssetsResult,
|
||||
ReferenceData,
|
||||
RegisterAssetResult,
|
||||
TagUsage,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
)
|
||||
from app.assets.services.tagging import (
|
||||
apply_tags,
|
||||
list_tags,
|
||||
remove_tags,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsResult",
|
||||
"AssetData",
|
||||
"AssetDetailResult",
|
||||
"AssetSummaryData",
|
||||
"ReferenceData",
|
||||
"BulkInsertResult",
|
||||
"DependencyMissingError",
|
||||
"DownloadResolutionResult",
|
||||
"HashMismatchError",
|
||||
"IngestResult",
|
||||
"ListAssetsResult",
|
||||
"RegisterAssetResult",
|
||||
"RemoveTagsResult",
|
||||
"TagUsage",
|
||||
"UploadResult",
|
||||
"UserMetadata",
|
||||
"apply_tags",
|
||||
"asset_exists",
|
||||
"batch_insert_seed_assets",
|
||||
"create_from_hash",
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
"list_files_recursively",
|
||||
"list_tags",
|
||||
"cleanup_unreferenced_assets",
|
||||
"remove_tags",
|
||||
"resolve_asset_for_download",
|
||||
"set_asset_preview",
|
||||
"update_asset_metadata",
|
||||
"upload_from_temp_path",
|
||||
"verify_file_unchanged",
|
||||
]
|
||||
309
app/assets/services/asset_management.py
Normal file
309
app/assets/services/asset_management.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
reference_exists_for_asset_id,
|
||||
delete_reference_by_id,
|
||||
fetch_reference_and_asset,
|
||||
soft_delete_reference_by_id,
|
||||
fetch_reference_asset_and_tags,
|
||||
get_asset_by_hash as queries_get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
list_references_page,
|
||||
list_references_by_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_tags,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_updated_at,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path
|
||||
from app.assets.services.path_utils import compute_relative_filename
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
ListAssetsResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_reference_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def get_asset_detail(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult | None:
|
||||
with create_session() as session:
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
ref, asset, tags = result
|
||||
return AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def update_asset_metadata(
|
||||
reference_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
touched = False
|
||||
if name is not None and name != ref.name:
|
||||
update_reference_name(session, reference_id=reference_id, name=name)
|
||||
touched = True
|
||||
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
|
||||
new_meta: dict | None = None
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
elif computed_filename:
|
||||
current_meta = ref.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
|
||||
if new_meta is not None:
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
set_reference_metadata(
|
||||
session, reference_id=reference_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
update_reference_updated_at(session, reference_id=reference_id)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during update")
|
||||
|
||||
ref, asset, tag_list = result
|
||||
detail = AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_list,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def delete_asset_reference(
|
||||
reference_id: str,
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
with create_session() as session:
|
||||
if not delete_content_if_orphan:
|
||||
# Soft delete: mark the reference as deleted but keep everything
|
||||
deleted = soft_delete_reference_by_id(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
session.commit()
|
||||
return deleted
|
||||
|
||||
ref_row = get_reference_by_id(session, reference_id=reference_id)
|
||||
asset_id = ref_row.asset_id if ref_row else None
|
||||
file_path = ref_row.file_path if ref_row else None
|
||||
|
||||
deleted = delete_reference_by_id(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = reference_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# Orphaned asset - delete it and its files
|
||||
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [
|
||||
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
||||
]
|
||||
# Also include the just-deleted file path
|
||||
if file_path:
|
||||
file_paths.append(file_path)
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Delete files after commit
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
reference_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
|
||||
ref, asset, tags = result
|
||||
detail = AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> ListAssetsResult:
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
AssetSummaryData(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(ref.asset),
|
||||
tags=tag_map.get(ref.id, []),
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult:
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
ref, asset = pair
|
||||
|
||||
# For references with file_path, use that directly
|
||||
if ref.file_path and os.path.isfile(ref.file_path):
|
||||
abs_path = ref.file_path
|
||||
else:
|
||||
# For API-created refs without file_path, find a path from other refs
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = select_best_live_path(refs)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError(
|
||||
f"No live path for AssetReference {reference_id} "
|
||||
f"(asset id={asset.id}, name={ref.name})"
|
||||
)
|
||||
|
||||
# Capture ORM attributes before commit (commit expires loaded objects)
|
||||
ref_name = ref.name
|
||||
asset_mime = asset.mime_type
|
||||
|
||||
update_reference_access_time(session, reference_id=reference_id)
|
||||
session.commit()
|
||||
|
||||
ctype = (
|
||||
asset_mime
|
||||
or mimetypes.guess_type(ref_name or abs_path)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
download_name = ref_name or os.path.basename(abs_path)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=download_name,
|
||||
)
|
||||
280
app/assets/services/bulk_ingest.py
Normal file
280
app/assets/services/bulk_ingest.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.queries import (
|
||||
bulk_insert_assets,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_insert_tags_and_meta,
|
||||
delete_assets_by_ids,
|
||||
get_existing_asset_ids,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
restore_references_by_paths,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.assets.services.metadata_extract import ExtractedMetadata
|
||||
|
||||
|
||||
class SeedAssetSpec(TypedDict):
|
||||
"""Spec for seeding an asset from filesystem."""
|
||||
|
||||
abs_path: str
|
||||
size_bytes: int
|
||||
mtime_ns: int
|
||||
info_name: str
|
||||
tags: list[str]
|
||||
fname: str
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
"""Row data for inserting an Asset."""
|
||||
|
||||
id: str
|
||||
hash: str | None
|
||||
size_bytes: int
|
||||
mime_type: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ReferenceRow(TypedDict):
|
||||
"""Row data for inserting an AssetReference."""
|
||||
|
||||
id: str
|
||||
asset_id: str
|
||||
file_path: str
|
||||
mtime_ns: int
|
||||
owner_id: str
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
|
||||
|
||||
class TagRow(TypedDict):
|
||||
"""Row data for inserting a Tag."""
|
||||
|
||||
asset_reference_id: str
|
||||
tag_name: str
|
||||
origin: str
|
||||
added_at: datetime
|
||||
|
||||
|
||||
class MetadataRow(TypedDict):
|
||||
"""Row data for inserting asset metadata."""
|
||||
|
||||
asset_reference_id: str
|
||||
key: str
|
||||
ordinal: int
|
||||
val_str: str | None
|
||||
val_num: float | None
|
||||
val_bool: bool | None
|
||||
val_json: dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BulkInsertResult:
|
||||
"""Result of bulk asset insertion."""
|
||||
|
||||
inserted_refs: int
|
||||
won_paths: int
|
||||
lost_paths: int
|
||||
|
||||
|
||||
def batch_insert_seed_assets(
|
||||
session: Session,
|
||||
specs: list[SeedAssetSpec],
|
||||
owner_id: str = "",
|
||||
) -> BulkInsertResult:
|
||||
"""Seed assets from filesystem specs in batch.
|
||||
|
||||
Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
|
||||
This function orchestrates:
|
||||
1. Insert seed Assets (hash=NULL)
|
||||
2. Claim references with ON CONFLICT DO NOTHING on file_path
|
||||
3. Query to find winners (paths where our asset_id was inserted)
|
||||
4. Delete Assets for losers (path already claimed by another asset)
|
||||
5. Insert tags and metadata for successfully inserted references
|
||||
|
||||
Returns:
|
||||
BulkInsertResult with inserted_refs, won_paths, lost_paths
|
||||
"""
|
||||
if not specs:
|
||||
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
|
||||
|
||||
current_time = get_utc_now()
|
||||
asset_rows: list[AssetRow] = []
|
||||
reference_rows: list[ReferenceRow] = []
|
||||
path_to_asset_id: dict[str, str] = {}
|
||||
asset_id_to_ref_data: dict[str, dict] = {}
|
||||
absolute_path_list: list[str] = []
|
||||
|
||||
for spec in specs:
|
||||
absolute_path = os.path.abspath(spec["abs_path"])
|
||||
asset_id = str(uuid.uuid4())
|
||||
reference_id = str(uuid.uuid4())
|
||||
absolute_path_list.append(absolute_path)
|
||||
path_to_asset_id[absolute_path] = asset_id
|
||||
|
||||
mime_type = spec.get("mime_type")
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": asset_id,
|
||||
"hash": spec.get("hash"),
|
||||
"size_bytes": spec["size_bytes"],
|
||||
"mime_type": mime_type,
|
||||
"created_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Build user_metadata from extracted metadata or fallback to filename
|
||||
extracted_metadata = spec.get("metadata")
|
||||
if extracted_metadata:
|
||||
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
|
||||
elif spec["fname"]:
|
||||
user_metadata = {"filename": spec["fname"]}
|
||||
else:
|
||||
user_metadata = None
|
||||
|
||||
reference_rows.append(
|
||||
{
|
||||
"id": reference_id,
|
||||
"asset_id": asset_id,
|
||||
"file_path": absolute_path,
|
||||
"mtime_ns": spec["mtime_ns"],
|
||||
"owner_id": owner_id,
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
asset_id_to_ref_data[asset_id] = {
|
||||
"reference_id": reference_id,
|
||||
"tags": spec["tags"],
|
||||
"filename": spec["fname"],
|
||||
"extracted_metadata": extracted_metadata,
|
||||
}
|
||||
|
||||
bulk_insert_assets(session, asset_rows)
|
||||
|
||||
# Filter reference rows to only those whose assets were actually inserted
|
||||
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
|
||||
inserted_asset_ids = get_existing_asset_ids(
|
||||
session, [r["asset_id"] for r in reference_rows]
|
||||
)
|
||||
reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids]
|
||||
|
||||
bulk_insert_references_ignore_conflicts(session, reference_rows)
|
||||
restore_references_by_paths(session, absolute_path_list)
|
||||
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
|
||||
|
||||
inserted_paths = {
|
||||
path
|
||||
for path in absolute_path_list
|
||||
if path_to_asset_id[path] in inserted_asset_ids
|
||||
}
|
||||
losing_paths = inserted_paths - winning_paths
|
||||
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
|
||||
|
||||
if lost_asset_ids:
|
||||
delete_assets_by_ids(session, lost_asset_ids)
|
||||
|
||||
if not winning_paths:
|
||||
return BulkInsertResult(
|
||||
inserted_refs=0,
|
||||
won_paths=0,
|
||||
lost_paths=len(losing_paths),
|
||||
)
|
||||
|
||||
# Get reference IDs for winners
|
||||
winning_ref_ids = [
|
||||
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
|
||||
for path in winning_paths
|
||||
]
|
||||
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
|
||||
|
||||
tag_rows: list[TagRow] = []
|
||||
metadata_rows: list[MetadataRow] = []
|
||||
|
||||
if inserted_ref_ids:
|
||||
for path in winning_paths:
|
||||
asset_id = path_to_asset_id[path]
|
||||
ref_data = asset_id_to_ref_data[asset_id]
|
||||
ref_id = ref_data["reference_id"]
|
||||
|
||||
if ref_id not in inserted_ref_ids:
|
||||
continue
|
||||
|
||||
for tag in ref_data["tags"]:
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Use extracted metadata for meta rows if available
|
||||
extracted_metadata = ref_data.get("extracted_metadata")
|
||||
if extracted_metadata:
|
||||
metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id))
|
||||
elif ref_data["filename"]:
|
||||
# Fallback: just store filename
|
||||
metadata_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": ref_data["filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
|
||||
|
||||
return BulkInsertResult(
|
||||
inserted_refs=len(inserted_ref_ids),
|
||||
won_paths=len(winning_paths),
|
||||
lost_paths=len(losing_paths),
|
||||
)
|
||||
|
||||
|
||||
def cleanup_unreferenced_assets(session: Session) -> int:
|
||||
"""Hard-delete unhashed assets with no active references.
|
||||
|
||||
This is a destructive operation intended for explicit cleanup.
|
||||
Only deletes assets where hash=None and all references are missing.
|
||||
|
||||
Returns:
|
||||
Number of assets deleted
|
||||
"""
|
||||
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
|
||||
return delete_assets_by_ids(session, unreferenced_ids)
|
||||
70
app/assets/services/file_utils.py
Normal file
70
app/assets/services/file_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
|
||||
|
||||
def get_mtime_ns(stat_result: os.stat_result) -> int:
|
||||
"""Extract mtime in nanoseconds from a stat result."""
|
||||
return getattr(
|
||||
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
|
||||
)
|
||||
|
||||
|
||||
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
|
||||
"""Get file size in bytes and mtime in nanoseconds."""
|
||||
st = os.stat(path, follow_symlinks=follow_symlinks)
|
||||
return st.st_size, get_mtime_ns(st)
|
||||
|
||||
|
||||
def verify_file_unchanged(
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
"""Check if a file is unchanged based on mtime and size.
|
||||
|
||||
Returns True if the file's mtime and size match the database values.
|
||||
Returns False if mtime_db is None or values don't match.
|
||||
|
||||
size_db=None means don't check size; 0 is a valid recorded size.
|
||||
"""
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = get_mtime_ns(stat_result)
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
if size_db is not None:
|
||||
return int(stat_result.st_size) == int(size_db)
|
||||
return True
|
||||
|
||||
|
||||
def is_visible(name: str) -> bool:
|
||||
"""Return True if a file or directory name is visible (not hidden)."""
|
||||
return not name.startswith(".")
|
||||
|
||||
|
||||
def list_files_recursively(base_dir: str) -> list[str]:
|
||||
"""Recursively list all files in a directory, following symlinks."""
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
# Track seen real directory identities to prevent circular symlink loops
|
||||
seen_dirs: set[tuple[int, int]] = set()
|
||||
for dirpath, subdirs, filenames in os.walk(
|
||||
base_abs, topdown=True, followlinks=True
|
||||
):
|
||||
try:
|
||||
st = os.stat(dirpath)
|
||||
dir_id = (st.st_dev, st.st_ino)
|
||||
except OSError:
|
||||
subdirs.clear()
|
||||
continue
|
||||
if dir_id in seen_dirs:
|
||||
subdirs.clear()
|
||||
continue
|
||||
seen_dirs.add(dir_id)
|
||||
subdirs[:] = [d for d in subdirs if is_visible(d)]
|
||||
for name in filenames:
|
||||
if not is_visible(name):
|
||||
continue
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
99
app/assets/services/hashing.py
Normal file
99
app/assets/services/hashing.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import io
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import IO, Any, Callable, Iterator
|
||||
import logging
|
||||
|
||||
try:
|
||||
from blake3 import blake3
|
||||
except ModuleNotFoundError:
|
||||
logging.warning("WARNING: blake3 package not installed")
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
InterruptCheck = Callable[[], bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HashCheckpoint:
|
||||
"""Saved state for resuming an interrupted hash computation."""
|
||||
|
||||
bytes_processed: int
|
||||
hasher: Any # blake3 hasher instance
|
||||
mtime_ns: int = 0
|
||||
file_size: int = 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]:
|
||||
"""Yield (file_object, is_path) with appropriate setup/teardown."""
|
||||
if hasattr(fp, "read"):
|
||||
seekable = getattr(fp, "seekable", lambda: False)()
|
||||
orig_pos = None
|
||||
if seekable:
|
||||
try:
|
||||
orig_pos = fp.tell()
|
||||
if orig_pos != 0:
|
||||
fp.seek(0)
|
||||
except io.UnsupportedOperation:
|
||||
orig_pos = None
|
||||
try:
|
||||
yield fp, False
|
||||
finally:
|
||||
if orig_pos is not None:
|
||||
fp.seek(orig_pos)
|
||||
else:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
yield f, True
|
||||
|
||||
|
||||
def compute_blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
interrupt_check: InterruptCheck | None = None,
|
||||
checkpoint: HashCheckpoint | None = None,
|
||||
) -> tuple[str | None, HashCheckpoint | None]:
|
||||
"""Compute BLAKE3 hash of a file, with optional checkpoint support.
|
||||
|
||||
Args:
|
||||
fp: File path or file-like object
|
||||
chunk_size: Size of chunks to read at a time
|
||||
interrupt_check: Optional callable that returns True if the operation
|
||||
should be interrupted (e.g. paused or cancelled). Must be
|
||||
non-blocking so file handles are released immediately. Checked
|
||||
between chunk reads.
|
||||
checkpoint: Optional checkpoint to resume from (file paths only)
|
||||
|
||||
Returns:
|
||||
Tuple of (hex_digest, None) on completion, or
|
||||
(None, checkpoint) on interruption (file paths only), or
|
||||
(None, None) on interruption of a file object
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
with _open_for_hashing(fp) as (f, is_path):
|
||||
if checkpoint is not None and is_path:
|
||||
f.seek(checkpoint.bytes_processed)
|
||||
h = checkpoint.hasher
|
||||
bytes_processed = checkpoint.bytes_processed
|
||||
else:
|
||||
h = blake3()
|
||||
bytes_processed = 0
|
||||
|
||||
while True:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
if is_path:
|
||||
return None, HashCheckpoint(
|
||||
bytes_processed=bytes_processed,
|
||||
hasher=h,
|
||||
)
|
||||
return None, None
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
bytes_processed += len(chunk)
|
||||
|
||||
return h.hexdigest(), None
|
||||
375
app/assets/services/ingest.py
Normal file
375
app/assets/services/ingest.py
Normal file
@@ -0,0 +1,375 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_tags,
|
||||
upsert_asset,
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
IngestResult,
|
||||
RegisterAssetResult,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_reference_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def _ingest_file_from_path(
|
||||
abs_path: str,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> IngestResult:
|
||||
locator = os.path.abspath(abs_path)
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
asset_created = False
|
||||
asset_updated = False
|
||||
ref_created = False
|
||||
ref_updated = False
|
||||
reference_id: str | None = None
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
ref_created, ref_updated = upsert_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
file_path=locator,
|
||||
name=info_name or os.path.basename(locator),
|
||||
mtime_ns=mtime_ns,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
# Get the reference we just created/updated
|
||||
ref = get_reference_by_file_path(session, locator)
|
||||
if ref:
|
||||
reference_id = ref.id
|
||||
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
norm = normalize_tags(list(tags))
|
||||
if norm:
|
||||
if require_existing_tags:
|
||||
validate_tags_exist(session, norm)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=norm,
|
||||
origin=tag_origin,
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
|
||||
_update_metadata_with_filename(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
file_path=ref.file_path,
|
||||
current_metadata=ref.user_metadata,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return IngestResult(
|
||||
asset_created=asset_created,
|
||||
asset_updated=asset_updated,
|
||||
ref_created=ref_created,
|
||||
ref_updated=ref_updated,
|
||||
reference_id=reference_id,
|
||||
)
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"No asset with hash {asset_hash}")
|
||||
|
||||
ref, ref_created = get_or_create_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
)
|
||||
|
||||
if not ref_created:
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=False,
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
new_meta = dict(user_metadata)
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta:
|
||||
set_reference_metadata(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
session.refresh(ref)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def _update_metadata_with_filename(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
file_path: str | None,
|
||||
current_metadata: dict | None,
|
||||
user_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
computed_filename = compute_relative_filename(file_path) if file_path else None
|
||||
|
||||
current_meta = current_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
set_reference_metadata(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
return n if n else fallback
|
||||
|
||||
|
||||
class HashMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DependencyMissingError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def upload_from_temp_path(
|
||||
temp_path: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_hash: str | None = None,
|
||||
) -> UploadResult:
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_hash and asset_hash != expected_hash.strip().lower():
|
||||
raise HashMismatchError("Uploaded file hash does not match provided hash.")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _sanitize_filename(name or client_filename, fallback=digest)
|
||||
result = _register_existing_asset(
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
if not tags:
|
||||
raise ValueError("tags are required for new asset uploads")
|
||||
base_dir, subdirs = resolve_destination_from_tags(tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
validate_path_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
reference_id = ingest_result.reference_id
|
||||
if not reference_id:
|
||||
raise RuntimeError("failed to create asset reference")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
ref, asset = pair
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
|
||||
return UploadResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def create_from_hash(
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> UploadResult | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
327
app/assets/services/metadata_extract.py
Normal file
327
app/assets/services/metadata_extract.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Metadata extraction for asset scanning.
|
||||
|
||||
Tier 1: Filesystem metadata (zero parsing)
|
||||
Tier 2: Safetensors header metadata (fast JSON read only)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from utils.mime_types import init_mime_types
|
||||
|
||||
init_mime_types()
|
||||
|
||||
# Supported safetensors extensions
|
||||
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
||||
|
||||
# Maximum safetensors header size to read (8MB)
|
||||
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedMetadata:
|
||||
"""Metadata extracted from a file during scanning."""
|
||||
|
||||
# Tier 1: Filesystem (always available)
|
||||
filename: str = ""
|
||||
file_path: str = "" # Full absolute path to the file
|
||||
content_length: int = 0
|
||||
content_type: str | None = None
|
||||
format: str = "" # file extension without dot
|
||||
|
||||
# Tier 2: Safetensors header (if available)
|
||||
base_model: str | None = None
|
||||
trained_words: list[str] | None = None
|
||||
air: str | None = None # CivitAI AIR identifier
|
||||
has_preview_images: bool = False
|
||||
|
||||
# Source provenance (populated if embedded in safetensors)
|
||||
source_url: str | None = None
|
||||
source_arn: str | None = None
|
||||
repo_url: str | None = None
|
||||
preview_url: str | None = None
|
||||
source_hash: str | None = None
|
||||
|
||||
# HuggingFace specific
|
||||
repo_id: str | None = None
|
||||
revision: str | None = None
|
||||
filepath: str | None = None
|
||||
resolve_url: str | None = None
|
||||
|
||||
def to_user_metadata(self) -> dict[str, Any]:
|
||||
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
|
||||
data: dict[str, Any] = {
|
||||
"filename": self.filename,
|
||||
"content_length": self.content_length,
|
||||
"format": self.format,
|
||||
}
|
||||
if self.file_path:
|
||||
data["file_path"] = self.file_path
|
||||
if self.content_type:
|
||||
data["content_type"] = self.content_type
|
||||
|
||||
# Tier 2 fields
|
||||
if self.base_model:
|
||||
data["base_model"] = self.base_model
|
||||
if self.trained_words:
|
||||
data["trained_words"] = self.trained_words
|
||||
if self.air:
|
||||
data["air"] = self.air
|
||||
if self.has_preview_images:
|
||||
data["has_preview_images"] = True
|
||||
|
||||
# Source provenance
|
||||
if self.source_url:
|
||||
data["source_url"] = self.source_url
|
||||
if self.source_arn:
|
||||
data["source_arn"] = self.source_arn
|
||||
if self.repo_url:
|
||||
data["repo_url"] = self.repo_url
|
||||
if self.preview_url:
|
||||
data["preview_url"] = self.preview_url
|
||||
if self.source_hash:
|
||||
data["source_hash"] = self.source_hash
|
||||
|
||||
# HuggingFace
|
||||
if self.repo_id:
|
||||
data["repo_id"] = self.repo_id
|
||||
if self.revision:
|
||||
data["revision"] = self.revision
|
||||
if self.filepath:
|
||||
data["filepath"] = self.filepath
|
||||
if self.resolve_url:
|
||||
data["resolve_url"] = self.resolve_url
|
||||
|
||||
return data
|
||||
|
||||
def to_meta_rows(self, reference_id: str) -> list[dict]:
|
||||
"""Convert to asset_reference_meta rows for typed/indexed querying."""
|
||||
rows: list[dict] = []
|
||||
|
||||
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
|
||||
if val:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": val[:2048] if len(val) > 2048 else val,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_num(key: str, val: int | float | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": val,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_bool(key: str, val: bool | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": val,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
# Tier 1
|
||||
add_str("filename", self.filename)
|
||||
add_num("content_length", self.content_length)
|
||||
add_str("content_type", self.content_type)
|
||||
add_str("format", self.format)
|
||||
|
||||
# Tier 2
|
||||
add_str("base_model", self.base_model)
|
||||
add_str("air", self.air)
|
||||
has_previews = self.has_preview_images if self.has_preview_images else None
|
||||
add_bool("has_preview_images", has_previews)
|
||||
|
||||
# trained_words as multiple rows with ordinals
|
||||
if self.trained_words:
|
||||
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
|
||||
add_str("trained_words", word, ordinal=i)
|
||||
|
||||
# Source provenance
|
||||
add_str("source_url", self.source_url)
|
||||
add_str("source_arn", self.source_arn)
|
||||
add_str("repo_url", self.repo_url)
|
||||
add_str("preview_url", self.preview_url)
|
||||
add_str("source_hash", self.source_hash)
|
||||
|
||||
# HuggingFace
|
||||
add_str("repo_id", self.repo_id)
|
||||
add_str("revision", self.revision)
|
||||
add_str("filepath", self.filepath)
|
||||
add_str("resolve_url", self.resolve_url)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _read_safetensors_header(
|
||||
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
|
||||
) -> dict[str, Any] | None:
|
||||
"""Read only the JSON header from a safetensors file.
|
||||
|
||||
This is very fast - reads 8 bytes for header length, then the JSON header.
|
||||
No tensor data is loaded.
|
||||
|
||||
Args:
|
||||
path: Absolute path to safetensors file
|
||||
max_size: Maximum header size to read (default 8MB)
|
||||
|
||||
Returns:
|
||||
Parsed header dict or None if failed
|
||||
"""
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
header_bytes = f.read(8)
|
||||
if len(header_bytes) < 8:
|
||||
return None
|
||||
length_of_header = struct.unpack("<Q", header_bytes)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
header_data = f.read(length_of_header)
|
||||
if len(header_data) < length_of_header:
|
||||
return None
|
||||
return json.loads(header_data.decode("utf-8"))
|
||||
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_safetensors_metadata(
|
||||
header: dict[str, Any], meta: ExtractedMetadata
|
||||
) -> None:
|
||||
"""Extract metadata from safetensors header __metadata__ section.
|
||||
|
||||
Modifies meta in-place.
|
||||
"""
|
||||
st_meta = header.get("__metadata__", {})
|
||||
if not isinstance(st_meta, dict):
|
||||
return
|
||||
|
||||
# Common model metadata
|
||||
meta.base_model = (
|
||||
st_meta.get("ss_base_model_version")
|
||||
or st_meta.get("modelspec.base_model")
|
||||
or st_meta.get("base_model")
|
||||
)
|
||||
|
||||
# Trained words / trigger words
|
||||
trained_words = st_meta.get("ss_tag_frequency")
|
||||
if trained_words and isinstance(trained_words, str):
|
||||
try:
|
||||
tag_freq = json.loads(trained_words)
|
||||
# Extract unique tags from all datasets
|
||||
all_tags: set[str] = set()
|
||||
for dataset_tags in tag_freq.values():
|
||||
if isinstance(dataset_tags, dict):
|
||||
all_tags.update(dataset_tags.keys())
|
||||
if all_tags:
|
||||
meta.trained_words = sorted(all_tags)[:100]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Direct trained_words field (some formats)
|
||||
if not meta.trained_words:
|
||||
tw = st_meta.get("trained_words")
|
||||
if isinstance(tw, str):
|
||||
try:
|
||||
parsed = json.loads(tw)
|
||||
if isinstance(parsed, list):
|
||||
meta.trained_words = [str(x) for x in parsed]
|
||||
else:
|
||||
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
||||
except json.JSONDecodeError:
|
||||
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
||||
elif isinstance(tw, list):
|
||||
meta.trained_words = [str(x) for x in tw]
|
||||
|
||||
# CivitAI AIR
|
||||
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
|
||||
|
||||
# Preview images (ssmd_cover_images)
|
||||
cover_images = st_meta.get("ssmd_cover_images")
|
||||
if cover_images:
|
||||
meta.has_preview_images = True
|
||||
|
||||
# Source provenance fields
|
||||
meta.source_url = st_meta.get("source_url")
|
||||
meta.source_arn = st_meta.get("source_arn")
|
||||
meta.repo_url = st_meta.get("repo_url")
|
||||
meta.preview_url = st_meta.get("preview_url")
|
||||
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
|
||||
|
||||
# HuggingFace fields
|
||||
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
|
||||
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
|
||||
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
|
||||
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
|
||||
|
||||
|
||||
def extract_file_metadata(
|
||||
abs_path: str,
|
||||
stat_result: os.stat_result | None = None,
|
||||
relative_filename: str | None = None,
|
||||
) -> ExtractedMetadata:
|
||||
"""Extract metadata from a file using tier 1 and tier 2 methods.
|
||||
|
||||
Tier 1: Filesystem metadata from path and stat
|
||||
Tier 2: Safetensors header parsing if applicable
|
||||
|
||||
Args:
|
||||
abs_path: Absolute path to the file
|
||||
stat_result: Optional pre-fetched stat result (saves a syscall)
|
||||
relative_filename: Optional relative filename to use instead of basename
|
||||
(e.g., "flux/123/model.safetensors" for model paths)
|
||||
|
||||
Returns:
|
||||
ExtractedMetadata with all available fields populated
|
||||
"""
|
||||
meta = ExtractedMetadata()
|
||||
|
||||
# Tier 1: Filesystem metadata
|
||||
meta.filename = relative_filename or os.path.basename(abs_path)
|
||||
meta.file_path = abs_path
|
||||
_, ext = os.path.splitext(abs_path)
|
||||
meta.format = ext.lstrip(".").lower() if ext else ""
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(abs_path)
|
||||
meta.content_type = mime_type
|
||||
|
||||
# Size from stat
|
||||
if stat_result is None:
|
||||
try:
|
||||
stat_result = os.stat(abs_path, follow_symlinks=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if stat_result:
|
||||
meta.content_length = stat_result.st_size
|
||||
|
||||
# Tier 2: Safetensors header (if applicable and enabled)
|
||||
if ext.lower() in SAFETENSORS_EXTENSIONS:
|
||||
header = _read_safetensors_header(abs_path)
|
||||
if header:
|
||||
try:
|
||||
_extract_safetensors_metadata(header, meta)
|
||||
except Exception as e:
|
||||
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
|
||||
|
||||
return meta
|
||||
167
app/assets/services/path_utils.py
Normal file
167
app/assets/services/path_utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build list of (folder_name, base_paths[]) for all model locations.
|
||||
|
||||
Includes every category registered in folder_names_and_paths,
|
||||
regardless of whether its paths are under the main models_dir,
|
||||
but excludes non-model entries like custom_nodes.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
if name in _NON_MODEL_FOLDER_NAMES:
|
||||
continue
|
||||
paths, _exts = values[0], values[1]
|
||||
if paths:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
for i in raw_subdirs:
|
||||
if i in (".", "..") or _sep_chars & set(i):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = Path(os.path.abspath(candidate))
|
||||
base_abs = Path(os.path.abspath(base))
|
||||
if not cand_abs.is_relative_to(base_abs):
|
||||
raise ValueError("destination escapes base directory")
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_asset_category_and_relative_path(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
- 'input': under folder_paths.get_input_directory()
|
||||
- 'output': under folder_paths.get_output_directory()
|
||||
- 'models': under any base path from get_comfy_models_folders()
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _check_is_within(child: str, parent: str) -> bool:
|
||||
return Path(child).is_relative_to(parent)
|
||||
|
||||
def _compute_relative(child: str, parent: str) -> str:
|
||||
# Normalize relative path, stripping any leading ".." components
|
||||
# by anchoring to root (os.sep) then computing relpath back from it.
|
||||
return os.path.relpath(
|
||||
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
|
||||
)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _check_is_within(fp_abs, input_base):
|
||||
return "input", _compute_relative(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _check_is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: [root_category] + parent folder names in order
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
root_category, some_path = get_asset_category_and_relative_path(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
109
app/assets/services/schemas.py
Normal file
109
app/assets/services/schemas.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
|
||||
UserMetadata = dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetData:
|
||||
hash: str | None
|
||||
size_bytes: int | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReferenceData:
|
||||
"""Data transfer object for AssetReference."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
file_path: str | None
|
||||
user_metadata: UserMetadata
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetDetailResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterAssetResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IngestResult:
|
||||
asset_created: bool
|
||||
asset_updated: bool
|
||||
ref_created: bool
|
||||
ref_updated: bool
|
||||
reference_id: str | None
|
||||
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetSummaryData:
|
||||
ref: ReferenceData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DownloadResolutionResult:
|
||||
abs_path: str
|
||||
content_type: str
|
||||
download_name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created_new: bool
|
||||
|
||||
|
||||
def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
||||
return ReferenceData(
|
||||
id=ref.id,
|
||||
name=ref.name,
|
||||
file_path=ref.file_path,
|
||||
user_metadata=ref.user_metadata,
|
||||
preview_id=ref.preview_id,
|
||||
created_at=ref.created_at,
|
||||
updated_at=ref.updated_at,
|
||||
last_access_time=ref.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def extract_asset_data(asset: Asset | None) -> AssetData | None:
|
||||
if asset is None:
|
||||
return None
|
||||
return AssetData(
|
||||
hash=asset.hash,
|
||||
size_bytes=asset.size_bytes,
|
||||
mime_type=asset.mime_type,
|
||||
)
|
||||
75
app/assets/services/tagging.py
Normal file
75
app/assets/services/tagging.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
add_tags_to_reference,
|
||||
get_reference_with_owner_check,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_reference,
|
||||
)
|
||||
from app.assets.services.schemas import TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def apply_tags(
|
||||
reference_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AddTagsResult:
|
||||
with create_session() as session:
|
||||
ref_row = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
result = add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
reference_row=ref_row,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remove_tags(
|
||||
reference_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> RemoveTagsResult:
|
||||
with create_session() as session:
|
||||
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
result = remove_tags_from_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[TagUsage], int]:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from filelock import FileLock, Timeout
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
@@ -14,8 +15,12 @@ try:
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database.models import Base
|
||||
import app.assets.database.models # noqa: F401 — register models with Base.metadata
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
@@ -65,9 +70,69 @@ def get_db_path():
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
_db_lock = None
|
||||
|
||||
def _acquire_file_lock(db_path):
|
||||
"""Acquire an OS-level file lock to prevent multi-process access.
|
||||
|
||||
Uses filelock for cross-platform support (macOS, Linux, Windows).
|
||||
The OS automatically releases the lock when the process exits, even on crashes.
|
||||
"""
|
||||
global _db_lock
|
||||
lock_path = db_path + ".lock"
|
||||
_db_lock = FileLock(lock_path)
|
||||
try:
|
||||
_db_lock.acquire(timeout=0)
|
||||
except Timeout:
|
||||
raise RuntimeError(
|
||||
f"Could not acquire lock on database '{db_path}'. "
|
||||
"Another ComfyUI process may already be using it. "
|
||||
"Use --database-url to specify a separate database file."
|
||||
)
|
||||
|
||||
|
||||
def _is_memory_db(db_url):
|
||||
"""Check if the database URL refers to an in-memory SQLite database."""
|
||||
return db_url in ("sqlite:///:memory:", "sqlite://")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
|
||||
if _is_memory_db(db_url):
|
||||
_init_memory_db(db_url)
|
||||
else:
|
||||
_init_file_db(db_url)
|
||||
|
||||
|
||||
def _init_memory_db(db_url):
|
||||
"""Initialize an in-memory SQLite database using metadata.create_all.
|
||||
|
||||
Alembic migrations don't work with in-memory SQLite because each
|
||||
connection gets its own separate database — tables created by Alembic's
|
||||
internal connection are lost immediately.
|
||||
"""
|
||||
engine = create_engine(
|
||||
db_url,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def _init_file_db(db_url):
|
||||
"""Initialize a file-backed SQLite database using Alembic migrations."""
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
@@ -75,6 +140,14 @@ def init_db():
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
|
||||
# Enable foreign key enforcement for SQLite
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
@@ -104,6 +177,12 @@ def init_db():
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
# Acquire an OS-level file lock after migrations are complete.
|
||||
# Alembic uses its own connection, so we must wait until it's done
|
||||
# before locking — otherwise our own lock blocks the migration.
|
||||
conn.close()
|
||||
_acquire_file_lock(db_path)
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class AudioEncoderModel():
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
comfy.model_management.archive_model_dtypes(self.model)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
@@ -179,6 +179,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
@@ -232,7 +234,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.cli_args import args
|
||||
import uuid
|
||||
import os
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
@@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
def __init__(self):
|
||||
if _ISOLATION_HOOKREF_MODE:
|
||||
self._pyisolate_id = str(uuid.uuid4())
|
||||
|
||||
def _ensure_pyisolate_id(self):
|
||||
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||
if pyisolate_id is None:
|
||||
pyisolate_id = str(uuid.uuid4())
|
||||
self._pyisolate_id = pyisolate_id
|
||||
return pyisolate_id
|
||||
|
||||
def __eq__(self, other):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return self is other
|
||||
if not isinstance(other, _HookRef):
|
||||
return False
|
||||
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||
|
||||
def __hash__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return id(self)
|
||||
return hash(self._ensure_pyisolate_id())
|
||||
|
||||
def __str__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return super().__str__()
|
||||
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
@@ -168,6 +200,8 @@ class WeightHook(Hook):
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if self.weights is None:
|
||||
self.weights = {}
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
|
||||
436
comfy/isolation/__init__.py
Normal file
436
comfy/isolation/__init__.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||
|
||||
load_isolated_node = None
|
||||
find_manifest_directories = None
|
||||
build_stub_class = None
|
||||
get_class_types_for_extension = None
|
||||
scan_shm_forensics = None
|
||||
start_shm_forensics = None
|
||||
|
||||
if _IMPORT_TORCH:
|
||||
import folder_paths
|
||||
from .extension_loader import load_isolated_node
|
||||
from .manifest_loader import find_manifest_directories
|
||||
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
||||
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyisolate import ExtensionManager
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
isolated_node_timings: List[tuple[float, Path, int]] = []
|
||||
|
||||
if _IMPORT_TORCH:
|
||||
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
||||
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
from .child_hooks import is_child_process
|
||||
|
||||
is_child = is_child_process()
|
||||
|
||||
if is_child:
|
||||
from .child_hooks import initialize_child_process
|
||||
|
||||
initialize_child_process()
|
||||
else:
|
||||
from .host_hooks import initialize_host_process
|
||||
|
||||
initialize_host_process()
|
||||
if start_shm_forensics is not None:
|
||||
start_shm_forensics()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IsolatedNodeSpec:
|
||||
node_name: str
|
||||
display_name: str
|
||||
stub_class: type
|
||||
module_path: Path
|
||||
|
||||
|
||||
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
|
||||
_CLAIMED_PATHS: Set[Path] = set()
|
||||
_ISOLATION_SCAN_ATTEMPTED = False
|
||||
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
|
||||
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
|
||||
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
|
||||
_EARLY_START_TIME: Optional[float] = None
|
||||
|
||||
|
||||
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
return
|
||||
_EARLY_START_TIME = time.perf_counter()
|
||||
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
|
||||
|
||||
|
||||
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
specs = await _ISOLATION_BACKGROUND_TASK
|
||||
return specs
|
||||
return await initialize_isolation_nodes()
|
||||
|
||||
|
||||
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
|
||||
|
||||
if _ISOLATED_NODE_SPECS:
|
||||
return _ISOLATED_NODE_SPECS
|
||||
|
||||
if _ISOLATION_SCAN_ATTEMPTED:
|
||||
return []
|
||||
|
||||
_ISOLATION_SCAN_ATTEMPTED = True
|
||||
if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None:
|
||||
return []
|
||||
manifest_entries = find_manifest_directories()
|
||||
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
||||
|
||||
if not manifest_entries:
|
||||
return []
|
||||
|
||||
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
|
||||
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def load_with_semaphore(
|
||||
node_dir: Path, manifest: Path
|
||||
) -> List[IsolatedNodeSpec]:
|
||||
async with semaphore:
|
||||
load_start = time.perf_counter()
|
||||
spec_list = await load_isolated_node(
|
||||
node_dir,
|
||||
manifest,
|
||||
logger,
|
||||
lambda name, info, extension: build_stub_class(
|
||||
name,
|
||||
info,
|
||||
extension,
|
||||
_RUNNING_EXTENSIONS,
|
||||
logger,
|
||||
),
|
||||
PYISOLATE_VENV_ROOT,
|
||||
_EXTENSION_MANAGERS,
|
||||
)
|
||||
spec_list = [
|
||||
IsolatedNodeSpec(
|
||||
node_name=node_name,
|
||||
display_name=display_name,
|
||||
stub_class=stub_cls,
|
||||
module_path=node_dir,
|
||||
)
|
||||
for node_name, display_name, stub_cls in spec_list
|
||||
]
|
||||
isolated_node_timings.append(
|
||||
(time.perf_counter() - load_start, node_dir, len(spec_list))
|
||||
)
|
||||
return spec_list
|
||||
|
||||
tasks = [
|
||||
load_with_semaphore(node_dir, manifest)
|
||||
for node_dir, manifest in manifest_entries
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
specs: List[IsolatedNodeSpec] = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"%s Isolated node failed during startup; continuing: %s",
|
||||
LOG_PREFIX,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
specs.extend(result)
|
||||
|
||||
_ISOLATED_NODE_SPECS = specs
|
||||
return list(_ISOLATED_NODE_SPECS)
|
||||
|
||||
|
||||
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
"""Get all node class types (node names) belonging to an extension."""
|
||||
extension = _RUNNING_EXTENSIONS.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in _ISOLATED_NODE_SPECS:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
|
||||
return class_types
|
||||
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None:
|
||||
"""Evict running extensions not needed for current execution.
|
||||
|
||||
When *caches* is provided, cache entries for evicted extensions' node
|
||||
class_types are invalidated to prevent stale ``RemoteObjectHandle``
|
||||
references from surviving in the output cache.
|
||||
"""
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:notify_graph_wait_idle",
|
||||
)
|
||||
|
||||
evicted_class_types: Set[str] = set()
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
) -> None:
|
||||
# Collect class_types BEFORE stopping so we can invalidate cache entries.
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
evicted_class_types.update(ext_class_types)
|
||||
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
||||
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
||||
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
||||
if scan_shm_forensics is not None:
|
||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||
|
||||
if scan_shm_forensics is not None:
|
||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||
isolated_class_types_in_graph = needed_class_types.intersection(
|
||||
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
||||
)
|
||||
graph_uses_isolation = bool(isolated_class_types_in_graph)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
len(needed_class_types),
|
||||
)
|
||||
if graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
|
||||
# If NONE of this extension's nodes are in the execution graph -> evict.
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
)
|
||||
|
||||
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if getattr(device, "type", None) == "cuda":
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
free_before = model_management.get_free_memory(device)
|
||||
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.unload_all_models()
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
finally:
|
||||
# Invalidate cached outputs for evicted extensions so stale
|
||||
# RemoteObjectHandle references are not served from cache.
|
||||
if evicted_class_types and caches:
|
||||
total_invalidated = 0
|
||||
for cache in caches:
|
||||
if hasattr(cache, "invalidate_by_class_types"):
|
||||
total_invalidated += cache.invalidate_by_class_types(
|
||||
evicted_class_types
|
||||
)
|
||||
if total_invalidated > 0:
|
||||
logger.info(
|
||||
"%s ISO:cache_invalidated count=%d class_types=%s",
|
||||
LOG_PREFIX,
|
||||
total_invalidated,
|
||||
evicted_class_types,
|
||||
)
|
||||
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
||||
)
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:flush_transport_wait_idle",
|
||||
)
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
if not callable(flush_fn):
|
||||
continue
|
||||
try:
|
||||
flushed = await flush_fn()
|
||||
if isinstance(flushed, int):
|
||||
total_flushed += flushed
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush released=%d",
|
||||
LOG_PREFIX,
|
||||
ext_name,
|
||||
flushed,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
|
||||
)
|
||||
scan_shm_forensics(
|
||||
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
|
||||
)
|
||||
return total_flushed
|
||||
|
||||
|
||||
async def wait_for_model_patcher_quiescence(
|
||||
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
marker: str = "ISO:wait_model_patcher_idle",
|
||||
) -> bool:
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
start = time.perf_counter()
|
||||
idle = await registry.wait_all_idle(timeout_ms)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
if idle:
|
||||
logger.debug(
|
||||
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
)
|
||||
return True
|
||||
|
||||
states = await registry.get_all_operation_states()
|
||||
logger.error(
|
||||
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
states,
|
||||
)
|
||||
if fail_loud:
|
||||
raise TimeoutError(
|
||||
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
|
||||
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
|
||||
"""Update all active RPC instances with the current event loop.
|
||||
|
||||
This MUST be called at the start of each workflow execution to ensure
|
||||
RPC calls are scheduled on the correct event loop. This handles the case
|
||||
where asyncio.run() creates a new event loop for each workflow.
|
||||
|
||||
Args:
|
||||
loop: The event loop to use. If None, uses asyncio.get_running_loop().
|
||||
"""
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
update_count = 0
|
||||
|
||||
# Update RPCs from ExtensionManagers
|
||||
for manager in _EXTENSION_MANAGERS:
|
||||
if not hasattr(manager, "extensions"):
|
||||
continue
|
||||
for name, extension in manager.extensions.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
|
||||
|
||||
# Also update RPCs from running extensions (they may have direct RPC refs)
|
||||
for name, extension in _RUNNING_EXTENSIONS.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
|
||||
|
||||
if update_count > 0:
|
||||
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
|
||||
else:
|
||||
logger.debug(
|
||||
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"initialize_proxies",
|
||||
"initialize_isolation_nodes",
|
||||
"start_isolation_loading_early",
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"wait_for_model_patcher_quiescence",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
"get_class_types_for_extension",
|
||||
]
|
||||
965
comfy/isolation/adapter.py
Normal file
965
comfy/isolation/adapter.py
Normal file
@@ -0,0 +1,965 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||
|
||||
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||
|
||||
# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp.
|
||||
# Safe to import in sealed workers without host framework modules.
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||
|
||||
# Singleton proxies that transitively import torch, PIL, or heavy host modules.
|
||||
# Only available when torch/host framework is present.
|
||||
CLIPProxy = None
|
||||
CLIPRegistry = None
|
||||
ModelPatcherProxy = None
|
||||
ModelPatcherRegistry = None
|
||||
ModelSamplingProxy = None
|
||||
ModelSamplingRegistry = None
|
||||
VAEProxy = None
|
||||
VAERegistry = None
|
||||
FirstStageModelRegistry = None
|
||||
ModelManagementProxy = None
|
||||
PromptServerService = None
|
||||
ProgressProxy = None
|
||||
UtilsProxy = None
|
||||
_HAS_TORCH_PROXIES = False
|
||||
if _IMPORT_TORCH:
|
||||
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
||||
from comfy.isolation.model_patcher_proxy import (
|
||||
ModelPatcherProxy,
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
from comfy.isolation.model_sampling_proxy import (
|
||||
ModelSamplingProxy,
|
||||
ModelSamplingRegistry,
|
||||
)
|
||||
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
_HAS_TORCH_PROXIES = True
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force /dev/shm for shared memory (bwrap makes /tmp private)
|
||||
import tempfile
|
||||
|
||||
if os.path.exists("/dev/shm"):
|
||||
# Only override if not already set or if default is not /dev/shm
|
||||
current_tmp = tempfile.gettempdir()
|
||||
if not current_tmp.startswith("/dev/shm"):
|
||||
logger.debug(
|
||||
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
|
||||
)
|
||||
os.environ["TMPDIR"] = "/dev/shm"
|
||||
tempfile.tempdir = None # Clear cache to force re-evaluation
|
||||
|
||||
|
||||
class ComfyUIAdapter(IsolationAdapter):
|
||||
# ComfyUI-specific IsolationAdapter implementation
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return "comfyui"
|
||||
|
||||
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
|
||||
if "ComfyUI" in module_path and "custom_nodes" in module_path:
|
||||
parts = module_path.split("ComfyUI")
|
||||
if len(parts) > 1:
|
||||
comfy_root = parts[0] + "ComfyUI"
|
||||
return {
|
||||
"preferred_root": comfy_root,
|
||||
"additional_paths": [
|
||||
os.path.join(comfy_root, "custom_nodes"),
|
||||
os.path.join(comfy_root, "comfy"),
|
||||
],
|
||||
"filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"],
|
||||
}
|
||||
return None
|
||||
|
||||
def get_sandbox_system_paths(self) -> Optional[List[str]]:
|
||||
"""Returns required application paths to mount in the sandbox."""
|
||||
# By inspecting where our adapter is loaded from, we can determine the comfy root
|
||||
adapter_file = inspect.getfile(self.__class__)
|
||||
# adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py
|
||||
comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file)))
|
||||
if os.path.exists(comfy_root):
|
||||
return [comfy_root]
|
||||
return None
|
||||
|
||||
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
||||
comfy_root = snapshot.get("preferred_root")
|
||||
if not comfy_root:
|
||||
return
|
||||
|
||||
requirements_path = Path(comfy_root) / "requirements.txt"
|
||||
if requirements_path.exists():
|
||||
import re
|
||||
|
||||
for line in requirements_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
|
||||
if pkg_name:
|
||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||
|
||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||
if not _IMPORT_TORCH:
|
||||
# Sealed worker without torch — register torch-free TensorValue handler
|
||||
# so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts.
|
||||
import numpy as np
|
||||
|
||||
_TORCH_DTYPE_TO_NUMPY = {
|
||||
"torch.float32": np.float32,
|
||||
"torch.float64": np.float64,
|
||||
"torch.float16": np.float16,
|
||||
"torch.bfloat16": np.float32, # numpy has no bfloat16; upcast
|
||||
"torch.int32": np.int32,
|
||||
"torch.int64": np.int64,
|
||||
"torch.int16": np.int16,
|
||||
"torch.int8": np.int8,
|
||||
"torch.uint8": np.uint8,
|
||||
"torch.bool": np.bool_,
|
||||
}
|
||||
|
||||
def _deserialize_tensor_value(data: Dict[str, Any]) -> Any:
|
||||
dtype_str = data["dtype"]
|
||||
np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32)
|
||||
shape = tuple(data["tensor_size"])
|
||||
arr = np.array(data["data"], dtype=np_dtype).reshape(shape)
|
||||
return arr
|
||||
|
||||
_NUMPY_TO_TORCH_DTYPE = {
|
||||
np.float32: "torch.float32",
|
||||
np.float64: "torch.float64",
|
||||
np.float16: "torch.float16",
|
||||
np.int32: "torch.int32",
|
||||
np.int64: "torch.int64",
|
||||
np.int16: "torch.int16",
|
||||
np.int8: "torch.int8",
|
||||
np.uint8: "torch.uint8",
|
||||
np.bool_: "torch.bool",
|
||||
}
|
||||
|
||||
def _serialize_tensor_value(obj: Any) -> Dict[str, Any]:
|
||||
arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj)
|
||||
dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32")
|
||||
return {
|
||||
"__type__": "TensorValue",
|
||||
"dtype": dtype_str,
|
||||
"tensor_size": list(arr.shape),
|
||||
"requires_grad": False,
|
||||
"data": arr.tolist(),
|
||||
}
|
||||
|
||||
registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||
# ndarray output from sealed workers serializes as TensorValue for host torch reconstruction
|
||||
registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
def serialize_device(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "device", "device_str": str(obj)}
|
||||
|
||||
def deserialize_device(data: Dict[str, Any]) -> Any:
|
||||
return torch.device(data["device_str"])
|
||||
|
||||
registry.register("device", serialize_device, deserialize_device)
|
||||
|
||||
_VALID_DTYPES = {
|
||||
"float16", "float32", "float64", "bfloat16",
|
||||
"int8", "int16", "int32", "int64",
|
||||
"uint8", "bool",
|
||||
}
|
||||
|
||||
def serialize_dtype(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "dtype", "dtype_str": str(obj)}
|
||||
|
||||
def deserialize_dtype(data: Dict[str, Any]) -> Any:
|
||||
dtype_name = data["dtype_str"].replace("torch.", "")
|
||||
if dtype_name not in _VALID_DTYPES:
|
||||
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
|
||||
return getattr(torch, dtype_name)
|
||||
|
||||
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
||||
|
||||
from comfy_api.latest._io import FolderType
|
||||
from comfy_api.latest._ui import SavedImages, SavedResult
|
||||
|
||||
def serialize_saved_result(obj: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"__type__": "SavedResult",
|
||||
"filename": obj.filename,
|
||||
"subfolder": obj.subfolder,
|
||||
"folder_type": obj.type.value,
|
||||
}
|
||||
|
||||
def deserialize_saved_result(data: Dict[str, Any]) -> Any:
|
||||
if isinstance(data, SavedResult):
|
||||
return data
|
||||
folder_type = data["folder_type"] if "folder_type" in data else data["type"]
|
||||
return SavedResult(
|
||||
filename=data["filename"],
|
||||
subfolder=data["subfolder"],
|
||||
type=FolderType(folder_type),
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"SavedResult",
|
||||
serialize_saved_result,
|
||||
deserialize_saved_result,
|
||||
data_type=True,
|
||||
)
|
||||
|
||||
def serialize_saved_images(obj: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"__type__": "SavedImages",
|
||||
"results": [serialize_saved_result(result) for result in obj.results],
|
||||
"is_animated": obj.is_animated,
|
||||
}
|
||||
|
||||
def deserialize_saved_images(data: Dict[str, Any]) -> Any:
|
||||
return SavedImages(
|
||||
results=[deserialize_saved_result(result) for result in data["results"]],
|
||||
is_animated=data.get("is_animated", False),
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"SavedImages",
|
||||
serialize_saved_images,
|
||||
deserialize_saved_images,
|
||||
data_type=True,
|
||||
)
|
||||
|
||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelPatcher in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side: register with registry
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
model_id = ModelPatcherRegistry().register(obj)
|
||||
return {"__type__": "ModelPatcherRef", "model_id": model_id}
|
||||
|
||||
def deserialize_model_patcher(data: Any) -> Any:
|
||||
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
return data
|
||||
|
||||
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelPatcherRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
else:
|
||||
return ModelPatcherRegistry()._get_instance(data["model_id"])
|
||||
|
||||
# Register ModelPatcher type for serialization
|
||||
registry.register(
|
||||
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherProxy type (already a proxy, just return ref)
|
||||
registry.register(
|
||||
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
|
||||
|
||||
def serialize_clip(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
|
||||
clip_id = CLIPRegistry().register(obj)
|
||||
return {"__type__": "CLIPRef", "clip_id": clip_id}
|
||||
|
||||
def deserialize_clip(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
return data
|
||||
|
||||
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware CLIPRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
else:
|
||||
return CLIPRegistry()._get_instance(data["clip_id"])
|
||||
|
||||
# Register CLIP type for serialization
|
||||
registry.register("CLIP", serialize_clip, deserialize_clip)
|
||||
# Register CLIPProxy type (already a proxy, just return ref)
|
||||
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
|
||||
# Register CLIPRef for deserialization (context-aware: host or child)
|
||||
registry.register("CLIPRef", None, deserialize_clip_ref)
|
||||
|
||||
def serialize_vae(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "VAERef", "vae_id": obj._instance_id}
|
||||
vae_id = VAERegistry().register(obj)
|
||||
return {"__type__": "VAERef", "vae_id": vae_id}
|
||||
|
||||
def deserialize_vae(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return VAEProxy(data["vae_id"])
|
||||
return data
|
||||
|
||||
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware VAERef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
# Child: create a proxy
|
||||
return VAEProxy(data["vae_id"])
|
||||
else:
|
||||
# Host: lookup real VAE from registry
|
||||
return VAERegistry()._get_instance(data["vae_id"])
|
||||
|
||||
# Register VAE type for serialization
|
||||
registry.register("VAE", serialize_vae, deserialize_vae)
|
||||
# Register VAEProxy type (already a proxy, just return ref)
|
||||
registry.register("VAEProxy", serialize_vae, deserialize_vae)
|
||||
# Register VAERef for deserialization (context-aware: host or child)
|
||||
registry.register("VAERef", None, deserialize_vae_ref)
|
||||
|
||||
# ModelSampling serialization - handles ModelSampling* types
|
||||
# copyreg removed - no pickle fallback allowed
|
||||
|
||||
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
||||
# Proxy with _instance_id — return ref (works from both host and child)
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
# Child-side: object created locally in child (e.g. ModelSamplingAdvanced
|
||||
# in nodes_z_image_turbo.py). Serialize as inline data so the host can
|
||||
# reconstruct the real torch.nn.Module.
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
import base64
|
||||
import io as _io
|
||||
|
||||
# Identify base classes from comfy.model_sampling
|
||||
bases = []
|
||||
for base in type(obj).__mro__:
|
||||
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
|
||||
bases.append(base.__name__)
|
||||
# Serialize state_dict as base64 safetensors-like
|
||||
sd = obj.state_dict()
|
||||
sd_serialized = {}
|
||||
for k, v in sd.items():
|
||||
buf = _io.BytesIO()
|
||||
torch.save(v, buf)
|
||||
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
# Capture plain attrs (shift, multiplier, sigma_data, etc.)
|
||||
plain_attrs = {}
|
||||
for k, v in obj.__dict__.items():
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
if isinstance(v, (bool, int, float, str)):
|
||||
plain_attrs[k] = v
|
||||
return {
|
||||
"__type__": "ModelSamplingInline",
|
||||
"bases": bases,
|
||||
"state_dict": sd_serialized,
|
||||
"attrs": plain_attrs,
|
||||
}
|
||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||
ms_id = ModelSamplingRegistry().register(obj)
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||
|
||||
def deserialize_model_sampling(data: Any) -> Any:
|
||||
"""Deserialize ModelSampling refs or inline data."""
|
||||
if isinstance(data, dict):
|
||||
if data.get("__type__") == "ModelSamplingInline":
|
||||
return _reconstruct_model_sampling_inline(data)
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
return data
|
||||
|
||||
def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any:
|
||||
"""Reconstruct a ModelSampling object on the host from inline child data."""
|
||||
import comfy.model_sampling as _ms
|
||||
import base64
|
||||
import io as _io
|
||||
|
||||
# Resolve base classes
|
||||
base_classes = []
|
||||
for name in data["bases"]:
|
||||
cls = getattr(_ms, name, None)
|
||||
if cls is not None:
|
||||
base_classes.append(cls)
|
||||
if not base_classes:
|
||||
raise RuntimeError(
|
||||
f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}"
|
||||
)
|
||||
# Create dynamic class matching the child's class hierarchy
|
||||
ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {})
|
||||
obj = ReconstructedSampling.__new__(ReconstructedSampling)
|
||||
torch.nn.Module.__init__(obj)
|
||||
# Restore plain attributes first
|
||||
for k, v in data.get("attrs", {}).items():
|
||||
setattr(obj, k, v)
|
||||
# Restore state_dict (buffers like sigmas)
|
||||
for k, v_b64 in data.get("state_dict", {}).items():
|
||||
buf = _io.BytesIO(base64.b64decode(v_b64))
|
||||
tensor = torch.load(buf, weights_only=True)
|
||||
# Register as buffer so it's part of state_dict
|
||||
parts = k.split(".")
|
||||
if len(parts) == 1:
|
||||
cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member
|
||||
else:
|
||||
setattr(obj, parts[0], tensor)
|
||||
# Register on host so future references use proxy pattern.
|
||||
# Skip in child process — register() is async RPC and cannot be
|
||||
# called synchronously during deserialization.
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
ModelSamplingRegistry().register(obj)
|
||||
return obj
|
||||
|
||||
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
else:
|
||||
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||
|
||||
# Register all ModelSampling* and StableCascadeSampling classes dynamically
|
||||
import comfy.model_sampling
|
||||
|
||||
for ms_cls in vars(comfy.model_sampling).values():
|
||||
if not isinstance(ms_cls, type):
|
||||
continue
|
||||
if not issubclass(ms_cls, torch.nn.Module):
|
||||
continue
|
||||
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
|
||||
continue
|
||||
registry.register(
|
||||
ms_cls.__name__,
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||
)
|
||||
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
||||
# Register ModelSamplingInline for deserialization (child→host inline transfer)
|
||||
registry.register(
|
||||
"ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data)
|
||||
)
|
||||
|
||||
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"cond": obj.cond,
|
||||
}
|
||||
|
||||
def deserialize_cond(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(data["cond"])
|
||||
|
||||
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
|
||||
state: Dict[str, Any] = {}
|
||||
for key, value in obj.__dict__.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if callable(value):
|
||||
continue
|
||||
state[key] = value
|
||||
return state
|
||||
|
||||
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"state": _serialize_public_state(obj),
|
||||
}
|
||||
|
||||
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
obj = cls()
|
||||
for key, value in data.get("state", {}).items():
|
||||
prop = getattr(type(obj), key, None)
|
||||
if isinstance(prop, property) and prop.fset is None:
|
||||
continue
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
import comfy.conds
|
||||
|
||||
for cond_cls in vars(comfy.conds).values():
|
||||
if not isinstance(cond_cls, type):
|
||||
continue
|
||||
if not issubclass(cond_cls, comfy.conds.CONDRegular):
|
||||
continue
|
||||
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
|
||||
registry.register(type_key, serialize_cond, deserialize_cond)
|
||||
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
|
||||
|
||||
import comfy.latent_formats
|
||||
|
||||
for latent_cls in vars(comfy.latent_formats).values():
|
||||
if not isinstance(latent_cls, type):
|
||||
continue
|
||||
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
|
||||
continue
|
||||
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
|
||||
registry.register(
|
||||
type_key, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
registry.register(
|
||||
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
|
||||
# V3 API: unwrap NodeOutput.args
|
||||
def deserialize_node_output(data: Any) -> Any:
|
||||
return getattr(data, "args", data)
|
||||
|
||||
registry.register("NodeOutput", None, deserialize_node_output)
|
||||
|
||||
# KSAMPLER serializer: stores sampler name instead of function object
|
||||
# sampler_function is a callable which gets filtered out by JSONSocketTransport
|
||||
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
|
||||
func_name = obj.sampler_function.__name__
|
||||
# Map function name back to sampler name
|
||||
if func_name == "sample_unipc":
|
||||
sampler_name = "uni_pc"
|
||||
elif func_name == "sample_unipc_bh2":
|
||||
sampler_name = "uni_pc_bh2"
|
||||
elif func_name == "dpm_fast_function":
|
||||
sampler_name = "dpm_fast"
|
||||
elif func_name == "dpm_adaptive_function":
|
||||
sampler_name = "dpm_adaptive"
|
||||
elif func_name.startswith("sample_"):
|
||||
sampler_name = func_name[7:] # Remove "sample_" prefix
|
||||
else:
|
||||
sampler_name = func_name
|
||||
return {
|
||||
"__type__": "KSAMPLER",
|
||||
"sampler_name": sampler_name,
|
||||
"extra_options": obj.extra_options,
|
||||
"inpaint_options": obj.inpaint_options,
|
||||
}
|
||||
|
||||
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
|
||||
import comfy.samplers
|
||||
|
||||
return comfy.samplers.ksampler(
|
||||
data["sampler_name"],
|
||||
data.get("extra_options", {}),
|
||||
data.get("inpaint_options", {}),
|
||||
)
|
||||
|
||||
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
|
||||
|
||||
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
|
||||
|
||||
register_hooks_serializers(registry)
|
||||
|
||||
# Generic Numpy Serializer
|
||||
def serialize_numpy(obj: Any) -> Any:
|
||||
import torch
|
||||
|
||||
try:
|
||||
# Attempt zero-copy conversion to Tensor
|
||||
return torch.from_numpy(obj)
|
||||
except Exception:
|
||||
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||
return obj.tolist()
|
||||
|
||||
def deserialize_numpy_b64(data: Any) -> Any:
|
||||
"""Deserialize base64-encoded ndarray from sealed worker."""
|
||||
import base64
|
||||
import numpy as np
|
||||
if isinstance(data, dict) and "data" in data and "dtype" in data:
|
||||
raw = base64.b64decode(data["data"])
|
||||
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
|
||||
return torch.from_numpy(arr.copy())
|
||||
return data
|
||||
|
||||
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
|
||||
|
||||
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
|
||||
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
|
||||
|
||||
def serialize_file3d(obj: Any) -> Dict[str, Any]:
|
||||
import base64
|
||||
return {
|
||||
"__type__": "File3D",
|
||||
"format": obj.format,
|
||||
"data": base64.b64encode(obj.get_bytes()).decode("ascii"),
|
||||
}
|
||||
|
||||
def deserialize_file3d(data: Any) -> Any:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from comfy_api.latest._util.geometry_types import File3D
|
||||
return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"])
|
||||
|
||||
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
|
||||
|
||||
# -- VIDEO (comfy_api.latest._input_impl.video_types) -------------------
|
||||
# Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962
|
||||
|
||||
def serialize_video(obj: Any) -> Dict[str, Any]:
|
||||
components = obj.get_components()
|
||||
images = components.images.detach() if components.images.requires_grad else components.images
|
||||
result: Dict[str, Any] = {
|
||||
"__type__": "VIDEO",
|
||||
"images": images,
|
||||
"frame_rate_num": components.frame_rate.numerator,
|
||||
"frame_rate_den": components.frame_rate.denominator,
|
||||
}
|
||||
if components.audio is not None:
|
||||
waveform = components.audio["waveform"]
|
||||
if waveform.requires_grad:
|
||||
waveform = waveform.detach()
|
||||
result["audio_waveform"] = waveform
|
||||
result["audio_sample_rate"] = components.audio["sample_rate"]
|
||||
if components.metadata is not None:
|
||||
result["metadata"] = components.metadata
|
||||
return result
|
||||
|
||||
def deserialize_video(data: Any) -> Any:
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest._input_impl.video_types import VideoFromComponents
|
||||
from comfy_api.latest._util.video_types import VideoComponents
|
||||
audio = None
|
||||
if "audio_waveform" in data:
|
||||
audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]}
|
||||
components = VideoComponents(
|
||||
images=data["images"],
|
||||
frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]),
|
||||
audio=audio,
|
||||
metadata=data.get("metadata"),
|
||||
)
|
||||
return VideoFromComponents(components)
|
||||
|
||||
registry.register("VIDEO", serialize_video, deserialize_video, data_type=True)
|
||||
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
|
||||
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
|
||||
|
||||
def setup_web_directory(self, module: Any) -> None:
|
||||
"""Detect WEB_DIRECTORY on a module and populate/register it.
|
||||
|
||||
Called by the sealed worker after loading the node module.
|
||||
Mirrors extension_wrapper.py:216-227 for host-coupled nodes.
|
||||
Does NOT import extension_wrapper.py (it has `import torch` at module level).
|
||||
"""
|
||||
import shutil
|
||||
|
||||
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||
if web_dir_attr is None:
|
||||
return
|
||||
|
||||
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||
|
||||
# Read extension name from pyproject.toml
|
||||
ext_name = os.path.basename(module_dir)
|
||||
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||
if os.path.exists(pyproject):
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
try:
|
||||
with open(pyproject, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
name = data.get("project", {}).get("name")
|
||||
if name:
|
||||
ext_name = name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Populate web dir if empty (mirrors _run_prestartup_web_copy)
|
||||
if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))):
|
||||
os.makedirs(web_dir_path, exist_ok=True)
|
||||
|
||||
# Module-defined copy spec
|
||||
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||
if copy_spec is not None and callable(copy_spec):
|
||||
try:
|
||||
copy_spec(web_dir_path)
|
||||
except Exception as e:
|
||||
logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e)
|
||||
|
||||
# Fallback: comfy_3d_viewers
|
||||
try:
|
||||
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||
for viewer in VIEWER_FILES:
|
||||
try:
|
||||
copy_viewer(viewer, web_dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback: comfy_dynamic_widgets
|
||||
try:
|
||||
from comfy_dynamic_widgets import get_js_path
|
||||
src = os.path.realpath(get_js_path())
|
||||
if os.path.exists(src):
|
||||
dst_dir = os.path.join(web_dir_path, "js")
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js"))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||
logger.info(
|
||||
"][ Adapter: registered web dir for %s (%d files)",
|
||||
ext_name,
|
||||
sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_host_event_handlers(extension: Any) -> None:
|
||||
"""Register host-side event handlers for an isolated extension.
|
||||
|
||||
Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK``
|
||||
so the ComfyUI frontend receives progress bar updates.
|
||||
"""
|
||||
register_event_handler = inspect.getattr_static(
|
||||
extension, "register_event_handler", None
|
||||
)
|
||||
if not callable(register_event_handler):
|
||||
return
|
||||
|
||||
def _host_progress_handler(payload: dict) -> None:
|
||||
import comfy.utils
|
||||
|
||||
hook = comfy.utils.PROGRESS_BAR_HOOK
|
||||
if hook is not None:
|
||||
hook(
|
||||
payload.get("value", 0),
|
||||
payload.get("total", 0),
|
||||
payload.get("preview"),
|
||||
payload.get("node_id"),
|
||||
)
|
||||
|
||||
extension.register_event_handler("progress", _host_progress_handler)
|
||||
|
||||
def setup_child_event_hooks(self, extension: Any) -> None:
|
||||
"""Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension.
|
||||
|
||||
Host-coupled only — sealed workers do not have comfy.utils (torch).
|
||||
"""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child)
|
||||
if not is_child:
|
||||
return
|
||||
|
||||
if not _IMPORT_TORCH:
|
||||
logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)")
|
||||
return
|
||||
|
||||
import comfy.utils
|
||||
|
||||
def _event_progress_hook(value, total, preview=None, node_id=None):
|
||||
logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id)
|
||||
extension.emit_event("progress", {
|
||||
"value": value,
|
||||
"total": total,
|
||||
"node_id": node_id,
|
||||
})
|
||||
|
||||
comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook
|
||||
logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel")
|
||||
|
||||
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
||||
# Always available — no torch/PIL dependency
|
||||
services: List[type[ProxiedSingleton]] = [
|
||||
FolderPathsProxy,
|
||||
HelperProxiesService,
|
||||
WebDirectoryProxy,
|
||||
]
|
||||
# Torch/PIL-dependent proxies
|
||||
if _HAS_TORCH_PROXIES:
|
||||
services.extend([
|
||||
PromptServerService,
|
||||
ModelManagementProxy,
|
||||
UtilsProxy,
|
||||
ProgressProxy,
|
||||
VAERegistry,
|
||||
CLIPRegistry,
|
||||
ModelPatcherRegistry,
|
||||
ModelSamplingRegistry,
|
||||
FirstStageModelRegistry,
|
||||
])
|
||||
return services
|
||||
|
||||
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
||||
# Resolve the real name whether it's an instance or the Singleton class itself
|
||||
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
|
||||
|
||||
if api_name == "FolderPathsProxy":
|
||||
import folder_paths
|
||||
|
||||
# Replace module-level functions with proxy methods
|
||||
# This is aggressive but necessary for transparent proxying
|
||||
# Handle both instance and class cases
|
||||
instance = api() if isinstance(api, type) else api
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(folder_paths, name, getattr(instance, name))
|
||||
|
||||
# Fence: isolated children get writable temp inside sandbox
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
import tempfile
|
||||
_child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp")
|
||||
os.makedirs(_child_temp, exist_ok=True)
|
||||
folder_paths.temp_directory = _child_temp
|
||||
|
||||
return
|
||||
|
||||
if api_name == "ModelManagementProxy":
|
||||
if _IMPORT_TORCH:
|
||||
import comfy.model_management
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
# Replace module-level functions with proxy methods
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(comfy.model_management, name, getattr(instance, name))
|
||||
return
|
||||
|
||||
if api_name == "UtilsProxy":
|
||||
if not _IMPORT_TORCH:
|
||||
logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)")
|
||||
return
|
||||
|
||||
import comfy.utils
|
||||
|
||||
# Static Injection of RPC mechanism to ensure Child can access it
|
||||
# independent of instance lifecycle.
|
||||
api.set_rpc(rpc)
|
||||
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child)
|
||||
|
||||
# Progress hook wiring moved to setup_child_event_hooks via event channel
|
||||
|
||||
return
|
||||
|
||||
if api_name == "PromptServerProxy":
|
||||
if not _IMPORT_TORCH:
|
||||
return
|
||||
# Defer heavy import to child context
|
||||
import server
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
proxy = (
|
||||
instance.instance
|
||||
) # PromptServerProxy instance has .instance property returning self
|
||||
|
||||
original_register_route = proxy.register_route
|
||||
|
||||
def register_route_wrapper(
|
||||
method: str, path: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
callback_id = rpc.register_callback(handler)
|
||||
loop = getattr(rpc, "loop", None)
|
||||
if loop and loop.is_running():
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
return None
|
||||
|
||||
proxy.register_route = register_route_wrapper
|
||||
|
||||
class RouteTableDefProxy:
|
||||
def __init__(self, proxy_instance: Any):
|
||||
self.proxy = proxy_instance
|
||||
|
||||
def get(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
proxy.routes = RouteTableDefProxy(proxy)
|
||||
|
||||
if (
|
||||
hasattr(server, "PromptServer")
|
||||
and getattr(server.PromptServer, "instance", None) != proxy
|
||||
):
|
||||
server.PromptServer.instance = proxy
|
||||
101
comfy/isolation/child_hooks.py
Normal file
101
comfy/isolation/child_hooks.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
|
||||
# Child process initialization for PyIsolate
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def initialize_child_process() -> None:
|
||||
_setup_child_loop_bridge()
|
||||
|
||||
# Manual RPC injection
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc:
|
||||
_setup_proxy_callers(rpc)
|
||||
else:
|
||||
logger.warning("Could not get child RPC instance for manual injection")
|
||||
_setup_proxy_callers()
|
||||
except Exception as e:
|
||||
logger.error(f"Manual RPC Injection failed: {e}")
|
||||
_setup_proxy_callers()
|
||||
|
||||
_setup_logging()
|
||||
|
||||
|
||||
def _setup_child_loop_bridge() -> None:
|
||||
import asyncio
|
||||
|
||||
main_loop = None
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
try:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if main_loop is None:
|
||||
return
|
||||
|
||||
try:
|
||||
from .proxies.base import set_global_loop
|
||||
|
||||
set_global_loop(main_loop)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _setup_prompt_server_stub(rpc=None) -> None:
|
||||
try:
|
||||
from .proxies.prompt_server_impl import PromptServerStub
|
||||
|
||||
if rpc:
|
||||
PromptServerStub.set_rpc(rpc)
|
||||
elif hasattr(PromptServerStub, "clear_rpc"):
|
||||
PromptServerStub.clear_rpc()
|
||||
else:
|
||||
PromptServerStub._rpc = None # type: ignore[attr-defined]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup PromptServerStub: {e}")
|
||||
|
||||
|
||||
def _setup_proxy_callers(rpc=None) -> None:
|
||||
try:
|
||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from .proxies.helper_proxies import HelperProxiesService
|
||||
from .proxies.model_management_proxy import ModelManagementProxy
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
from .proxies.prompt_server_impl import PromptServerStub
|
||||
from .proxies.utils_proxy import UtilsProxy
|
||||
|
||||
if rpc is None:
|
||||
FolderPathsProxy.clear_rpc()
|
||||
HelperProxiesService.clear_rpc()
|
||||
ModelManagementProxy.clear_rpc()
|
||||
ProgressProxy.clear_rpc()
|
||||
PromptServerStub.clear_rpc()
|
||||
UtilsProxy.clear_rpc()
|
||||
return
|
||||
|
||||
FolderPathsProxy.set_rpc(rpc)
|
||||
HelperProxiesService.set_rpc(rpc)
|
||||
ModelManagementProxy.set_rpc(rpc)
|
||||
ProgressProxy.set_rpc(rpc)
|
||||
PromptServerStub.set_rpc(rpc)
|
||||
UtilsProxy.set_rpc(rpc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup child singleton proxy callers: {e}")
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
327
comfy/isolation/clip_proxy.py
Normal file
327
comfy/isolation/clip_proxy.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
|
||||
# CLIP Proxy implementation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
|
||||
class CondStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "cond_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
|
||||
_registry_class = CondStageModelRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CondStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class TokenizerRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "tokenizer"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
|
||||
_registry_class = TokenizerRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TokenizerProxy {self._instance_id}>"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "clip"
|
||||
_allowed_setters = {
|
||||
"layer_idx",
|
||||
"tokenizer_options",
|
||||
"use_clip_schedule",
|
||||
"apply_hooks_to_conds",
|
||||
}
|
||||
|
||||
async def get_ram_usage(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).get_ram_usage()
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
|
||||
|
||||
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
|
||||
|
||||
async def get_cond_stage_model_id(self, instance_id: str) -> str:
|
||||
return CondStageModelRegistry().register(
|
||||
self._get_instance(instance_id).cond_stage_model
|
||||
)
|
||||
|
||||
async def get_tokenizer_id(self, instance_id: str) -> str:
|
||||
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
|
||||
|
||||
async def load_model(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).load_model()
|
||||
|
||||
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
|
||||
self._get_instance(instance_id).clip_layer(layer_idx)
|
||||
|
||||
async def set_tokenizer_option(
|
||||
self, instance_id: str, option_name: str, value: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
|
||||
if name not in self._allowed_setters:
|
||||
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
|
||||
setattr(self._get_instance(instance_id), name, value)
|
||||
|
||||
async def tokenize(
|
||||
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).tokenize(
|
||||
text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
async def encode(self, instance_id: str, text: str) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(text))
|
||||
|
||||
async def encode_from_tokens(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
return_pooled: bool = False,
|
||||
return_dict: bool = False,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens(
|
||||
tokens, return_pooled=return_pooled, return_dict=return_dict
|
||||
)
|
||||
)
|
||||
|
||||
async def encode_from_tokens_scheduled(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens_scheduled(
|
||||
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
|
||||
)
|
||||
)
|
||||
|
||||
async def add_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).add_patches(
|
||||
patches, strength_patch=strength_patch, strength_model=strength_model
|
||||
)
|
||||
|
||||
async def get_key_patches(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_key_patches()
|
||||
|
||||
async def load_sd(
|
||||
self, instance_id: str, sd: dict, full_model: bool = False
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
|
||||
|
||||
async def get_sd(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_sd()
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
return self.register(self._get_instance(instance_id).clone())
|
||||
|
||||
|
||||
class CLIPProxy(BaseProxy[CLIPRegistry]):
|
||||
_registry_class = CLIPRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
@property
|
||||
def patcher(self) -> "ModelPatcherProxy":
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@patcher.setter
|
||||
def patcher(self, value: Any) -> None:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if isinstance(value, ModelPatcherProxy):
|
||||
self._patcher_proxy = value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
|
||||
)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self) -> CondStageModelProxy:
|
||||
if not hasattr(self, "_cond_stage_model_proxy"):
|
||||
csm_id = self._call_rpc("get_cond_stage_model_id")
|
||||
self._cond_stage_model_proxy = CondStageModelProxy(
|
||||
csm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._cond_stage_model_proxy
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerProxy:
|
||||
if not hasattr(self, "_tokenizer_proxy"):
|
||||
tok_id = self._call_rpc("get_tokenizer_id")
|
||||
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
|
||||
return self._tokenizer_proxy
|
||||
|
||||
def load_model(self) -> ModelPatcherProxy:
|
||||
self._call_rpc("load_model")
|
||||
return self.patcher
|
||||
|
||||
@property
|
||||
def layer_idx(self) -> Optional[int]:
|
||||
return self._call_rpc("get_property", "layer_idx")
|
||||
|
||||
@layer_idx.setter
|
||||
def layer_idx(self, value: Optional[int]) -> None:
|
||||
self._call_rpc("set_property", "layer_idx", value)
|
||||
|
||||
@property
|
||||
def tokenizer_options(self) -> dict:
|
||||
return self._call_rpc("get_property", "tokenizer_options")
|
||||
|
||||
@tokenizer_options.setter
|
||||
def tokenizer_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_property", "tokenizer_options", value)
|
||||
|
||||
@property
|
||||
def use_clip_schedule(self) -> bool:
|
||||
return self._call_rpc("get_property", "use_clip_schedule")
|
||||
|
||||
@use_clip_schedule.setter
|
||||
def use_clip_schedule(self, value: bool) -> None:
|
||||
self._call_rpc("set_property", "use_clip_schedule", value)
|
||||
|
||||
@property
|
||||
def apply_hooks_to_conds(self) -> Any:
|
||||
return self._call_rpc("get_property", "apply_hooks_to_conds")
|
||||
|
||||
@apply_hooks_to_conds.setter
|
||||
def apply_hooks_to_conds(self, value: Any) -> None:
|
||||
self._call_rpc("set_property", "apply_hooks_to_conds", value)
|
||||
|
||||
def clip_layer(self, layer_idx: int) -> None:
|
||||
return self._call_rpc("clip_layer", layer_idx)
|
||||
|
||||
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
|
||||
return self._call_rpc("set_tokenizer_option", option_name, value)
|
||||
|
||||
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(
|
||||
"tokenize", text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
def encode(self, text: str) -> Any:
|
||||
return self._call_rpc("encode", text)
|
||||
|
||||
def encode_from_tokens(
|
||||
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
|
||||
) -> Any:
|
||||
res = self._call_rpc(
|
||||
"encode_from_tokens",
|
||||
tokens,
|
||||
return_pooled=return_pooled,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if return_pooled and isinstance(res, list) and not return_dict:
|
||||
return tuple(res)
|
||||
return res
|
||||
|
||||
def encode_from_tokens_scheduled(
|
||||
self,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return self._call_rpc(
|
||||
"encode_from_tokens_scheduled",
|
||||
tokens,
|
||||
unprojected=unprojected,
|
||||
add_dict=add_dict,
|
||||
show_pbar=show_pbar,
|
||||
)
|
||||
|
||||
def add_patches(
|
||||
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"add_patches",
|
||||
patches,
|
||||
strength_patch=strength_patch,
|
||||
strength_model=strength_model,
|
||||
)
|
||||
|
||||
def get_key_patches(self) -> Any:
|
||||
return self._call_rpc("get_key_patches")
|
||||
|
||||
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
|
||||
return self._call_rpc("load_sd", sd, full_model=full_model)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def clone(self) -> CLIPProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
|
||||
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
|
||||
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()
|
||||
16
comfy/isolation/custom_node_serializers.py
Normal file
16
comfy/isolation/custom_node_serializers.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Compatibility shim for the indexed serializer path."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def register_custom_node_serializers(_registry: Any) -> None:
|
||||
"""Legacy no-op shim.
|
||||
|
||||
Serializer registration now lives directly in the active isolation adapter.
|
||||
This module remains importable because the isolation index still references it.
|
||||
"""
|
||||
return None
|
||||
|
||||
__all__ = ["register_custom_node_serializers"]
|
||||
489
comfy/isolation/extension_loader.py
Normal file
489
comfy/isolation/extension_loader.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import pyisolate
|
||||
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||
from packaging.requirements import InvalidRequirement, Requirement
|
||||
from packaging.utils import canonicalize_name
|
||||
|
||||
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||
from .host_policy import load_host_policy
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
|
||||
"""Register an isolated extension's web directory on the host side."""
|
||||
import nodes
|
||||
|
||||
# Method 1: pyproject.toml [tool.comfy] web field
|
||||
pyproject = node_dir / "pyproject.toml"
|
||||
if pyproject.exists():
|
||||
try:
|
||||
with pyproject.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
|
||||
if web_dir_name:
|
||||
web_dir_path = str(node_dir / web_dir_name)
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug(
|
||||
"][ Registered web dir for isolated %s: %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
|
||||
init_file = node_dir / "__init__.py"
|
||||
if init_file.exists():
|
||||
try:
|
||||
source = init_file.read_text()
|
||||
for line in source.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("WEB_DIRECTORY"):
|
||||
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
|
||||
_, _, value = stripped.partition("=")
|
||||
value = value.strip().strip("\"'")
|
||||
if value:
|
||||
web_dir_path = str((node_dir / value).resolve())
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug(
|
||||
"][ Registered web dir for isolated %s: %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _get_extension_type(execution_model: str) -> type[Any]:
|
||||
if execution_model == "sealed_worker":
|
||||
return pyisolate.SealedNodeExtension
|
||||
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
return ComfyNodeExtension
|
||||
|
||||
|
||||
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
|
||||
try:
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
except Exception:
|
||||
logger.debug("][ %s stop failed", extension_name, exc_info=True)
|
||||
|
||||
|
||||
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
|
||||
req, sep, marker = dep.partition(";")
|
||||
req = req.strip()
|
||||
marker_suffix = f";{marker}" if sep else ""
|
||||
|
||||
def _resolve_local_path(local_path: str) -> Path | None:
|
||||
for base in base_paths:
|
||||
candidate = (base / local_path).resolve()
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
if req.startswith("./") or req.startswith("../"):
|
||||
resolved = _resolve_local_path(req)
|
||||
if resolved is not None:
|
||||
return f"{resolved}{marker_suffix}"
|
||||
|
||||
if req.startswith("file://"):
|
||||
raw = req[len("file://") :]
|
||||
if raw.startswith("./") or raw.startswith("../"):
|
||||
resolved = _resolve_local_path(raw)
|
||||
if resolved is not None:
|
||||
return f"file://{resolved}{marker_suffix}"
|
||||
|
||||
return dep
|
||||
|
||||
|
||||
def _dependency_name_from_spec(dep: str) -> str | None:
|
||||
stripped = dep.strip()
|
||||
if not stripped or stripped == "-e" or stripped.startswith("-e "):
|
||||
return None
|
||||
if stripped.startswith(("/", "./", "../", "file://")):
|
||||
return None
|
||||
|
||||
try:
|
||||
return canonicalize_name(Requirement(stripped).name)
|
||||
except InvalidRequirement:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_cuda_wheels_config(
|
||||
tool_config: dict[str, object], dependencies: list[str]
|
||||
) -> dict[str, object] | None:
|
||||
raw_config = tool_config.get("cuda_wheels")
|
||||
if raw_config is None:
|
||||
return None
|
||||
if not isinstance(raw_config, dict):
|
||||
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
|
||||
|
||||
index_url = raw_config.get("index_url")
|
||||
index_urls = raw_config.get("index_urls")
|
||||
if index_urls is not None:
|
||||
if not isinstance(index_urls, list) or not all(
|
||||
isinstance(u, str) and u.strip() for u in index_urls
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
|
||||
)
|
||||
elif not isinstance(index_url, str) or not index_url.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
|
||||
)
|
||||
|
||||
packages = raw_config.get("packages")
|
||||
if not isinstance(packages, list) or not all(
|
||||
isinstance(package_name, str) and package_name.strip()
|
||||
for package_name in packages
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
|
||||
)
|
||||
|
||||
declared_dependencies = {
|
||||
dependency_name
|
||||
for dep in dependencies
|
||||
if (dependency_name := _dependency_name_from_spec(dep)) is not None
|
||||
}
|
||||
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
|
||||
missing = [
|
||||
package_name
|
||||
for package_name in normalized_packages
|
||||
if package_name not in declared_dependencies
|
||||
]
|
||||
if missing:
|
||||
missing_joined = ", ".join(sorted(missing))
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
|
||||
f"{missing_joined}"
|
||||
)
|
||||
|
||||
package_map = raw_config.get("package_map", {})
|
||||
if not isinstance(package_map, dict):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
|
||||
)
|
||||
|
||||
normalized_package_map: dict[str, str] = {}
|
||||
for dependency_name, index_package_name in package_map.items():
|
||||
if not isinstance(dependency_name, str) or not dependency_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
|
||||
)
|
||||
if not isinstance(index_package_name, str) or not index_package_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
|
||||
)
|
||||
canonical_dependency_name = canonicalize_name(dependency_name)
|
||||
if canonical_dependency_name not in normalized_packages:
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
|
||||
"[tool.comfy.isolation.cuda_wheels.packages]"
|
||||
)
|
||||
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
|
||||
|
||||
result: dict = {
|
||||
"packages": normalized_packages,
|
||||
"package_map": normalized_package_map,
|
||||
}
|
||||
if index_urls is not None:
|
||||
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
|
||||
else:
|
||||
result["index_url"] = index_url.rstrip("/") + "/"
|
||||
return result
|
||||
|
||||
|
||||
def get_enforcement_policy() -> Dict[str, bool]:
|
||||
return {
|
||||
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
|
||||
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
|
||||
}
|
||||
|
||||
|
||||
class ExtensionLoadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
|
||||
normalized_name = extension_name.replace("-", "_").replace(".", "_")
|
||||
if normalized_name not in sys.modules:
|
||||
dummy_module = types.ModuleType(normalized_name)
|
||||
dummy_module.__file__ = str(node_dir / "__init__.py")
|
||||
dummy_module.__path__ = [str(node_dir)]
|
||||
dummy_module.__package__ = normalized_name
|
||||
sys.modules[normalized_name] = dummy_module
|
||||
|
||||
|
||||
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
|
||||
for details in cached_data.values():
|
||||
if not isinstance(details, dict):
|
||||
return True
|
||||
if details.get("is_v3") and "schema_v1" not in details:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def load_isolated_node(
|
||||
node_dir: Path,
|
||||
manifest_path: Path,
|
||||
logger: logging.Logger,
|
||||
build_stub_class: Callable[[str, Dict[str, object], Any], type],
|
||||
venv_root: Path,
|
||||
extension_managers: List[ExtensionManager],
|
||||
) -> List[Tuple[str, str, type]]:
|
||||
try:
|
||||
with manifest_path.open("rb") as handle:
|
||||
manifest_data = tomllib.load(handle)
|
||||
except Exception as e:
|
||||
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
|
||||
return []
|
||||
|
||||
# Parse [tool.comfy.isolation]
|
||||
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||
can_isolate = tool_config.get("can_isolate", False)
|
||||
share_torch = tool_config.get("share_torch", False)
|
||||
package_manager = tool_config.get("package_manager", "uv")
|
||||
is_conda = package_manager == "conda"
|
||||
execution_model = tool_config.get("execution_model")
|
||||
if execution_model is None:
|
||||
execution_model = "sealed_worker" if is_conda else "host-coupled"
|
||||
|
||||
if "sealed_host_ro_paths" in tool_config:
|
||||
raise ValueError(
|
||||
"Manifest field 'sealed_host_ro_paths' is not allowed. "
|
||||
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
|
||||
)
|
||||
|
||||
# Conda-specific manifest fields
|
||||
conda_channels: list[str] = (
|
||||
tool_config.get("conda_channels", []) if is_conda else []
|
||||
)
|
||||
conda_dependencies: list[str] = (
|
||||
tool_config.get("conda_dependencies", []) if is_conda else []
|
||||
)
|
||||
conda_platforms: list[str] = (
|
||||
tool_config.get("conda_platforms", []) if is_conda else []
|
||||
)
|
||||
conda_python: str = (
|
||||
tool_config.get("conda_python", "*") if is_conda else "*"
|
||||
)
|
||||
|
||||
# Parse [project] dependencies
|
||||
project_config = manifest_data.get("project", {})
|
||||
dependencies = project_config.get("dependencies", [])
|
||||
if not isinstance(dependencies, list):
|
||||
dependencies = []
|
||||
|
||||
# Get extension name (default to folder name if not in project.name)
|
||||
extension_name = project_config.get("name", node_dir.name)
|
||||
|
||||
# LOGIC: Isolation Decision
|
||||
policy = get_enforcement_policy()
|
||||
isolated = can_isolate or policy["force_isolated"]
|
||||
|
||||
if not isolated:
|
||||
return []
|
||||
|
||||
import folder_paths
|
||||
|
||||
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||
dependencies = [
|
||||
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
|
||||
for dep in dependencies
|
||||
]
|
||||
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
|
||||
|
||||
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||
extension_type = _get_extension_type(execution_model)
|
||||
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||
extension_type, manager_config
|
||||
)
|
||||
extension_managers.append(manager)
|
||||
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
|
||||
sandbox_config = {}
|
||||
is_linux = platform.system() == "Linux"
|
||||
|
||||
if is_conda:
|
||||
share_torch = False
|
||||
share_cuda_ipc = False
|
||||
else:
|
||||
share_cuda_ipc = share_torch and is_linux
|
||||
|
||||
if is_linux and isolated:
|
||||
sandbox_config = {
|
||||
"network": host_policy["allow_network"],
|
||||
"writable_paths": host_policy["writable_paths"],
|
||||
"readonly_paths": host_policy["readonly_paths"],
|
||||
}
|
||||
|
||||
extension_config: dict = {
|
||||
"name": extension_name,
|
||||
"module_path": str(node_dir),
|
||||
"isolated": True,
|
||||
"dependencies": dependencies,
|
||||
"share_torch": share_torch,
|
||||
"share_cuda_ipc": share_cuda_ipc,
|
||||
"sandbox_mode": host_policy["sandbox_mode"],
|
||||
"sandbox": sandbox_config,
|
||||
}
|
||||
|
||||
_is_sealed = execution_model == "sealed_worker"
|
||||
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
|
||||
logger.info(
|
||||
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
|
||||
extension_name,
|
||||
"x" if share_torch else " ",
|
||||
"x" if _is_sealed else " ",
|
||||
"x" if _is_sandboxed else " ",
|
||||
)
|
||||
|
||||
if cuda_wheels is not None:
|
||||
extension_config["cuda_wheels"] = cuda_wheels
|
||||
|
||||
# Conda-specific keys
|
||||
if is_conda:
|
||||
extension_config["package_manager"] = "conda"
|
||||
extension_config["conda_channels"] = conda_channels
|
||||
extension_config["conda_dependencies"] = conda_dependencies
|
||||
extension_config["conda_python"] = conda_python
|
||||
find_links = tool_config.get("find_links", [])
|
||||
if find_links:
|
||||
extension_config["find_links"] = find_links
|
||||
if conda_platforms:
|
||||
extension_config["conda_platforms"] = conda_platforms
|
||||
|
||||
if execution_model != "host-coupled":
|
||||
extension_config["execution_model"] = execution_model
|
||||
if execution_model == "sealed_worker":
|
||||
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
|
||||
if isinstance(policy_ro_paths, list) and policy_ro_paths:
|
||||
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
|
||||
# Sealed workers keep the host RPC service inventory even when the
|
||||
# child resolves no API classes locally.
|
||||
|
||||
extension = manager.load_extension(extension_config)
|
||||
register_dummy_module(extension_name, node_dir)
|
||||
|
||||
# Register host-side event handlers via adapter
|
||||
from .adapter import ComfyUIAdapter
|
||||
ComfyUIAdapter.register_host_event_handlers(extension)
|
||||
|
||||
# Register web directory on the host — only when sandbox is disabled.
|
||||
# In sandbox mode, serving untrusted JS to the browser is not safe.
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# Register for proxied web serving — the child's web dir may have
|
||||
# content that doesn't exist on the host (e.g., pip-installed viewer
|
||||
# bundles). The WebDirectoryCache will lazily fetch via RPC.
|
||||
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
|
||||
cache = get_web_directory_cache()
|
||||
cache.register_proxy(extension_name, WebDirectoryProxy())
|
||||
|
||||
# Try cache first (lazy spawn)
|
||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||
cached_data = load_from_cache(node_dir, venv_root)
|
||||
if cached_data:
|
||||
if _is_stale_node_cache(cached_data):
|
||||
logger.debug(
|
||||
"][ %s cache is stale/incompatible; rebuilding metadata",
|
||||
extension_name,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"][ {extension_name} loaded from cache")
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
for node_name, details in cached_data.items():
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append(
|
||||
(node_name, details.get("display_name", node_name), stub_cls)
|
||||
)
|
||||
return specs
|
||||
|
||||
# Cache miss - spawn process and get metadata
|
||||
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
|
||||
|
||||
try:
|
||||
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s metadata discovery failed, skipping isolated load: %s",
|
||||
extension_name,
|
||||
exc,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
if not remote_nodes:
|
||||
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
cache_data: Dict[str, Dict] = {}
|
||||
|
||||
for node_name, display_name in remote_nodes.items():
|
||||
try:
|
||||
details = await extension.get_node_details(node_name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s failed to load metadata for %s, skipping node: %s",
|
||||
extension_name,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
details["display_name"] = display_name
|
||||
cache_data[node_name] = details
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append((node_name, display_name, stub_cls))
|
||||
|
||||
if not specs:
|
||||
logger.warning(
|
||||
"][ %s produced no usable nodes after metadata scan; skipping",
|
||||
extension_name,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
# Save metadata to cache for future runs
|
||||
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||
logger.debug(f"][ {extension_name} metadata cached")
|
||||
|
||||
# Re-check web directory AFTER child has populated it
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
|
||||
return specs
|
||||
|
||||
|
||||
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]
|
||||
878
comfy/isolation/extension_wrapper.py
Normal file
878
comfy/isolation/extension_wrapper.py
Normal file
@@ -0,0 +1,878 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError as e:
|
||||
raise AttributeError(item) from e
|
||||
|
||||
def copy(self):
|
||||
return AttrDict(super().copy())
|
||||
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pyisolate import ExtensionBase
|
||||
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
V3_DISCOVERY_TIMEOUT = 30
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
|
||||
"""Run the web asset copy step that prestartup_script.py used to do.
|
||||
|
||||
If the module's web/ directory is empty and the module had a
|
||||
prestartup_script.py that copied assets from pip packages, this
|
||||
function replicates that work inside the child process.
|
||||
|
||||
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
|
||||
defined, otherwise falls back to detecting common asset packages.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Already populated — nothing to do
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
return
|
||||
|
||||
os.makedirs(web_dir_path, exist_ok=True)
|
||||
|
||||
# Try module-defined copy spec first (generic hook for any node pack)
|
||||
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||
if copy_spec is not None and callable(copy_spec):
|
||||
try:
|
||||
copy_spec(web_dir_path)
|
||||
logger.info(
|
||||
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
|
||||
LOG_PREFIX, module_dir, e,
|
||||
)
|
||||
|
||||
# Fallback: detect comfy_3d_viewers and run copy_viewer()
|
||||
try:
|
||||
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||
viewers = list(VIEWER_FILES.keys())
|
||||
for viewer in viewers:
|
||||
try:
|
||||
copy_viewer(viewer, web_dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
if any(os.scandir(web_dir_path)):
|
||||
logger.info(
|
||||
"%s Copied %d viewer types from comfy_3d_viewers to %s",
|
||||
LOG_PREFIX, len(viewers), web_dir_path,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback: detect comfy_dynamic_widgets
|
||||
try:
|
||||
from comfy_dynamic_widgets import get_js_path
|
||||
src = os.path.realpath(get_js_path())
|
||||
if os.path.exists(src):
|
||||
dst_dir = os.path.join(web_dir_path, "js")
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
dst = os.path.join(dst_dir, "dynamic_widgets.js")
|
||||
shutil.copy2(src, dst)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _read_extension_name(module_dir: str) -> str:
|
||||
"""Read extension name from pyproject.toml, falling back to directory name."""
|
||||
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||
if os.path.exists(pyproject):
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
try:
|
||||
with open(pyproject, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
name = data.get("project", {}).get("name")
|
||||
if name:
|
||||
return name
|
||||
except Exception:
|
||||
pass
|
||||
return os.path.basename(module_dir)
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str) -> int:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return 0
|
||||
if not callable(flush_tensor_keeper):
|
||||
return 0
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
return flushed
|
||||
|
||||
|
||||
def _relieve_child_vram_pressure(marker: str) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _sanitize_for_transport(value):
|
||||
primitives = (str, int, float, bool, type(None))
|
||||
if isinstance(value, primitives):
|
||||
return value
|
||||
|
||||
cls_name = value.__class__.__name__
|
||||
if cls_name == "FlexibleOptionalInputType":
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _sanitize_for_transport(getattr(value, "type", "*")),
|
||||
}
|
||||
if cls_name == "AnyType":
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if cls_name == "ByPassTypeTuple":
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [
|
||||
_sanitize_for_transport(v) for v in tuple(value)
|
||||
]
|
||||
}
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _sanitize_for_transport(v) for k, v in value.items()}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_for_transport(v) for v in value]
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
|
||||
# The canonical definition is now in pyisolate._internal.remote_handle
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
|
||||
|
||||
|
||||
class ComfyNodeExtension(ExtensionBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.node_classes: Dict[str, type] = {}
|
||||
self.display_names: Dict[str, str] = {}
|
||||
self.node_instances: Dict[str, Any] = {}
|
||||
self.remote_objects: Dict[str, Any] = {}
|
||||
self._route_handlers: Dict[str, Any] = {}
|
||||
self._module: Any = None
|
||||
|
||||
async def on_module_loaded(self, module: Any) -> None:
|
||||
self._module = module
|
||||
|
||||
# Registries are initialized in host_hooks.py initialize_host_process()
|
||||
# They auto-register via ProxiedSingleton when instantiated
|
||||
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
|
||||
|
||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||
|
||||
# Register web directory with WebDirectoryProxy (child-side)
|
||||
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||
if web_dir_attr is not None:
|
||||
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||
ext_name = _read_extension_name(module_dir)
|
||||
|
||||
# If web dir is empty, run the copy step that prestartup_script.py did
|
||||
_run_prestartup_web_copy(module, module_dir, web_dir_path)
|
||||
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||
|
||||
try:
|
||||
from comfy_api.latest import ComfyExtension
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if not (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, ComfyExtension)
|
||||
and obj is not ComfyExtension
|
||||
):
|
||||
continue
|
||||
if not obj.__module__.startswith(module.__name__):
|
||||
continue
|
||||
try:
|
||||
ext_instance = obj()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in on_load()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
v3_nodes = await asyncio.wait_for(
|
||||
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in get_node_list()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
for node_cls in v3_nodes:
|
||||
if hasattr(node_cls, "GET_SCHEMA"):
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
self.node_classes[schema.node_id] = node_cls
|
||||
if schema.display_name:
|
||||
self.display_names[schema.node_id] = schema.display_name
|
||||
except Exception as e:
|
||||
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
module_name = getattr(module, "__name__", "isolated_nodes")
|
||||
for node_cls in self.node_classes.values():
|
||||
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
|
||||
node_cls.__module__ = module_name
|
||||
|
||||
self.node_instances = {}
|
||||
|
||||
async def list_nodes(self) -> Dict[str, str]:
|
||||
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||
|
||||
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
|
||||
return await self.get_node_details(node_name)
|
||||
|
||||
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||
|
||||
input_types_raw = (
|
||||
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||
)
|
||||
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
|
||||
if output_is_list is not None:
|
||||
output_is_list = tuple(bool(x) for x in output_is_list)
|
||||
|
||||
details: Dict[str, Any] = {
|
||||
"input_types": _sanitize_for_transport(input_types_raw),
|
||||
"return_types": tuple(
|
||||
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
|
||||
),
|
||||
"return_names": getattr(node_cls, "RETURN_NAMES", None),
|
||||
"function": str(getattr(node_cls, "FUNCTION", "execute")),
|
||||
"category": str(getattr(node_cls, "CATEGORY", "")),
|
||||
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
|
||||
"output_is_list": output_is_list,
|
||||
"is_v3": is_v3,
|
||||
}
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||
try:
|
||||
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||
except (AttributeError, TypeError):
|
||||
schema_v3 = self._build_schema_v3_fallback(schema)
|
||||
details.update(
|
||||
{
|
||||
"schema_v1": schema_v1,
|
||||
"schema_v3": schema_v3,
|
||||
"hidden": [h.value for h in (schema.hidden or [])],
|
||||
"description": getattr(schema, "description", ""),
|
||||
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
|
||||
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
|
||||
"api_node": bool(getattr(node_cls, "API_NODE", False)),
|
||||
"input_is_list": bool(
|
||||
getattr(node_cls, "INPUT_IS_LIST", False)
|
||||
),
|
||||
"not_idempotent": bool(
|
||||
getattr(node_cls, "NOT_IDEMPOTENT", False)
|
||||
),
|
||||
"accept_all_inputs": bool(
|
||||
getattr(node_cls, "ACCEPT_ALL_INPUTS", False)
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"%s V3 schema serialization failed for %s: %s",
|
||||
LOG_PREFIX,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
return details
|
||||
|
||||
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
|
||||
input_dict: Dict[str, Any] = {}
|
||||
output_dict: Dict[str, Any] = {}
|
||||
hidden_list: List[str] = []
|
||||
|
||||
if getattr(schema, "inputs", None):
|
||||
for inp in schema.inputs:
|
||||
self._add_schema_io_v3(inp, input_dict)
|
||||
if getattr(schema, "outputs", None):
|
||||
for out in schema.outputs:
|
||||
self._add_schema_io_v3(out, output_dict)
|
||||
if getattr(schema, "hidden", None):
|
||||
for h in schema.hidden:
|
||||
hidden_list.append(getattr(h, "value", str(h)))
|
||||
|
||||
return {
|
||||
"input": input_dict,
|
||||
"output": output_dict,
|
||||
"hidden": hidden_list,
|
||||
"name": getattr(schema, "node_id", None),
|
||||
"display_name": getattr(schema, "display_name", None),
|
||||
"description": getattr(schema, "description", None),
|
||||
"category": getattr(schema, "category", None),
|
||||
"output_node": getattr(schema, "is_output_node", False),
|
||||
"deprecated": getattr(schema, "is_deprecated", False),
|
||||
"experimental": getattr(schema, "is_experimental", False),
|
||||
"api_node": getattr(schema, "is_api_node", False),
|
||||
}
|
||||
|
||||
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
|
||||
io_id = getattr(io_obj, "id", None)
|
||||
if io_id is None:
|
||||
return
|
||||
|
||||
io_type_fn = getattr(io_obj, "get_io_type", None)
|
||||
io_type = (
|
||||
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
|
||||
)
|
||||
|
||||
as_dict_fn = getattr(io_obj, "as_dict", None)
|
||||
payload = as_dict_fn() if callable(as_dict_fn) else {}
|
||||
|
||||
target[str(io_id)] = (io_type, payload)
|
||||
|
||||
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
if hasattr(node_cls, "INPUT_TYPES"):
|
||||
return node_cls.INPUT_TYPES()
|
||||
return {}
|
||||
|
||||
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(inputs),
|
||||
)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
_relieve_child_vram_pressure("EXT:pre_execute")
|
||||
|
||||
resolved_inputs = self._resolve_remote_objects(inputs)
|
||||
|
||||
instance = self._get_node_instance(node_name)
|
||||
node_cls = self._get_node_class(node_name)
|
||||
|
||||
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
|
||||
# Hidden params come through RPC as string keys like "Hidden.prompt"
|
||||
from comfy_api.latest._io import Hidden, HiddenHolder
|
||||
|
||||
# Map string representations back to Hidden enum keys
|
||||
hidden_string_map = {
|
||||
"Hidden.unique_id": Hidden.unique_id,
|
||||
"Hidden.prompt": Hidden.prompt,
|
||||
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
|
||||
"Hidden.dynprompt": Hidden.dynprompt,
|
||||
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
|
||||
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
|
||||
# Uppercase enum VALUE forms — V3 execution engine passes these
|
||||
"UNIQUE_ID": Hidden.unique_id,
|
||||
"PROMPT": Hidden.prompt,
|
||||
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
|
||||
"DYNPROMPT": Hidden.dynprompt,
|
||||
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
|
||||
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Find and extract hidden parameters (both enum and string form)
|
||||
hidden_found = {}
|
||||
keys_to_remove = []
|
||||
|
||||
for key in list(resolved_inputs.keys()):
|
||||
# Check string form first (from RPC serialization)
|
||||
if key in hidden_string_map:
|
||||
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
# Also check enum form (direct calls)
|
||||
elif isinstance(key, Hidden):
|
||||
hidden_found[key] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
|
||||
# Remove hidden params from kwargs
|
||||
for key in keys_to_remove:
|
||||
resolved_inputs.pop(key)
|
||||
|
||||
# Set hidden on node class if any hidden params found
|
||||
if hidden_found:
|
||||
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
|
||||
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
|
||||
else:
|
||||
# Update existing hidden holder
|
||||
for key, value in hidden_found.items():
|
||||
setattr(node_cls.hidden, key.value.lower(), value)
|
||||
|
||||
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
|
||||
# flag is set. The isolation RPC delivers unwrapped values, so we must
|
||||
# wrap each input in a single-element list to match the contract.
|
||||
if getattr(node_cls, "INPUT_IS_LIST", False):
|
||||
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
|
||||
|
||||
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||
if not hasattr(instance, function_name):
|
||||
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||
|
||||
handler = getattr(instance, function_name)
|
||||
|
||||
try:
|
||||
import torch
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
with torch.inference_mode():
|
||||
result = await handler(**resolved_inputs)
|
||||
else:
|
||||
import functools
|
||||
|
||||
def _run_with_inference_mode(**kwargs):
|
||||
with torch.inference_mode():
|
||||
return handler(**kwargs)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, functools.partial(_run_with_inference_mode, **resolved_inputs)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:child_execute_error ext=%s node=%s",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
raise
|
||||
|
||||
if type(result).__name__ == "NodeOutput":
|
||||
node_output_dict = {
|
||||
"__node_output__": True,
|
||||
"args": self._wrap_unpicklable_objects(result.args),
|
||||
}
|
||||
if result.ui is not None:
|
||||
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
|
||||
if getattr(result, "expand", None) is not None:
|
||||
node_output_dict["expand"] = result.expand
|
||||
if getattr(result, "block_execution", None) is not None:
|
||||
node_output_dict["block_execution"] = result.block_execution
|
||||
return node_output_dict
|
||||
if self._is_comfy_protocol_return(result):
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
return 0
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
|
||||
)
|
||||
flushed = _flush_tensor_transport_state("EXT:workflow_end")
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
removed = registry.sweep_pending_cleanup()
|
||||
if removed > 0:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_done ext=%s flushed=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
flushed,
|
||||
)
|
||||
return flushed
|
||||
|
||||
async def get_remote_object(self, object_id: str) -> Any:
|
||||
"""Retrieve a remote object by ID for host-side deserialization."""
|
||||
if object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {object_id} not found")
|
||||
|
||||
return self.remote_objects[object_id]
|
||||
|
||||
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
|
||||
object_id = str(uuid.uuid4())
|
||||
self.remote_objects[object_id] = obj
|
||||
return RemoteObjectHandle(object_id, type(obj).__name__)
|
||||
|
||||
async def call_remote_object_method(
|
||||
self,
|
||||
object_id: str,
|
||||
method_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Invoke a method or attribute-backed accessor on a child-owned object."""
|
||||
obj = await self.get_remote_object(object_id)
|
||||
|
||||
if method_name == "get_patcher_attr":
|
||||
return getattr(obj, args[0])
|
||||
if method_name == "get_model_options":
|
||||
return getattr(obj, "model_options")
|
||||
if method_name == "set_model_options":
|
||||
setattr(obj, "model_options", args[0])
|
||||
return None
|
||||
if method_name == "get_object_patches":
|
||||
return getattr(obj, "object_patches")
|
||||
if method_name == "get_patches":
|
||||
return getattr(obj, "patches")
|
||||
if method_name == "get_wrappers":
|
||||
return getattr(obj, "wrappers")
|
||||
if method_name == "get_callbacks":
|
||||
return getattr(obj, "callbacks")
|
||||
if method_name == "get_load_device":
|
||||
return getattr(obj, "load_device")
|
||||
if method_name == "get_offload_device":
|
||||
return getattr(obj, "offload_device")
|
||||
if method_name == "get_hook_mode":
|
||||
return getattr(obj, "hook_mode")
|
||||
if method_name == "get_parent":
|
||||
parent = getattr(obj, "parent", None)
|
||||
if parent is None:
|
||||
return None
|
||||
return self._store_remote_object_handle(parent)
|
||||
if method_name == "get_inner_model_attr":
|
||||
attr_name = args[0]
|
||||
if hasattr(obj.model, attr_name):
|
||||
return getattr(obj.model, attr_name)
|
||||
if hasattr(obj, attr_name):
|
||||
return getattr(obj, attr_name)
|
||||
return None
|
||||
if method_name == "inner_model_apply_model":
|
||||
return obj.model.apply_model(*args[0], **args[1])
|
||||
if method_name == "inner_model_extra_conds_shapes":
|
||||
return obj.model.extra_conds_shapes(*args[0], **args[1])
|
||||
if method_name == "inner_model_extra_conds":
|
||||
return obj.model.extra_conds(*args[0], **args[1])
|
||||
if method_name == "inner_model_memory_required":
|
||||
return obj.model.memory_required(*args[0], **args[1])
|
||||
if method_name == "process_latent_in":
|
||||
return obj.model.process_latent_in(*args[0], **args[1])
|
||||
if method_name == "process_latent_out":
|
||||
return obj.model.process_latent_out(*args[0], **args[1])
|
||||
if method_name == "scale_latent_inpaint":
|
||||
return obj.model.scale_latent_inpaint(*args[0], **args[1])
|
||||
if method_name.startswith("get_"):
|
||||
attr_name = method_name[4:]
|
||||
if hasattr(obj, attr_name):
|
||||
return getattr(obj, attr_name)
|
||||
|
||||
target = getattr(obj, method_name)
|
||||
if callable(target):
|
||||
result = target(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if type(result).__name__ == "ModelPatcher":
|
||||
return self._store_remote_object_handle(result)
|
||||
return result
|
||||
if args or kwargs:
|
||||
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
|
||||
return target
|
||||
|
||||
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
tensor = data.detach() if data.requires_grad else data
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
|
||||
return tensor.cpu()
|
||||
return tensor
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
data, "last_hidden_state"
|
||||
):
|
||||
fields = {}
|
||||
for attr in (
|
||||
"penultimate_hidden_states",
|
||||
"last_hidden_state",
|
||||
"image_embeds",
|
||||
"text_embeds",
|
||||
):
|
||||
if hasattr(data, attr):
|
||||
try:
|
||||
fields[attr] = self._wrap_unpicklable_objects(
|
||||
getattr(data, attr)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if fields:
|
||||
return {"__pyisolate_attribute_container__": True, "data": fields}
|
||||
|
||||
# Avoid converting arbitrary objects with stateful methods (models, etc.)
|
||||
# They will be handled via RemoteObjectHandle below.
|
||||
|
||||
type_name = type(data).__name__
|
||||
if type_name == "ModelPatcherProxy":
|
||||
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
|
||||
if type_name == "CLIPProxy":
|
||||
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
|
||||
if type_name == "VAEProxy":
|
||||
return {"__type__": "VAERef", "vae_id": data._instance_id}
|
||||
if type_name == "ModelSamplingProxy":
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
|
||||
return tuple(wrapped) if isinstance(data, tuple) else wrapped
|
||||
if isinstance(data, dict):
|
||||
converted_dict = {
|
||||
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
|
||||
}
|
||||
return {"__pyisolate_attrdict__": True, "data": converted_dict}
|
||||
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
|
||||
registry = SerializerRegistry.get_instance()
|
||||
if registry.is_data_type(type_name):
|
||||
serializer = registry.get_serializer(type_name)
|
||||
if serializer:
|
||||
return serializer(data)
|
||||
|
||||
return self._store_remote_object_handle(data)
|
||||
|
||||
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, RemoteObjectHandle):
|
||||
if data.object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {data.object_id} not found")
|
||||
return self.remote_objects[data.object_id]
|
||||
|
||||
if isinstance(data, dict):
|
||||
ref_type = data.get("__type__")
|
||||
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
if ref_type == "ModelSamplingRef":
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
resolved = [self._resolve_remote_objects(item) for item in data]
|
||||
return tuple(resolved) if isinstance(data, tuple) else resolved
|
||||
return data
|
||||
|
||||
def _get_node_class(self, node_name: str) -> type:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
return self.node_classes[node_name]
|
||||
|
||||
def _get_node_instance(self, node_name: str) -> Any:
|
||||
if node_name not in self.node_instances:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
self.node_instances[node_name] = self.node_classes[node_name]()
|
||||
return self.node_instances[node_name]
|
||||
|
||||
async def before_module_loaded(self) -> None:
|
||||
# Inject initialization here if we think this is the child
|
||||
try:
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
initialize_proxies()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(
|
||||
f"Failed to call initialize_proxies in before_module_loaded: {e}"
|
||||
)
|
||||
|
||||
await super().before_module_loaded()
|
||||
try:
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
|
||||
ComfyAPI_latest.Execution = ProgressProxy
|
||||
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
|
||||
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
|
||||
# latest_ui.folder_paths = fp_proxy
|
||||
# latest_resources.folder_paths = fp_proxy
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def call_route_handler(
|
||||
self,
|
||||
handler_module: str,
|
||||
handler_func: str,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Any:
|
||||
cache_key = f"{handler_module}.{handler_func}"
|
||||
if cache_key not in self._route_handlers:
|
||||
if self._module is not None and hasattr(self._module, "__file__"):
|
||||
node_dir = os.path.dirname(self._module.__file__)
|
||||
if node_dir not in sys.path:
|
||||
sys.path.insert(0, node_dir)
|
||||
try:
|
||||
module = importlib.import_module(handler_module)
|
||||
self._route_handlers[cache_key] = getattr(module, handler_func)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Route handler not found: {cache_key}") from e
|
||||
|
||||
handler = self._route_handlers[cache_key]
|
||||
mock_request = MockRequest(request_data)
|
||||
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(mock_request)
|
||||
else:
|
||||
result = handler(mock_request)
|
||||
return self._serialize_response(result)
|
||||
|
||||
def _is_comfy_protocol_return(self, result: Any) -> bool:
|
||||
"""
|
||||
Check if the result matches the ComfyUI 'Protocol Return' schema.
|
||||
|
||||
A Protocol Return is a dictionary containing specific reserved keys that
|
||||
ComfyUI's execution engine interprets as instructions (UI updates,
|
||||
Workflow expansion, etc.) rather than purely data outputs.
|
||||
|
||||
Schema:
|
||||
- Must be a dict
|
||||
- Must contain at least one of: 'ui', 'result', 'expand'
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
return any(key in result for key in ("ui", "result", "expand"))
|
||||
|
||||
def _serialize_response(self, response: Any) -> Dict[str, Any]:
|
||||
if response is None:
|
||||
return {"type": "text", "body": "", "status": 204}
|
||||
if isinstance(response, dict):
|
||||
return {"type": "json", "body": response, "status": 200}
|
||||
if isinstance(response, str):
|
||||
return {"type": "text", "body": response, "status": 200}
|
||||
if hasattr(response, "text") and hasattr(response, "status"):
|
||||
return {
|
||||
"type": "text",
|
||||
"body": response.text
|
||||
if hasattr(response, "text")
|
||||
else str(response.body),
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers)
|
||||
if hasattr(response, "headers")
|
||||
else {},
|
||||
}
|
||||
if hasattr(response, "body") and hasattr(response, "status"):
|
||||
body = response.body
|
||||
if isinstance(body, bytes):
|
||||
try:
|
||||
return {
|
||||
"type": "text",
|
||||
"body": body.decode("utf-8"),
|
||||
"status": response.status,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"type": "binary",
|
||||
"body": body.hex(),
|
||||
"status": response.status,
|
||||
}
|
||||
return {"type": "json", "body": body, "status": response.status}
|
||||
return {"type": "text", "body": str(response), "status": 200}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.method = data.get("method", "GET")
|
||||
self.path = data.get("path", "/")
|
||||
self.query = data.get("query", {})
|
||||
self._body = data.get("body", {})
|
||||
self._text = data.get("text", "")
|
||||
self.headers = data.get("headers", {})
|
||||
self.content_type = data.get(
|
||||
"content_type", self.headers.get("Content-Type", "application/json")
|
||||
)
|
||||
self.match_info = data.get("match_info", {})
|
||||
|
||||
async def json(self) -> Any:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
if isinstance(self._body, str):
|
||||
return json.loads(self._body)
|
||||
return {}
|
||||
|
||||
async def post(self) -> Dict[str, Any]:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
return {}
|
||||
|
||||
async def text(self) -> str:
|
||||
if self._text:
|
||||
return self._text
|
||||
if isinstance(self._body, str):
|
||||
return self._body
|
||||
if isinstance(self._body, dict):
|
||||
return json.dumps(self._body)
|
||||
return ""
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return (await self.text()).encode("utf-8")
|
||||
30
comfy/isolation/host_hooks.py
Normal file
30
comfy/isolation/host_hooks.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# Host process initialization for PyIsolate
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_host_process() -> None:
|
||||
root = logging.getLogger()
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
root.addHandler(logging.NullHandler())
|
||||
|
||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from .proxies.helper_proxies import HelperProxiesService
|
||||
from .proxies.model_management_proxy import ModelManagementProxy
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
from .proxies.prompt_server_impl import PromptServerService
|
||||
from .proxies.utils_proxy import UtilsProxy
|
||||
from .proxies.web_directory_proxy import WebDirectoryProxy
|
||||
from .vae_proxy import VAERegistry
|
||||
|
||||
FolderPathsProxy()
|
||||
HelperProxiesService()
|
||||
ModelManagementProxy()
|
||||
ProgressProxy()
|
||||
PromptServerService()
|
||||
UtilsProxy()
|
||||
WebDirectoryProxy()
|
||||
VAERegistry()
|
||||
178
comfy/isolation/host_policy.py
Normal file
178
comfy/isolation/host_policy.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# pylint: disable=logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
|
||||
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
|
||||
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
|
||||
|
||||
|
||||
class HostSecurityPolicy(TypedDict):
|
||||
sandbox_mode: str
|
||||
allow_network: bool
|
||||
writable_paths: List[str]
|
||||
readonly_paths: List[str]
|
||||
sealed_worker_ro_import_paths: List[str]
|
||||
whitelist: Dict[str, str]
|
||||
|
||||
|
||||
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": ["/dev/shm"],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": [],
|
||||
"whitelist": {},
|
||||
}
|
||||
|
||||
|
||||
def _default_policy() -> HostSecurityPolicy:
|
||||
return {
|
||||
"sandbox_mode": DEFAULT_POLICY["sandbox_mode"],
|
||||
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
|
||||
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_writable_paths(paths: list[object]) -> list[str]:
|
||||
normalized_paths: list[str] = []
|
||||
for raw_path in paths:
|
||||
# Host-policy paths are contract-style POSIX paths; keep representation
|
||||
# stable across Windows/Linux so tests and config behavior stay consistent.
|
||||
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
|
||||
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
|
||||
continue
|
||||
normalized_paths.append(normalized_path)
|
||||
return normalized_paths
|
||||
|
||||
|
||||
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
|
||||
if not file_path.is_absolute():
|
||||
file_path = config_path.parent / file_path
|
||||
if not file_path.exists():
|
||||
logger.warning("whitelist_file %s not found, skipping.", file_path)
|
||||
return {}
|
||||
entries: Dict[str, str] = {}
|
||||
for line in file_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
entries[line] = "*"
|
||||
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
|
||||
return entries
|
||||
|
||||
|
||||
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
|
||||
if not isinstance(raw_paths, list):
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
|
||||
)
|
||||
|
||||
normalized_paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw_path in raw_paths:
|
||||
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
|
||||
)
|
||||
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
|
||||
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
|
||||
is_absolute = normalized_path.startswith("/") or (
|
||||
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
|
||||
)
|
||||
if not is_absolute:
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
|
||||
)
|
||||
if normalized_path not in seen:
|
||||
seen.add(normalized_path)
|
||||
normalized_paths.append(normalized_path)
|
||||
|
||||
return normalized_paths
|
||||
|
||||
|
||||
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
|
||||
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
|
||||
policy = _default_policy()
|
||||
|
||||
if not config_path.exists():
|
||||
logger.debug("Host policy file missing at %s, using defaults.", config_path)
|
||||
return policy
|
||||
|
||||
try:
|
||||
with config_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse host policy from %s, using defaults.",
|
||||
config_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return policy
|
||||
|
||||
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
|
||||
if not isinstance(tool_config, dict):
|
||||
logger.debug("No [tool.comfy.host] section found, using defaults.")
|
||||
return policy
|
||||
|
||||
sandbox_mode = tool_config.get("sandbox_mode")
|
||||
if isinstance(sandbox_mode, str):
|
||||
normalized_sandbox_mode = sandbox_mode.strip().lower()
|
||||
if normalized_sandbox_mode in VALID_SANDBOX_MODES:
|
||||
policy["sandbox_mode"] = normalized_sandbox_mode
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid host sandbox_mode %r in %s, using default %r.",
|
||||
sandbox_mode,
|
||||
config_path,
|
||||
DEFAULT_POLICY["sandbox_mode"],
|
||||
)
|
||||
|
||||
if "allow_network" in tool_config:
|
||||
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||
|
||||
if "writable_paths" in tool_config:
|
||||
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
|
||||
|
||||
if "readonly_paths" in tool_config:
|
||||
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||
|
||||
if "sealed_worker_ro_import_paths" in tool_config:
|
||||
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
|
||||
tool_config["sealed_worker_ro_import_paths"]
|
||||
)
|
||||
|
||||
whitelist_file = tool_config.get("whitelist_file")
|
||||
if isinstance(whitelist_file, str):
|
||||
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
|
||||
|
||||
whitelist_raw = tool_config.get("whitelist")
|
||||
if isinstance(whitelist_raw, dict):
|
||||
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
|
||||
|
||||
logger.debug(
|
||||
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
||||
len(policy["whitelist"]),
|
||||
policy["sandbox_mode"],
|
||||
policy["allow_network"],
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]
|
||||
221
comfy/isolation/manifest_loader.py
Normal file
221
comfy/isolation/manifest_loader.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import folder_paths
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_SUBDIR = "cache"
|
||||
CACHE_KEY_FILE = "cache_key"
|
||||
CACHE_DATA_FILE = "node_info.json"
|
||||
CACHE_KEY_LENGTH = 16
|
||||
_NESTED_SCAN_ROOT = "packages"
|
||||
_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"}
|
||||
|
||||
|
||||
def _read_manifest(manifest_path: Path) -> dict[str, Any] | None:
|
||||
try:
|
||||
with manifest_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _is_isolation_manifest(data: dict[str, Any]) -> bool:
|
||||
return (
|
||||
"tool" in data
|
||||
and "comfy" in data["tool"]
|
||||
and "isolation" in data["tool"]["comfy"]
|
||||
)
|
||||
|
||||
|
||||
def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]:
|
||||
packages_root = entry / _NESTED_SCAN_ROOT
|
||||
if not packages_root.exists() or not packages_root.is_dir():
|
||||
return []
|
||||
|
||||
nested: List[Tuple[Path, Path]] = []
|
||||
for manifest in sorted(packages_root.rglob("pyproject.toml")):
|
||||
node_dir = manifest.parent
|
||||
if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts):
|
||||
continue
|
||||
|
||||
data = _read_manifest(manifest)
|
||||
if not data or not _is_isolation_manifest(data):
|
||||
continue
|
||||
|
||||
isolation = data["tool"]["comfy"]["isolation"]
|
||||
if isolation.get("standalone") is True:
|
||||
nested.append((node_dir, manifest))
|
||||
|
||||
return nested
|
||||
|
||||
|
||||
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
||||
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
|
||||
manifest_dirs: List[Tuple[Path, Path]] = []
|
||||
|
||||
# Standard custom_nodes paths
|
||||
for base_path in folder_paths.get_folder_paths("custom_nodes"):
|
||||
base = Path(base_path)
|
||||
if not base.exists() or not base.is_dir():
|
||||
continue
|
||||
|
||||
for entry in base.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
# Look for pyproject.toml
|
||||
manifest = entry / "pyproject.toml"
|
||||
if not manifest.exists():
|
||||
continue
|
||||
|
||||
data = _read_manifest(manifest)
|
||||
if not data or not _is_isolation_manifest(data):
|
||||
continue
|
||||
|
||||
manifest_dirs.append((entry, manifest))
|
||||
manifest_dirs.extend(_discover_nested_manifests(entry))
|
||||
|
||||
return manifest_dirs
|
||||
|
||||
|
||||
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
|
||||
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
try:
|
||||
# Hashing the manifest content ensures config changes invalidate cache
|
||||
hasher.update(manifest_path.read_bytes())
|
||||
except OSError:
|
||||
hasher.update(b"__manifest_read_error__")
|
||||
|
||||
try:
|
||||
py_files = sorted(node_dir.rglob("*.py"))
|
||||
for py_file in py_files:
|
||||
rel_path = py_file.relative_to(node_dir)
|
||||
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
|
||||
continue
|
||||
hasher.update(str(rel_path).encode("utf-8"))
|
||||
try:
|
||||
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
|
||||
except OSError:
|
||||
hasher.update(b"__file_stat_error__")
|
||||
except OSError:
|
||||
hasher.update(b"__dir_scan_error__")
|
||||
|
||||
hasher.update(sys.version.encode("utf-8"))
|
||||
|
||||
try:
|
||||
import pyisolate
|
||||
|
||||
hasher.update(pyisolate.__version__.encode("utf-8"))
|
||||
except (ImportError, AttributeError):
|
||||
hasher.update(b"__pyisolate_unknown__")
|
||||
|
||||
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
|
||||
|
||||
|
||||
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
|
||||
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
|
||||
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
|
||||
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
|
||||
|
||||
|
||||
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
|
||||
"""Return True only if stored cache key matches current computed key."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_key_file.exists() or not cache_data_file.exists():
|
||||
return False
|
||||
current_key = compute_cache_key(node_dir, manifest_path)
|
||||
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
|
||||
return current_key == stored_key
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
|
||||
"""Load node metadata from cache, return None on any error."""
|
||||
try:
|
||||
_, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_data_file.exists():
|
||||
return None
|
||||
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_to_cache(
|
||||
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
|
||||
) -> None:
|
||||
"""Save node metadata and cache key atomically."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
cache_dir = cache_key_file.parent
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_key = compute_cache_key(node_dir, manifest_path)
|
||||
|
||||
# Atomic write: data
|
||||
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
|
||||
json.dump(node_data, f, indent=2)
|
||||
os.replace(tmp_data_path, cache_data_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_data_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Atomic write: key
|
||||
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
|
||||
f.write(cache_key)
|
||||
os.replace(tmp_key_path, cache_key_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_key_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"find_manifest_directories",
|
||||
"compute_cache_key",
|
||||
"get_cache_path",
|
||||
"is_cache_valid",
|
||||
"load_from_cache",
|
||||
"save_to_cache",
|
||||
]
|
||||
888
comfy/isolation/model_patcher_proxy.py
Normal file
888
comfy/isolation/model_patcher_proxy.py
Normal file
@@ -0,0 +1,888 @@
|
||||
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
|
||||
# RPC proxy for ModelPatcher (parent process)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, List, Set, Dict, Callable
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
AutoPatcherEjector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
_registry_class = ModelPatcherRegistry
|
||||
__module__ = "comfy.model_patcher"
|
||||
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
|
||||
|
||||
def _spawn_related_proxy(self, instance_id: str) -> "ModelPatcherProxy":
|
||||
proxy = ModelPatcherProxy(
|
||||
instance_id,
|
||||
self._registry,
|
||||
manage_lifecycle=not IS_CHILD_PROCESS,
|
||||
)
|
||||
if getattr(self, "_rpc_caller", None) is not None:
|
||||
proxy._rpc_caller = self._rpc_caller
|
||||
return proxy
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
else:
|
||||
self._rpc_caller = self._registry
|
||||
return self._rpc_caller
|
||||
|
||||
def get_all_callbacks(self, call_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_callbacks", call_type)
|
||||
|
||||
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_wrappers", wrapper_type)
|
||||
|
||||
def _load_list(self, *args, **kwargs) -> Any:
|
||||
return self._call_rpc("load_list_internal", *args, **kwargs)
|
||||
|
||||
def prepare_hook_patches_current_keyframe(
|
||||
self, t: Any, hook_group: Any, model_options: Any
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
|
||||
)
|
||||
|
||||
def add_hook_patches(
|
||||
self,
|
||||
hook: Any,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"add_hook_patches", hook, patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
def clear_cached_hook_weights(self) -> None:
|
||||
self._call_rpc("clear_cached_hook_weights")
|
||||
|
||||
def get_combined_hook_patches(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("get_combined_hook_patches", hooks)
|
||||
|
||||
def get_additional_models_with_key(self, key: str) -> Any:
|
||||
return self._call_rpc("get_additional_models_with_key", key)
|
||||
|
||||
@property
|
||||
def object_patches(self) -> Any:
|
||||
return self._call_rpc("get_object_patches")
|
||||
|
||||
@property
|
||||
def patches(self) -> Any:
|
||||
res = self._call_rpc("get_patches")
|
||||
if isinstance(res, dict):
|
||||
new_res = {}
|
||||
for k, v in res.items():
|
||||
new_list = []
|
||||
for item in v:
|
||||
if isinstance(item, list):
|
||||
new_list.append(tuple(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
new_res[k] = new_list
|
||||
return new_res
|
||||
return res
|
||||
|
||||
@property
|
||||
def pinned(self) -> Set:
|
||||
val = self._call_rpc("get_patcher_attr", "pinned")
|
||||
return set(val) if val is not None else set()
|
||||
|
||||
@property
|
||||
def hook_patches(self) -> Dict:
|
||||
val = self._call_rpc("get_patcher_attr", "hook_patches")
|
||||
if val is None:
|
||||
return {}
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import json
|
||||
|
||||
new_val = {}
|
||||
for k, v in val.items():
|
||||
if isinstance(k, str):
|
||||
if k.startswith("PYISOLATE_HOOKREF:"):
|
||||
ref_id = k.split(":", 1)[1]
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
elif k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[len("__pyisolate_key__") :]
|
||||
data = json.loads(json_str)
|
||||
ref_id = None
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and item[0] == "id"
|
||||
):
|
||||
ref_id = item[1]
|
||||
break
|
||||
if ref_id:
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
except Exception:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
return new_val
|
||||
except ImportError:
|
||||
return val
|
||||
|
||||
def set_hook_mode(self, hook_mode: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", hook_mode)
|
||||
|
||||
def register_all_hook_patches(
|
||||
self,
|
||||
hooks: Any,
|
||||
target_dict: Any,
|
||||
model_options: Any = None,
|
||||
registered: Any = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"register_all_hook_patches", hooks, target_dict, model_options, registered
|
||||
)
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
if isinstance(other, ModelPatcherProxy):
|
||||
return self._call_rpc("is_clone_by_id", other._instance_id)
|
||||
return False
|
||||
|
||||
def clone(self) -> ModelPatcherProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return self._spawn_related_proxy(new_id)
|
||||
|
||||
def clone_has_same_weights(self, clone: Any) -> bool:
|
||||
if isinstance(clone, ModelPatcherProxy):
|
||||
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
|
||||
if not IS_CHILD_PROCESS:
|
||||
return self._call_rpc("is_clone", clone)
|
||||
return False
|
||||
|
||||
def get_model_object(self, name: str) -> Any:
|
||||
return self._call_rpc("get_model_object", name)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
data = self._call_rpc("get_model_options")
|
||||
import json
|
||||
|
||||
def _decode_keys(obj):
|
||||
if isinstance(obj, dict):
|
||||
new_d = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[17:]
|
||||
val = json.loads(json_str)
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
new_d[val] = _decode_keys(v)
|
||||
except:
|
||||
new_d[k] = _decode_keys(v)
|
||||
else:
|
||||
new_d[k] = _decode_keys(v)
|
||||
return new_d
|
||||
if isinstance(obj, list):
|
||||
return [_decode_keys(x) for x in obj]
|
||||
return obj
|
||||
|
||||
return _decode_keys(data)
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_model_options", value)
|
||||
|
||||
def apply_hooks(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("apply_hooks", hooks)
|
||||
|
||||
def prepare_state(self, timestep: Any) -> Any:
|
||||
return self._call_rpc("prepare_state", timestep)
|
||||
|
||||
def restore_hook_patches(self) -> None:
|
||||
self._call_rpc("restore_hook_patches")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
|
||||
self._call_rpc("unpatch_hooks", whitelist_keys_set)
|
||||
|
||||
def model_patches_to(self, device: Any) -> Any:
|
||||
return self._call_rpc("model_patches_to", device)
|
||||
|
||||
def partially_load(
|
||||
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"partially_load", device, extra_memory, force_patch_weights
|
||||
)
|
||||
|
||||
def partially_unload(
|
||||
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
|
||||
) -> int:
|
||||
return self._call_rpc(
|
||||
"partially_unload", device_to, memory_to_free, force_patch_weights
|
||||
)
|
||||
|
||||
def load(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
)
|
||||
|
||||
def patch_model(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
self._call_rpc(
|
||||
"patch_model",
|
||||
device_to,
|
||||
lowvram_model_memory,
|
||||
load_weights,
|
||||
force_patch_weights,
|
||||
)
|
||||
return self
|
||||
|
||||
def unpatch_model(
|
||||
self, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._call_rpc("unpatch_model", device_to, unpatch_weights)
|
||||
|
||||
def detach(self, unpatch_all: bool = True) -> Any:
|
||||
self._call_rpc("detach", unpatch_all)
|
||||
return self.model
|
||||
|
||||
def _cpu_tensor_bytes(self, obj: Any) -> int:
|
||||
import torch
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device.type == "cpu":
|
||||
return obj.nbytes
|
||||
return 0
|
||||
if isinstance(obj, dict):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj)
|
||||
return 0
|
||||
|
||||
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
|
||||
if required_bytes <= 0:
|
||||
return True
|
||||
|
||||
import torch
|
||||
import comfy.model_management as model_management
|
||||
|
||||
target_raw = self.load_device
|
||||
try:
|
||||
if isinstance(target_raw, torch.device):
|
||||
target = target_raw
|
||||
elif isinstance(target_raw, str):
|
||||
target = torch.device(target_raw)
|
||||
elif isinstance(target_raw, int):
|
||||
target = torch.device(f"cuda:{target_raw}")
|
||||
else:
|
||||
target = torch.device(target_raw)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if target.type != "cuda":
|
||||
return True
|
||||
|
||||
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
|
||||
if model_management.get_free_memory(target) >= required:
|
||||
return True
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.free_memory(required, target, for_dynamic=True)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
|
||||
model_management.free_memory(required, target, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.load_models_gpu(
|
||||
[self],
|
||||
minimum_memory_required=required,
|
||||
)
|
||||
|
||||
return model_management.get_free_memory(target) >= required
|
||||
|
||||
def apply_model(self, *args, **kwargs) -> Any:
|
||||
import torch
|
||||
|
||||
def _preferred_device() -> Any:
|
||||
for value in args:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
return None
|
||||
|
||||
def _move_result_to_device(obj: Any, device: Any) -> Any:
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.to(device) if obj.device != device else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_move_result_to_device(v, device) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_move_result_to_device(v, device) for v in obj)
|
||||
return obj
|
||||
|
||||
# DynamicVRAM models must keep load/offload decisions in host process.
|
||||
# Child-side CUDA staging here can deadlock before first inference RPC.
|
||||
if self.is_dynamic():
|
||||
out = self._call_rpc("inner_model_apply_model", args, kwargs)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
|
||||
def _to_cuda(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
|
||||
return obj.to("cuda")
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cuda(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cuda(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cuda(v) for v in obj)
|
||||
return obj
|
||||
|
||||
try:
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
except torch.OutOfMemoryError:
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
|
||||
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
keys = self._call_rpc("model_state_dict", filter_prefix)
|
||||
return dict.fromkeys(keys, None)
|
||||
|
||||
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
|
||||
res = self._call_rpc("add_patches", *args, **kwargs)
|
||||
if isinstance(res, list):
|
||||
return [tuple(x) if isinstance(x, list) else x for x in res]
|
||||
return res
|
||||
|
||||
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
return self._call_rpc("get_key_patches", filter_prefix)
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
self._call_rpc("pin_weight_to_device", key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
self._call_rpc("unpin_weight", key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self._call_rpc("unpin_all_weights")
|
||||
|
||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
|
||||
return self._call_rpc(
|
||||
"calculate_weight", patches, weight, key, intermediate_dtype
|
||||
)
|
||||
|
||||
def inject_model(self) -> None:
|
||||
self._call_rpc("inject_model")
|
||||
|
||||
def eject_model(self) -> None:
|
||||
self._call_rpc("eject_model")
|
||||
|
||||
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
|
||||
return AutoPatcherEjector(
|
||||
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
|
||||
)
|
||||
|
||||
@property
|
||||
def is_injected(self) -> bool:
|
||||
return self._call_rpc("get_is_injected")
|
||||
|
||||
@property
|
||||
def skip_injection(self) -> bool:
|
||||
return self._call_rpc("get_skip_injection")
|
||||
|
||||
@skip_injection.setter
|
||||
def skip_injection(self, value: bool) -> None:
|
||||
self._call_rpc("set_skip_injection", value)
|
||||
|
||||
def clean_hooks(self) -> None:
|
||||
self._call_rpc("clean_hooks")
|
||||
|
||||
def pre_run(self) -> None:
|
||||
self._call_rpc("pre_run")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
try:
|
||||
self._call_rpc("cleanup")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcherProxy cleanup RPC failed for %s",
|
||||
self._instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def model(self) -> _InnerModelProxy:
|
||||
return _InnerModelProxy(self)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
_whitelisted_attrs = {
|
||||
"hook_patches_backup",
|
||||
"hook_backup",
|
||||
"cached_hook_patches",
|
||||
"current_hooks",
|
||||
"forced_hooks",
|
||||
"is_clip",
|
||||
"patches_uuid",
|
||||
"pinned",
|
||||
"attachments",
|
||||
"additional_models",
|
||||
"injections",
|
||||
"hook_patches",
|
||||
"model_lowvram",
|
||||
"model_loaded_weight_memory",
|
||||
"backup",
|
||||
"object_patches_backup",
|
||||
"weight_wrapper_patches",
|
||||
"weight_inplace_update",
|
||||
"force_cast_weights",
|
||||
}
|
||||
if name in _whitelisted_attrs:
|
||||
return self._call_rpc("get_patcher_attr", name)
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
lora_path: str,
|
||||
strength_model: float,
|
||||
clip: Optional[Any] = None,
|
||||
strength_clip: float = 1.0,
|
||||
) -> tuple:
|
||||
clip_id = None
|
||||
if clip is not None:
|
||||
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
|
||||
result = self._call_rpc(
|
||||
"load_lora", lora_path, strength_model, clip_id, strength_clip
|
||||
)
|
||||
new_model = None
|
||||
if result.get("model_id"):
|
||||
new_model = self._spawn_related_proxy(result["model_id"])
|
||||
new_clip = None
|
||||
if result.get("clip_id"):
|
||||
from comfy.isolation.clip_proxy import CLIPProxy
|
||||
|
||||
new_clip = CLIPProxy(result["clip_id"])
|
||||
return (new_model, new_clip)
|
||||
|
||||
@property
|
||||
def load_device(self) -> Any:
|
||||
return self._call_rpc("get_load_device")
|
||||
|
||||
@property
|
||||
def offload_device(self) -> Any:
|
||||
return self._call_rpc("get_offload_device")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.load_device
|
||||
|
||||
def current_loaded_device(self) -> Any:
|
||||
return self._call_rpc("current_loaded_device")
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._call_rpc("get_size")
|
||||
|
||||
def model_size(self) -> Any:
|
||||
return self._call_rpc("model_size")
|
||||
|
||||
def loaded_size(self) -> Any:
|
||||
return self._call_rpc("loaded_size")
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return self._call_rpc("lowvram_patch_counter")
|
||||
|
||||
def memory_required(self, input_shape: Any) -> Any:
|
||||
return self._call_rpc("memory_required", input_shape)
|
||||
|
||||
def get_operation_state(self) -> Dict[str, Any]:
|
||||
state = self._call_rpc("get_operation_state")
|
||||
return state if isinstance(state, dict) else {}
|
||||
|
||||
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
|
||||
return bool(self._call_rpc("wait_for_idle", timeout_ms))
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return bool(self._call_rpc("is_dynamic"))
|
||||
|
||||
def get_free_memory(self, device: Any) -> Any:
|
||||
return self._call_rpc("get_free_memory", device)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload: int) -> Any:
|
||||
return self._call_rpc("partially_unload_ram", ram_to_unload)
|
||||
|
||||
def model_dtype(self) -> Any:
|
||||
res = self._call_rpc("model_dtype")
|
||||
if isinstance(res, str) and res.startswith("torch."):
|
||||
try:
|
||||
import torch
|
||||
|
||||
attr = res.split(".")[-1]
|
||||
if hasattr(torch, attr):
|
||||
return getattr(torch, attr)
|
||||
except ImportError:
|
||||
pass
|
||||
return res
|
||||
|
||||
@property
|
||||
def hook_mode(self) -> Any:
|
||||
return self._call_rpc("get_hook_mode")
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", value)
|
||||
|
||||
def set_model_sampler_cfg_function(
|
||||
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_cfg_function",
|
||||
sampler_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_post_cfg_function(
|
||||
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_post_cfg_function",
|
||||
post_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_pre_cfg_function(
|
||||
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_pre_cfg_function",
|
||||
pre_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
|
||||
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
|
||||
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
|
||||
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
|
||||
|
||||
def set_model_patch(self, patch: Any, name: str) -> None:
|
||||
self._call_rpc("set_model_patch", patch, name)
|
||||
|
||||
def set_model_patch_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
name: str,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_patch_replace",
|
||||
patch,
|
||||
name,
|
||||
block_name,
|
||||
number,
|
||||
transformer_index,
|
||||
)
|
||||
|
||||
def set_model_attn1_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn1", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn2_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn2", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def set_model_emb_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "emb_patch")
|
||||
|
||||
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def set_model_post_input_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(
|
||||
self,
|
||||
scale_x=1.0,
|
||||
shift_x=0.0,
|
||||
scale_y=1.0,
|
||||
shift_y=0.0,
|
||||
scale_t=1.0,
|
||||
shift_t=0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
options = {
|
||||
"scale_x": scale_x,
|
||||
"shift_x": shift_x,
|
||||
"scale_y": scale_y,
|
||||
"shift_y": shift_y,
|
||||
"scale_t": scale_t,
|
||||
"shift_t": shift_t,
|
||||
}
|
||||
options.update(kwargs)
|
||||
self._call_rpc("set_model_rope_options", options)
|
||||
|
||||
def set_model_compute_dtype(self, dtype: Any) -> None:
|
||||
self._call_rpc("set_model_compute_dtype", dtype)
|
||||
|
||||
def add_object_patch(self, name: str, obj: Any) -> None:
|
||||
self._call_rpc("add_object_patch", name, obj)
|
||||
|
||||
def add_weight_wrapper(self, name: str, function: Any) -> None:
|
||||
self._call_rpc("add_weight_wrapper", name, function)
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
|
||||
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
|
||||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
|
||||
|
||||
@property
|
||||
def wrappers(self) -> Any:
|
||||
return self._call_rpc("get_wrappers")
|
||||
|
||||
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
|
||||
self._call_rpc("add_callback_with_key", call_type, key, callback)
|
||||
|
||||
def add_callback(self, call_type: str, callback: Any) -> None:
|
||||
self.add_callback_with_key(call_type, None, callback)
|
||||
|
||||
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_callbacks_with_key", call_type, key)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> Any:
|
||||
return self._call_rpc("get_callbacks")
|
||||
|
||||
def set_attachments(self, key: str, attachment: Any) -> None:
|
||||
self._call_rpc("set_attachments", key, attachment)
|
||||
|
||||
def get_attachment(self, key: str) -> Any:
|
||||
return self._call_rpc("get_attachment", key)
|
||||
|
||||
def remove_attachments(self, key: str) -> None:
|
||||
self._call_rpc("remove_attachments", key)
|
||||
|
||||
def set_injections(self, key: str, injections: Any) -> None:
|
||||
self._call_rpc("set_injections", key, injections)
|
||||
|
||||
def get_injections(self, key: str) -> Any:
|
||||
return self._call_rpc("get_injections", key)
|
||||
|
||||
def remove_injections(self, key: str) -> None:
|
||||
self._call_rpc("remove_injections", key)
|
||||
|
||||
def set_additional_models(self, key: str, models: Any) -> None:
|
||||
ids = [m._instance_id for m in models]
|
||||
self._call_rpc("set_additional_models", key, ids)
|
||||
|
||||
def remove_additional_models(self, key: str) -> None:
|
||||
self._call_rpc("remove_additional_models", key)
|
||||
|
||||
def get_nested_additional_models(self) -> Any:
|
||||
return self._call_rpc("get_nested_additional_models")
|
||||
|
||||
def get_additional_models(self) -> List[ModelPatcherProxy]:
|
||||
ids = self._call_rpc("get_additional_models")
|
||||
return [self._spawn_related_proxy(mid) for mid in ids]
|
||||
|
||||
def model_patches_models(self) -> Any:
|
||||
return self._call_rpc("model_patches_models")
|
||||
|
||||
@property
|
||||
def parent(self) -> Any:
|
||||
return self._call_rpc("get_parent")
|
||||
|
||||
def model_mmap_residency(self, free: bool = False) -> tuple:
|
||||
result = self._call_rpc("model_mmap_residency", free)
|
||||
if isinstance(result, list):
|
||||
return tuple(result)
|
||||
return result
|
||||
|
||||
def pinned_memory_size(self) -> int:
|
||||
return self._call_rpc("pinned_memory_size")
|
||||
|
||||
def get_non_dynamic_delegate(self) -> ModelPatcherProxy:
|
||||
new_id = self._call_rpc("get_non_dynamic_delegate")
|
||||
return self._spawn_related_proxy(new_id)
|
||||
|
||||
def disable_model_cfg1_optimization(self) -> None:
|
||||
self._call_rpc("disable_model_cfg1_optimization")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
|
||||
class _InnerModelProxy:
|
||||
def __init__(self, parent: ModelPatcherProxy):
|
||||
self._parent = parent
|
||||
self._model_sampling = None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(name)
|
||||
if name == "model_config":
|
||||
from types import SimpleNamespace
|
||||
|
||||
data = self._parent._call_rpc("get_inner_model_attr", name)
|
||||
if isinstance(data, dict):
|
||||
return SimpleNamespace(**data)
|
||||
return data
|
||||
if name in (
|
||||
"latent_format",
|
||||
"model_type",
|
||||
"current_weight_patches_uuid",
|
||||
):
|
||||
return self._parent._call_rpc("get_inner_model_attr", name)
|
||||
if name == "load_device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "load_device")
|
||||
if name == "device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "device")
|
||||
if name == "current_patcher":
|
||||
proxy = ModelPatcherProxy(
|
||||
self._parent._instance_id,
|
||||
self._parent._registry,
|
||||
manage_lifecycle=False,
|
||||
)
|
||||
if getattr(self._parent, "_rpc_caller", None) is not None:
|
||||
proxy._rpc_caller = self._parent._rpc_caller
|
||||
return proxy
|
||||
if name == "model_sampling":
|
||||
if self._model_sampling is None:
|
||||
self._model_sampling = self._parent._call_rpc(
|
||||
"get_model_object", "model_sampling"
|
||||
)
|
||||
return self._model_sampling
|
||||
if name == "extra_conds_shapes":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds_shapes", a, k
|
||||
)
|
||||
if name == "extra_conds":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds", a, k
|
||||
)
|
||||
if name == "memory_required":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_memory_required", a, k
|
||||
)
|
||||
if name == "apply_model":
|
||||
# Delegate to parent's method to get the CPU->CUDA optimization
|
||||
return self._parent.apply_model
|
||||
if name == "process_latent_in":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
|
||||
if name == "process_latent_out":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
|
||||
if name == "scale_latent_inpaint":
|
||||
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
|
||||
if name == "diffusion_model":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
|
||||
raise AttributeError(f"'{name}' not supported on isolated InnerModel")
|
||||
1311
comfy/isolation/model_patcher_proxy_registry.py
Normal file
1311
comfy/isolation/model_patcher_proxy_registry.py
Normal file
File diff suppressed because it is too large
Load Diff
156
comfy/isolation/model_patcher_proxy_utils.py
Normal file
156
comfy/isolation/model_patcher_proxy_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
|
||||
# Isolation utilities and serializers for ModelPatcherProxy
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
isolation_active = args.use_process_isolation or is_child
|
||||
|
||||
if not isolation_active:
|
||||
return model_patcher
|
||||
if is_child:
|
||||
return model_patcher
|
||||
if isinstance(model_patcher, ModelPatcherProxy):
|
||||
return model_patcher
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
model_id = registry.register(model_patcher)
|
||||
logger.debug(f"Isolated ModelPatcher: {model_id}")
|
||||
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
|
||||
|
||||
|
||||
def register_hooks_serializers(registry=None):
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
import comfy.hooks
|
||||
|
||||
if registry is None:
|
||||
registry = SerializerRegistry.get_instance()
|
||||
|
||||
def serialize_enum(obj):
|
||||
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
|
||||
|
||||
def deserialize_enum(data):
|
||||
cls_name, val_name = data["__enum__"].split(".")
|
||||
cls = getattr(comfy.hooks, cls_name)
|
||||
return cls[val_name]
|
||||
|
||||
registry.register("EnumHookType", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
|
||||
|
||||
def serialize_hook_group(obj):
|
||||
return {"__type__": "HookGroup", "hooks": obj.hooks}
|
||||
|
||||
def deserialize_hook_group(data):
|
||||
hg = comfy.hooks.HookGroup()
|
||||
for h in data["hooks"]:
|
||||
hg.add(h)
|
||||
return hg
|
||||
|
||||
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
|
||||
|
||||
def serialize_dict_state(obj):
|
||||
d = obj.__dict__.copy()
|
||||
d["__type__"] = type(obj).__name__
|
||||
if "custom_should_register" in d:
|
||||
del d["custom_should_register"]
|
||||
return d
|
||||
|
||||
def deserialize_dict_state_generic(cls):
|
||||
def _deserialize(data):
|
||||
h = cls()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
return _deserialize
|
||||
|
||||
def deserialize_hook_keyframe(data):
|
||||
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
|
||||
|
||||
def deserialize_hook_keyframe_group(data):
|
||||
h = comfy.hooks.HookKeyframeGroup()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register(
|
||||
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
|
||||
)
|
||||
|
||||
def deserialize_hook(data):
|
||||
h = comfy.hooks.Hook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("Hook", serialize_dict_state, deserialize_hook)
|
||||
|
||||
def deserialize_weight_hook(data):
|
||||
h = comfy.hooks.WeightHook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
|
||||
|
||||
def serialize_set(obj):
|
||||
return {"__set__": list(obj)}
|
||||
|
||||
def deserialize_set(data):
|
||||
return set(data["__set__"])
|
||||
|
||||
registry.register("set", serialize_set, deserialize_set)
|
||||
|
||||
try:
|
||||
from comfy.weight_adapter.lora import LoRAAdapter
|
||||
|
||||
def serialize_lora(obj):
|
||||
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
|
||||
|
||||
def deserialize_lora(data):
|
||||
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
|
||||
|
||||
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import uuid
|
||||
|
||||
def serialize_hook_ref(obj):
|
||||
return {
|
||||
"__hook_ref__": True,
|
||||
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
|
||||
}
|
||||
|
||||
def deserialize_hook_ref(data):
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
|
||||
return h
|
||||
|
||||
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register _HookRef: {e}")
|
||||
|
||||
|
||||
try:
|
||||
register_hooks_serializers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize hook serializers: {e}")
|
||||
360
comfy/isolation/model_sampling_proxy.py
Normal file
360
comfy/isolation/model_sampling_proxy.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _describe_value(obj: Any) -> str:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
try:
|
||||
if torch is not None and isinstance(obj, torch.Tensor):
|
||||
return (
|
||||
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
|
||||
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "%s(id=%s)" % (type(obj).__name__, id(obj))
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return None
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor) and t.is_cuda:
|
||||
return t.device
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.device
|
||||
return None
|
||||
|
||||
|
||||
def _to_device(obj: Any, device: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device != device:
|
||||
return obj.to(device)
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_device(x, device) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_device(v, device) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
def _to_cpu_for_rpc(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
t = obj.detach() if obj.requires_grad else obj
|
||||
if t.is_cuda:
|
||||
return t.to("cpu")
|
||||
return t
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_cpu_for_rpc(x) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.calculate_input(sigma, noise))
|
||||
|
||||
async def calculate_denoised(
|
||||
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.calculate_denoised(sigma, model_output, model_input)
|
||||
)
|
||||
|
||||
async def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
|
||||
)
|
||||
|
||||
async def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
|
||||
|
||||
async def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.timestep(sigma)
|
||||
|
||||
async def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.sigma(timestep)
|
||||
|
||||
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.percent_to_sigma(percent)
|
||||
|
||||
async def get_sigma_min(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_min)
|
||||
|
||||
async def get_sigma_max(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_max)
|
||||
|
||||
async def get_sigma_data(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_data)
|
||||
|
||||
async def get_sigmas(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigmas)
|
||||
|
||||
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
sampling = self._get_instance(instance_id)
|
||||
sampling.set_sigmas(sigmas)
|
||||
|
||||
|
||||
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
_registry_class = ModelSamplingRegistry
|
||||
__module__ = "comfy.isolation.model_sampling_proxy"
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
|
||||
)
|
||||
else:
|
||||
registry = ModelSamplingRegistry()
|
||||
|
||||
class _LocalCaller:
|
||||
def calculate_input(
|
||||
self, instance_id: str, sigma: Any, noise: Any
|
||||
) -> Any:
|
||||
return registry.calculate_input(instance_id, sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
model_output: Any,
|
||||
model_input: Any,
|
||||
) -> Any:
|
||||
return registry.calculate_denoised(
|
||||
instance_id, sigma, model_output, model_input
|
||||
)
|
||||
|
||||
def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
return registry.noise_scaling(
|
||||
instance_id, sigma, noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
return registry.inverse_noise_scaling(
|
||||
instance_id, sigma, latent
|
||||
)
|
||||
|
||||
def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
return registry.timestep(instance_id, sigma)
|
||||
|
||||
def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
return registry.sigma(instance_id, timestep)
|
||||
|
||||
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
return registry.percent_to_sigma(instance_id, percent)
|
||||
|
||||
def get_sigma_min(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_min(instance_id)
|
||||
|
||||
def get_sigma_max(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_max(instance_id)
|
||||
|
||||
def get_sigma_data(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_data(instance_id)
|
||||
|
||||
def get_sigmas(self, instance_id: str) -> Any:
|
||||
return registry.get_sigmas(instance_id)
|
||||
|
||||
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
return registry.set_sigmas(instance_id, sigmas)
|
||||
|
||||
self._rpc_caller = _LocalCaller()
|
||||
return self._rpc_caller
|
||||
|
||||
def _call(self, method_name: str, *args: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
result = method(self._instance_id, *args)
|
||||
timeout_ms = self._rpc_timeout_ms()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
call_id = "%s:%s:%s:%.6f" % (
|
||||
self._instance_id,
|
||||
method_name,
|
||||
thread_id,
|
||||
start_perf,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
timeout_ms,
|
||||
)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
out = run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
out = loop.run_until_complete(result)
|
||||
else:
|
||||
out = result
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
_describe_value(out),
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _rpc_timeout_ms() -> int:
|
||||
raw = os.environ.get(
|
||||
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
|
||||
)
|
||||
try:
|
||||
timeout_ms = int(raw)
|
||||
except ValueError:
|
||||
timeout_ms = 30000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
return self._call("get_sigma_min")
|
||||
|
||||
@property
|
||||
def sigma_max(self) -> Any:
|
||||
return self._call("get_sigma_max")
|
||||
|
||||
@property
|
||||
def sigma_data(self) -> Any:
|
||||
return self._call("get_sigma_data")
|
||||
|
||||
@property
|
||||
def sigmas(self) -> Any:
|
||||
return self._call("get_sigmas")
|
||||
|
||||
def calculate_input(self, sigma: Any, noise: Any) -> Any:
|
||||
return self._call("calculate_input", sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
return self._call("calculate_denoised", sigma, model_output, model_input)
|
||||
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
preferred_device = _prefer_device(noise, latent_image)
|
||||
out = self._call(
|
||||
"noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(noise),
|
||||
_to_cpu_for_rpc(latent_image),
|
||||
max_denoise,
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
preferred_device = _prefer_device(latent)
|
||||
out = self._call(
|
||||
"inverse_noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(latent),
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
def sigma(self, timestep: Any) -> Any:
|
||||
return self._call("sigma", timestep)
|
||||
|
||||
def percent_to_sigma(self, percent: float) -> Any:
|
||||
return self._call("percent_to_sigma", percent)
|
||||
|
||||
def set_sigmas(self, sigmas: Any) -> None:
|
||||
return self._call("set_sigmas", sigmas)
|
||||
17
comfy/isolation/proxies/__init__.py
Normal file
17
comfy/isolation/proxies/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IS_CHILD_PROCESS",
|
||||
"BaseRegistry",
|
||||
"BaseProxy",
|
||||
"get_thread_loop",
|
||||
"run_coro_in_new_loop",
|
||||
"detach_if_grad",
|
||||
]
|
||||
301
comfy/isolation/proxies/base.py
Normal file
301
comfy/isolation/proxies/base.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
_thread_local = threading.local()
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
||||
loop = getattr(_thread_local, "loop", None)
|
||||
if loop is None or loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
def run_coro_in_new_loop(coro: Any) -> Any:
|
||||
result_box: Dict[str, Any] = {}
|
||||
exc_box: Dict[str, BaseException] = {}
|
||||
|
||||
def runner() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result_box["value"] = loop.run_until_complete(coro)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exc_box["exc"] = exc
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
t = threading.Thread(target=runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
if "exc" in exc_box:
|
||||
raise exc_box["exc"]
|
||||
return result_box.get("value")
|
||||
|
||||
|
||||
def detach_if_grad(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.detach() if obj.requires_grad else obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(detach_if_grad(x) for x in obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: detach_if_grad(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
||||
_type_prefix: str = "base"
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
||||
super().__init__()
|
||||
self._registry: Dict[str, T] = {}
|
||||
self._id_map: Dict[int, str] = {}
|
||||
self._counter = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, instance: T) -> str:
|
||||
with self._lock:
|
||||
obj_id = id(instance)
|
||||
if obj_id in self._id_map:
|
||||
return self._id_map[obj_id]
|
||||
instance_id = f"{self._type_prefix}_{self._counter}"
|
||||
self._counter += 1
|
||||
self._registry[instance_id] = instance
|
||||
self._id_map[obj_id] = instance_id
|
||||
return instance_id
|
||||
|
||||
def unregister_sync(self, instance_id: str) -> None:
|
||||
with self._lock:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance:
|
||||
self._id_map.pop(id(instance), None)
|
||||
|
||||
def _get_instance(self, instance_id: str) -> T:
|
||||
if IS_CHILD_PROCESS:
|
||||
raise RuntimeError(
|
||||
f"[{self.__class__.__name__}] _get_instance called in child"
|
||||
)
|
||||
with self._lock:
|
||||
instance = self._registry.get(instance_id)
|
||||
if instance is None:
|
||||
raise ValueError(f"{instance_id} not found")
|
||||
return instance
|
||||
|
||||
|
||||
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
|
||||
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _GLOBAL_LOOP
|
||||
_GLOBAL_LOOP = loop
|
||||
|
||||
|
||||
def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any:
|
||||
if timeout_ms is not None:
|
||||
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
||||
|
||||
try:
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
pass
|
||||
except RuntimeError:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result(
|
||||
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
|
||||
except concurrent.futures.TimeoutError as exc:
|
||||
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
|
||||
|
||||
|
||||
def call_singleton_rpc(
|
||||
caller: Any,
|
||||
method_name: str,
|
||||
*args: Any,
|
||||
timeout_ms: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if caller is None:
|
||||
raise RuntimeError(f"No RPC caller available for {method_name}")
|
||||
method = getattr(caller, method_name)
|
||||
return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms)
|
||||
|
||||
|
||||
class BaseProxy(Generic[T]):
|
||||
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||
__module__: str = "comfy.isolation.proxies.base"
|
||||
_TIMEOUT_RPC_METHODS = frozenset(
|
||||
{
|
||||
"partially_load",
|
||||
"partially_unload",
|
||||
"load",
|
||||
"patch_model",
|
||||
"unpatch_model",
|
||||
"inner_model_apply_model",
|
||||
"memory_required",
|
||||
"model_dtype",
|
||||
"inner_model_memory_required",
|
||||
"inner_model_extra_conds_shapes",
|
||||
"inner_model_extra_conds",
|
||||
"process_latent_in",
|
||||
"process_latent_out",
|
||||
"scale_latent_inpaint",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
registry: Optional[Any] = None,
|
||||
manage_lifecycle: bool = False,
|
||||
) -> None:
|
||||
self._instance_id = instance_id
|
||||
self._rpc_caller: Optional[Any] = None
|
||||
self._registry = registry if registry is not None else self._registry_class()
|
||||
self._manage_lifecycle = manage_lifecycle
|
||||
self._cleaned_up = False
|
||||
if manage_lifecycle and not IS_CHILD_PROCESS:
|
||||
self._finalizer = weakref.finalize(
|
||||
self, self._registry.unregister_sync, instance_id
|
||||
)
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is None:
|
||||
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
return self._rpc_caller
|
||||
|
||||
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
|
||||
if method_name not in self._TIMEOUT_RPC_METHODS:
|
||||
return None
|
||||
try:
|
||||
timeout_ms = int(
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
|
||||
)
|
||||
except ValueError:
|
||||
timeout_ms = 120000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
|
||||
coro = method(self._instance_id, *args, **kwargs)
|
||||
if timeout_ms is not None:
|
||||
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
||||
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
loop_id: Optional[int] = id(running_loop)
|
||||
except RuntimeError:
|
||||
loop_id = None
|
||||
logger.debug(
|
||||
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
|
||||
"thread=%s loop=%s timeout_ms=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
loop_id,
|
||||
timeout_ms,
|
||||
)
|
||||
|
||||
try:
|
||||
return run_sync_rpc_coro(coro, timeout_ms=timeout_ms)
|
||||
except TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
finally:
|
||||
end_epoch = time.time()
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
|
||||
"elapsed_ms=%.3f thread=%s loop=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
end_epoch,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
loop_id,
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_instance_id": self._instance_id}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self._instance_id = state["_instance_id"]
|
||||
self._rpc_caller = None
|
||||
self._registry = self._registry_class()
|
||||
self._manage_lifecycle = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self._cleaned_up or IS_CHILD_PROCESS:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
finalizer = getattr(self, "_finalizer", None)
|
||||
if finalizer is not None:
|
||||
finalizer.detach()
|
||||
self._registry.unregister_sync(self._instance_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._instance_id}>"
|
||||
|
||||
|
||||
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
||||
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(method_name, *args, **kwargs)
|
||||
|
||||
method.__name__ = method_name
|
||||
return method
|
||||
202
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
202
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
def _folder_paths():
|
||||
import folder_paths
|
||||
|
||||
return folder_paths
|
||||
|
||||
|
||||
def _is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]:
|
||||
return {
|
||||
key: {"paths": list(paths), "extensions": sorted(list(extensions))}
|
||||
for key, (paths, extensions) in data.items()
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]:
|
||||
return {
|
||||
key: (list(value.get("paths", [])), set(value.get("extensions", [])))
|
||||
for key, value in data.items()
|
||||
}
|
||||
|
||||
|
||||
class FolderPathsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for folder_paths.
|
||||
Uses __getattr__ for most lookups, with explicit handling for
|
||||
mutable collections to ensure efficient by-value transfer.
|
||||
"""
|
||||
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("FolderPathsProxy RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _is_child_process():
|
||||
property_rpc = {
|
||||
"models_dir": "rpc_get_models_dir",
|
||||
"folder_names_and_paths": "rpc_get_folder_names_and_paths",
|
||||
"extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache",
|
||||
"filename_list_cache": "rpc_get_filename_list_cache",
|
||||
}
|
||||
rpc_name = property_rpc.get(name)
|
||||
if rpc_name is not None:
|
||||
return call_singleton_rpc(self._get_caller(), rpc_name)
|
||||
raise AttributeError(name)
|
||||
return getattr(_folder_paths(), name)
|
||||
|
||||
@property
|
||||
def folder_names_and_paths(self) -> Dict:
|
||||
if _is_child_process():
|
||||
payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths")
|
||||
return _deserialize_folder_names_and_paths(payload)
|
||||
return _folder_paths().folder_names_and_paths
|
||||
|
||||
@property
|
||||
def extension_mimetypes_cache(self) -> Dict:
|
||||
if _is_child_process():
|
||||
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache"))
|
||||
return dict(_folder_paths().extension_mimetypes_cache)
|
||||
|
||||
@property
|
||||
def filename_list_cache(self) -> Dict:
|
||||
if _is_child_process():
|
||||
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache"))
|
||||
return dict(_folder_paths().filename_list_cache)
|
||||
|
||||
@property
|
||||
def models_dir(self) -> str:
|
||||
if _is_child_process():
|
||||
return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir"))
|
||||
return _folder_paths().models_dir
|
||||
|
||||
def get_temp_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory")
|
||||
return _folder_paths().get_temp_directory()
|
||||
|
||||
def get_input_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory")
|
||||
return _folder_paths().get_input_directory()
|
||||
|
||||
def get_output_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory")
|
||||
return _folder_paths().get_output_directory()
|
||||
|
||||
def get_user_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory")
|
||||
return _folder_paths().get_user_directory()
|
||||
|
||||
def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(
|
||||
self._get_caller(), "rpc_get_annotated_filepath", name, default_dir
|
||||
)
|
||||
return _folder_paths().get_annotated_filepath(name, default_dir)
|
||||
|
||||
def exists_annotated_filepath(self, name: str) -> bool:
|
||||
if _is_child_process():
|
||||
return bool(
|
||||
call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name)
|
||||
)
|
||||
return bool(_folder_paths().exists_annotated_filepath(name))
|
||||
|
||||
def add_model_folder_path(
|
||||
self, folder_name: str, full_folder_path: str, is_default: bool = False
|
||||
) -> None:
|
||||
if _is_child_process():
|
||||
call_singleton_rpc(
|
||||
self._get_caller(),
|
||||
"rpc_add_model_folder_path",
|
||||
folder_name,
|
||||
full_folder_path,
|
||||
is_default,
|
||||
)
|
||||
return None
|
||||
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
|
||||
return None
|
||||
|
||||
def get_folder_paths(self, folder_name: str) -> list[str]:
|
||||
if _is_child_process():
|
||||
return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name))
|
||||
return list(_folder_paths().get_folder_paths(folder_name))
|
||||
|
||||
def get_filename_list(self, folder_name: str) -> list[str]:
|
||||
if _is_child_process():
|
||||
return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
|
||||
return list(_folder_paths().get_filename_list(folder_name))
|
||||
|
||||
def get_full_path(self, folder_name: str, filename: str) -> str | None:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename)
|
||||
return _folder_paths().get_full_path(folder_name, filename)
|
||||
|
||||
async def rpc_get_models_dir(self) -> str:
|
||||
return _folder_paths().models_dir
|
||||
|
||||
async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]:
|
||||
return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths)
|
||||
|
||||
async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]:
|
||||
return dict(_folder_paths().extension_mimetypes_cache)
|
||||
|
||||
async def rpc_get_filename_list_cache(self) -> dict[str, Any]:
|
||||
return dict(_folder_paths().filename_list_cache)
|
||||
|
||||
async def rpc_get_temp_directory(self) -> str:
|
||||
return _folder_paths().get_temp_directory()
|
||||
|
||||
async def rpc_get_input_directory(self) -> str:
|
||||
return _folder_paths().get_input_directory()
|
||||
|
||||
async def rpc_get_output_directory(self) -> str:
|
||||
return _folder_paths().get_output_directory()
|
||||
|
||||
async def rpc_get_user_directory(self) -> str:
|
||||
return _folder_paths().get_user_directory()
|
||||
|
||||
async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
|
||||
return _folder_paths().get_annotated_filepath(name, default_dir)
|
||||
|
||||
async def rpc_exists_annotated_filepath(self, name: str) -> bool:
|
||||
return _folder_paths().exists_annotated_filepath(name)
|
||||
|
||||
async def rpc_add_model_folder_path(
|
||||
self, folder_name: str, full_folder_path: str, is_default: bool = False
|
||||
) -> None:
|
||||
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
|
||||
|
||||
async def rpc_get_folder_paths(self, folder_name: str) -> list[str]:
|
||||
return _folder_paths().get_folder_paths(folder_name)
|
||||
|
||||
async def rpc_get_filename_list(self, folder_name: str) -> list[str]:
|
||||
return _folder_paths().get_filename_list(folder_name)
|
||||
|
||||
async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None:
|
||||
return _folder_paths().get_full_path(folder_name, filename)
|
||||
158
comfy/isolation/proxies/helper_proxies.py
Normal file
158
comfy/isolation/proxies/helper_proxies.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
class AnyTypeProxy(str):
|
||||
"""Replacement for custom AnyType objects used by some nodes."""
|
||||
|
||||
def __new__(cls, value: str = "*"):
|
||||
return super().__new__(cls, value)
|
||||
|
||||
def __ne__(self, other): # type: ignore[override]
|
||||
return False
|
||||
|
||||
|
||||
class FlexibleOptionalInputProxy(dict):
|
||||
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
|
||||
|
||||
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
|
||||
super().__init__()
|
||||
self.type = flex_type
|
||||
if data:
|
||||
self.update(data)
|
||||
|
||||
def __getitem__(self, key): # type: ignore[override]
|
||||
return (self.type,)
|
||||
|
||||
def __contains__(self, key): # type: ignore[override]
|
||||
return True
|
||||
|
||||
|
||||
class ByPassTypeTupleProxy(tuple):
|
||||
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
|
||||
|
||||
def __new__(cls, values):
|
||||
return super().__new__(cls, values)
|
||||
|
||||
def __getitem__(self, index): # type: ignore[override]
|
||||
if index >= len(self):
|
||||
return AnyTypeProxy("*")
|
||||
return super().__getitem__(index)
|
||||
|
||||
|
||||
def _restore_special_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
if value.get("__pyisolate_any_type__"):
|
||||
return AnyTypeProxy(value.get("value", "*"))
|
||||
if value.get("__pyisolate_flexible_optional__"):
|
||||
flex_type = _restore_special_value(value.get("type"))
|
||||
data_raw = value.get("data")
|
||||
data = (
|
||||
{k: _restore_special_value(v) for k, v in data_raw.items()}
|
||||
if isinstance(data_raw, dict)
|
||||
else {}
|
||||
)
|
||||
return FlexibleOptionalInputProxy(flex_type, data)
|
||||
if value.get("__pyisolate_tuple__") is not None:
|
||||
return tuple(
|
||||
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
|
||||
)
|
||||
if value.get("__pyisolate_bypass_tuple__") is not None:
|
||||
return ByPassTypeTupleProxy(
|
||||
tuple(
|
||||
_restore_special_value(v)
|
||||
for v in value["__pyisolate_bypass_tuple__"]
|
||||
)
|
||||
)
|
||||
return {k: _restore_special_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_restore_special_value(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_special_value(value: Any) -> Any:
|
||||
if isinstance(value, AnyTypeProxy):
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if isinstance(value, FlexibleOptionalInputProxy):
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _serialize_special_value(value.type),
|
||||
"data": {k: _serialize_special_value(v) for k, v in value.items()},
|
||||
}
|
||||
if isinstance(value, ByPassTypeTupleProxy):
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value]
|
||||
}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_serialize_special_value(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _serialize_special_value(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]:
|
||||
if not isinstance(raw, dict):
|
||||
return raw # type: ignore[return-value]
|
||||
|
||||
restored: Dict[str, object] = {}
|
||||
for section, entries in raw.items():
|
||||
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
|
||||
restored[section] = _restore_special_value(entries)
|
||||
elif isinstance(entries, dict):
|
||||
restored[section] = {
|
||||
k: _restore_special_value(v) for k, v in entries.items()
|
||||
}
|
||||
else:
|
||||
restored[section] = _restore_special_value(entries)
|
||||
return restored
|
||||
|
||||
|
||||
class HelperProxiesService(ProxiedSingleton):
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("HelperProxiesService RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]:
|
||||
restored = _restore_input_types_local(raw)
|
||||
return _serialize_special_value(restored)
|
||||
|
||||
|
||||
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
|
||||
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
payload = call_singleton_rpc(
|
||||
HelperProxiesService._get_caller(),
|
||||
"rpc_restore_input_types",
|
||||
raw,
|
||||
)
|
||||
return _restore_input_types_local(payload)
|
||||
return _restore_input_types_local(raw)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnyTypeProxy",
|
||||
"FlexibleOptionalInputProxy",
|
||||
"ByPassTypeTupleProxy",
|
||||
"HelperProxiesService",
|
||||
"restore_input_types",
|
||||
]
|
||||
142
comfy/isolation/proxies/model_management_proxy.py
Normal file
142
comfy/isolation/proxies/model_management_proxy.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
def _mm():
|
||||
import comfy.model_management
|
||||
|
||||
return comfy.model_management
|
||||
|
||||
|
||||
def _is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
class TorchDeviceProxy:
|
||||
def __init__(self, device_str: str):
|
||||
self._device_str = device_str
|
||||
if ":" in device_str:
|
||||
device_type, index = device_str.split(":", 1)
|
||||
self.type = device_type
|
||||
self.index = int(index)
|
||||
else:
|
||||
self.type = device_str
|
||||
self.index = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._device_str
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchDeviceProxy({self._device_str!r})"
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
value_type = type(value)
|
||||
if value_type.__module__ == "torch" and value_type.__name__ == "device":
|
||||
return {"__pyisolate_torch_device__": str(value)}
|
||||
if isinstance(value, TorchDeviceProxy):
|
||||
return {"__pyisolate_torch_device__": str(value)}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]}
|
||||
if isinstance(value, list):
|
||||
return [_serialize_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _serialize_value(inner) for key, inner in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def _deserialize_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
if "__pyisolate_torch_device__" in value:
|
||||
return TorchDeviceProxy(value["__pyisolate_torch_device__"])
|
||||
if "__pyisolate_tuple__" in value:
|
||||
return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"])
|
||||
return {key: _deserialize_value(inner) for key, inner in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_deserialize_value(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_argument(value: Any) -> Any:
|
||||
if isinstance(value, TorchDeviceProxy):
|
||||
import torch
|
||||
|
||||
return torch.device(str(value))
|
||||
if isinstance(value, dict):
|
||||
if "__pyisolate_torch_device__" in value:
|
||||
import torch
|
||||
|
||||
return torch.device(value["__pyisolate_torch_device__"])
|
||||
if "__pyisolate_tuple__" in value:
|
||||
return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"])
|
||||
return {key: _normalize_argument(inner) for key, inner in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_normalize_argument(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
class ModelManagementProxy(ProxiedSingleton):
|
||||
"""
|
||||
Exact-relay proxy for comfy.model_management.
|
||||
Child calls never import comfy.model_management directly; they serialize
|
||||
arguments, relay to host, and deserialize the host result back.
|
||||
"""
|
||||
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("ModelManagementProxy RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
payload = call_singleton_rpc(
|
||||
self._get_caller(),
|
||||
"rpc_call",
|
||||
method_name,
|
||||
_serialize_value(args),
|
||||
_serialize_value(kwargs),
|
||||
)
|
||||
return _deserialize_value(payload)
|
||||
|
||||
@property
|
||||
def VRAMState(self):
|
||||
return _mm().VRAMState
|
||||
|
||||
@property
|
||||
def CPUState(self):
|
||||
return _mm().CPUState
|
||||
|
||||
@property
|
||||
def OOM_EXCEPTION(self):
|
||||
return _mm().OOM_EXCEPTION
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if _is_child_process():
|
||||
def child_method(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._relay_call(name, *args, **kwargs)
|
||||
|
||||
return child_method
|
||||
return getattr(_mm(), name)
|
||||
|
||||
async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
|
||||
normalized_args = _normalize_argument(_deserialize_value(args))
|
||||
normalized_kwargs = _normalize_argument(_deserialize_value(kwargs))
|
||||
method = getattr(_mm(), method_name)
|
||||
result = method(*normalized_args, **normalized_kwargs)
|
||||
return _serialize_value(result)
|
||||
87
comfy/isolation/proxies/progress_proxy.py
Normal file
87
comfy/isolation/proxies/progress_proxy.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton:
|
||||
pass
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
def _get_progress_state():
|
||||
from comfy_execution.progress import get_progress_state
|
||||
|
||||
return get_progress_state()
|
||||
|
||||
|
||||
def _is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressProxy(ProxiedSingleton):
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("ProgressProxy RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: Optional[str] = None,
|
||||
image: Any = None,
|
||||
) -> None:
|
||||
if _is_child_process():
|
||||
call_singleton_rpc(
|
||||
self._get_caller(),
|
||||
"rpc_set_progress",
|
||||
value,
|
||||
max_value,
|
||||
node_id,
|
||||
image,
|
||||
)
|
||||
return None
|
||||
|
||||
_get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=image,
|
||||
)
|
||||
return None
|
||||
|
||||
async def rpc_set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: Optional[str] = None,
|
||||
image: Any = None,
|
||||
) -> None:
|
||||
_get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ProgressProxy"]
|
||||
271
comfy/isolation/proxies/prompt_server_impl.py
Normal file
271
comfy/isolation/proxies/prompt_server_impl.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
|
||||
"""Stateless RPC Implementation for PromptServer.
|
||||
|
||||
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
|
||||
- Host: PromptServerService (RPC Handler)
|
||||
- Child: PromptServerStub (Interface Implementation)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import logging
|
||||
|
||||
# IMPORTS
|
||||
from pyisolate import ProxiedSingleton
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_PREFIX = "[Isolation:C<->H]"
|
||||
|
||||
# ...
|
||||
|
||||
# =============================================================================
|
||||
# CHILD SIDE: PromptServerStub
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerStub:
|
||||
"""Stateless Stub for PromptServer."""
|
||||
|
||||
# Masquerade as the real server module
|
||||
__module__ = "server"
|
||||
|
||||
_instance: Optional["PromptServerStub"] = None
|
||||
_rpc: Optional[Any] = None # This will be the Caller object
|
||||
_source_file: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.routes = RouteStub(self)
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
"""Inject RPC client (called by adapter.py or manually)."""
|
||||
# Create caller for HOST Service
|
||||
# Assuming Host Service is registered as "PromptServerService" (class name)
|
||||
# We target the Host Service Class
|
||||
target_id = "PromptServerService"
|
||||
# We need to pass a class to create_caller? Usually yes.
|
||||
# But we don't have the Service class imported here necessarily (if running on child).
|
||||
# pyisolate check verify_service type?
|
||||
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
|
||||
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
|
||||
# We need a dummy class with right name?
|
||||
# Or just rely on string ID if create_caller supports it?
|
||||
# Standard: rpc.create_caller(PromptServerStub, target_id)
|
||||
# But wait, PromptServerStub is the *Local* class.
|
||||
# We want to call *Remote* class.
|
||||
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
|
||||
# The first arg is 'service_cls'.
|
||||
cls._rpc = rpc.create_caller(
|
||||
PromptServerService, target_id
|
||||
) # We import Service below?
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
# We need PromptServerService available for the create_caller call?
|
||||
# Or just use the Stub class if ID matches?
|
||||
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
|
||||
|
||||
@property
|
||||
def instance(self) -> "PromptServerStub":
|
||||
return self
|
||||
|
||||
# ... Compatibility ...
|
||||
@classmethod
|
||||
def _get_source_file(cls) -> str:
|
||||
if cls._source_file is None:
|
||||
import folder_paths
|
||||
|
||||
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
|
||||
return cls._source_file
|
||||
|
||||
@property
|
||||
def __file__(self) -> str:
|
||||
return self._get_source_file()
|
||||
|
||||
# --- Properties ---
|
||||
@property
|
||||
def client_id(self) -> Optional[str]:
|
||||
return "isolated_client"
|
||||
|
||||
def supports(self, feature: str) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def prompt_queue(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.prompt_queue is not accessible in isolated nodes."
|
||||
)
|
||||
|
||||
# --- UI Communication (RPC Delegates) ---
|
||||
async def send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send_sync(event, data, sid)
|
||||
|
||||
async def send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send(event, data, sid)
|
||||
|
||||
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
|
||||
if self._rpc:
|
||||
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
|
||||
# We must schedule it?
|
||||
# Or use fire_remote equivalent?
|
||||
# Caller object usually proxies calls. If host method is async, it returns coro.
|
||||
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
|
||||
# But UtilsProxy hook wrapper creates task.
|
||||
# Does send_progress_text need to be sync? Yes, node code calls it sync.
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
|
||||
except RuntimeError:
|
||||
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
|
||||
|
||||
# --- Route Registration Logic ---
|
||||
def register_route(self, method: str, path: str, handler: Callable):
|
||||
"""Register a route handler via RPC."""
|
||||
if not self._rpc:
|
||||
logger.error("RPC not initialized in PromptServerStub")
|
||||
return
|
||||
|
||||
# Fire registration async
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
||||
except RuntimeError:
|
||||
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler)
|
||||
|
||||
|
||||
class RouteStub:
|
||||
"""Simulates aiohttp.web.RouteTableDef."""
|
||||
|
||||
def __init__(self, stub: PromptServerStub):
|
||||
self._stub = stub
|
||||
|
||||
def get(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HOST SIDE: PromptServerService
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerService(ProxiedSingleton):
|
||||
"""Host-side RPC Service for PromptServer."""
|
||||
|
||||
def __init__(self):
|
||||
# We will bind to the real server instance lazily or via global import
|
||||
pass
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
from server import PromptServer
|
||||
|
||||
return PromptServer.instance
|
||||
|
||||
async def ui_send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send_sync(event, data, sid)
|
||||
|
||||
async def ui_send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send(event, data, sid)
|
||||
|
||||
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
|
||||
# Made async to be awaitable by RPC layer
|
||||
self.server.send_progress_text(text, node_id, sid)
|
||||
|
||||
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
||||
"""RPC Target: Register a route that forwards to the Child."""
|
||||
from aiohttp import web
|
||||
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
||||
|
||||
async def route_wrapper(request: web.Request) -> web.Response:
|
||||
# 1. Capture request data
|
||||
req_data = {
|
||||
"method": request.method,
|
||||
"path": request.path,
|
||||
"query": dict(request.query),
|
||||
}
|
||||
if request.can_read_body:
|
||||
req_data["text"] = await request.text()
|
||||
|
||||
try:
|
||||
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
|
||||
result = await child_handler_proxy(req_data)
|
||||
|
||||
# 3. Serialize Response
|
||||
return self._serialize_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
# Register loop
|
||||
self.server.app.router.add_route(method, path, route_wrapper)
|
||||
|
||||
def _serialize_response(self, result: Any) -> Any:
|
||||
"""Helper to convert Child result -> web.Response"""
|
||||
from aiohttp import web
|
||||
if isinstance(result, web.Response):
|
||||
return result
|
||||
# Handle dict (json)
|
||||
if isinstance(result, dict):
|
||||
return web.json_response(result)
|
||||
# Handle string
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result)
|
||||
# Fallback
|
||||
return web.Response(text=str(result))
|
||||
64
comfy/isolation/proxies/utils_proxy.py
Normal file
64
comfy/isolation/proxies/utils_proxy.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def _comfy_utils():
|
||||
import comfy.utils
|
||||
return comfy.utils
|
||||
|
||||
|
||||
class UtilsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Proxy for comfy.utils.
|
||||
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
|
||||
from isolated nodes reach the host.
|
||||
"""
|
||||
|
||||
# _instance and __new__ removed to rely on SingletonMetaclass
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
# Create caller using class name as ID (standard for Singletons)
|
||||
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
async def progress_bar_hook(
|
||||
self,
|
||||
value: int,
|
||||
total: int,
|
||||
preview: Optional[bytes] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Host-side implementation: forwards the call to the real global hook.
|
||||
Child-side: this method call is intercepted by RPC and sent to host.
|
||||
"""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if UtilsProxy._rpc is None:
|
||||
raise RuntimeError("UtilsProxy RPC caller is not configured")
|
||||
return await UtilsProxy._rpc.progress_bar_hook(
|
||||
value, total, preview, node_id
|
||||
)
|
||||
|
||||
# Host Execution
|
||||
utils = _comfy_utils()
|
||||
if utils.PROGRESS_BAR_HOOK is not None:
|
||||
return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
|
||||
return None
|
||||
|
||||
def set_progress_bar_global_hook(self, hook: Any) -> None:
|
||||
"""Forward hook registration (though usually not needed from child)."""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
raise RuntimeError(
|
||||
"UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support"
|
||||
)
|
||||
_comfy_utils().set_progress_bar_global_hook(hook)
|
||||
219
comfy/isolation/proxies/web_directory_proxy.py
Normal file
219
comfy/isolation/proxies/web_directory_proxy.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""WebDirectoryProxy — serves isolated node web assets via RPC.
|
||||
|
||||
Child side: enumerates and reads files from the extension's web/ directory.
|
||||
Host side: gets an RPC proxy that fetches file listings and contents on demand.
|
||||
|
||||
Only files with allowed extensions (.js, .html, .css) are served.
|
||||
Directory traversal is rejected. File contents are base64-encoded for
|
||||
safe JSON-RPC transport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"})
|
||||
|
||||
MIME_TYPES = {
|
||||
".js": "application/javascript",
|
||||
".html": "text/html",
|
||||
".css": "text/css",
|
||||
}
|
||||
|
||||
|
||||
class WebDirectoryProxy(ProxiedSingleton):
|
||||
"""Proxy for serving isolated extension web directories.
|
||||
|
||||
On the child side, this class has direct filesystem access to the
|
||||
extension's web/ directory. On the host side, callers get an RPC
|
||||
proxy whose method calls are forwarded to the child.
|
||||
"""
|
||||
|
||||
# {extension_name: absolute_path_to_web_dir}
|
||||
_web_dirs: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None:
|
||||
"""Register an extension's web directory (child-side only)."""
|
||||
cls._web_dirs[extension_name] = web_dir_path
|
||||
logger.info(
|
||||
"][ WebDirectoryProxy: registered %s -> %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
|
||||
def list_web_files(self, extension_name: str) -> List[Dict[str, str]]:
|
||||
"""Return a list of servable files in the extension's web directory.
|
||||
|
||||
Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}.
|
||||
Only files with allowed extensions are included.
|
||||
"""
|
||||
web_dir = self._web_dirs.get(extension_name)
|
||||
if not web_dir:
|
||||
return []
|
||||
|
||||
root = Path(web_dir)
|
||||
if not root.is_dir():
|
||||
return []
|
||||
|
||||
result: List[Dict[str, str]] = []
|
||||
for path in sorted(root.rglob("*")):
|
||||
if not path.is_file():
|
||||
continue
|
||||
ext = path.suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
continue
|
||||
rel = path.relative_to(root)
|
||||
result.append({
|
||||
"relative_path": str(PurePosixPath(rel)),
|
||||
"content_type": MIME_TYPES[ext],
|
||||
})
|
||||
return result
|
||||
|
||||
def get_web_file(
|
||||
self, extension_name: str, relative_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Return the contents of a single web file as base64.
|
||||
|
||||
Raises ValueError for traversal attempts or disallowed file types.
|
||||
Returns {"content": <base64 str>, "content_type": <MIME str>}.
|
||||
"""
|
||||
_validate_path(relative_path)
|
||||
|
||||
web_dir = self._web_dirs.get(extension_name)
|
||||
if not web_dir:
|
||||
raise FileNotFoundError(
|
||||
f"No web directory registered for {extension_name}"
|
||||
)
|
||||
|
||||
root = Path(web_dir)
|
||||
target = (root / relative_path).resolve()
|
||||
|
||||
# Ensure resolved path is under the web directory
|
||||
if not str(target).startswith(str(root.resolve())):
|
||||
raise ValueError(f"Path escapes web directory: {relative_path}")
|
||||
|
||||
if not target.is_file():
|
||||
raise FileNotFoundError(f"File not found: {relative_path}")
|
||||
|
||||
ext = target.suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise ValueError(f"Disallowed file type: {ext}")
|
||||
|
||||
content_type = MIME_TYPES[ext]
|
||||
raw = target.read_bytes()
|
||||
|
||||
return {
|
||||
"content": base64.b64encode(raw).decode("ascii"),
|
||||
"content_type": content_type,
|
||||
}
|
||||
|
||||
|
||||
def _validate_path(relative_path: str) -> None:
|
||||
"""Reject directory traversal and absolute paths."""
|
||||
if os.path.isabs(relative_path):
|
||||
raise ValueError(f"Absolute paths are not allowed: {relative_path}")
|
||||
if ".." in PurePosixPath(relative_path).parts:
|
||||
raise ValueError(f"Directory traversal is not allowed: {relative_path}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Host-side cache and aiohttp handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WebDirectoryCache:
|
||||
"""Host-side in-memory cache for proxied web directory contents.
|
||||
|
||||
Populated lazily via RPC calls to the child's WebDirectoryProxy.
|
||||
Once a file is cached, subsequent requests are served from memory.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# {extension_name: {relative_path: {"content": bytes, "content_type": str}}}
|
||||
self._file_cache: dict[str, dict[str, dict[str, Any]]] = {}
|
||||
# {extension_name: [{"relative_path": str, "content_type": str}, ...]}
|
||||
self._listing_cache: dict[str, list[dict[str, str]]] = {}
|
||||
# {extension_name: WebDirectoryProxy (RPC proxy instance)}
|
||||
self._proxies: dict[str, Any] = {}
|
||||
|
||||
def register_proxy(self, extension_name: str, proxy: Any) -> None:
|
||||
"""Register an RPC proxy for an extension's web directory."""
|
||||
self._proxies[extension_name] = proxy
|
||||
logger.info(
|
||||
"][ WebDirectoryCache: registered proxy for %s", extension_name
|
||||
)
|
||||
|
||||
@property
|
||||
def extension_names(self) -> list[str]:
|
||||
return list(self._proxies.keys())
|
||||
|
||||
def list_files(self, extension_name: str) -> list[dict[str, str]]:
|
||||
"""List servable files for an extension (cached after first call)."""
|
||||
if extension_name not in self._listing_cache:
|
||||
proxy = self._proxies.get(extension_name)
|
||||
if proxy is None:
|
||||
return []
|
||||
try:
|
||||
self._listing_cache[extension_name] = proxy.list_web_files(
|
||||
extension_name
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"][ WebDirectoryCache: failed to list files for %s",
|
||||
extension_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
return self._listing_cache[extension_name]
|
||||
|
||||
def get_file(
|
||||
self, extension_name: str, relative_path: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get file content (cached after first fetch). Returns None on miss."""
|
||||
ext_cache = self._file_cache.get(extension_name)
|
||||
if ext_cache and relative_path in ext_cache:
|
||||
return ext_cache[relative_path]
|
||||
|
||||
proxy = self._proxies.get(extension_name)
|
||||
if proxy is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = proxy.get_web_file(extension_name, relative_path)
|
||||
except (FileNotFoundError, ValueError):
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"][ WebDirectoryCache: failed to fetch %s/%s",
|
||||
extension_name,
|
||||
relative_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
decoded = {
|
||||
"content": base64.b64decode(result["content"]),
|
||||
"content_type": result["content_type"],
|
||||
}
|
||||
|
||||
if extension_name not in self._file_cache:
|
||||
self._file_cache[extension_name] = {}
|
||||
self._file_cache[extension_name][relative_path] = decoded
|
||||
return decoded
|
||||
|
||||
|
||||
# Global cache instance — populated during isolation loading
|
||||
_web_directory_cache = WebDirectoryCache()
|
||||
|
||||
|
||||
def get_web_directory_cache() -> WebDirectoryCache:
|
||||
return _web_directory_cache
|
||||
49
comfy/isolation/rpc_bridge.py
Normal file
49
comfy/isolation/rpc_bridge.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RpcBridge:
|
||||
"""Minimal helper to run coroutines synchronously inside isolated processes.
|
||||
|
||||
If an event loop is already running, the coroutine is executed on a fresh
|
||||
thread with its own loop to avoid nested run_until_complete errors.
|
||||
"""
|
||||
|
||||
def run_sync(self, maybe_coro):
|
||||
if not asyncio.iscoroutine(maybe_coro):
|
||||
return maybe_coro
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
result_container = {}
|
||||
exc_container = {}
|
||||
|
||||
def _runner():
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result_container["value"] = new_loop.run_until_complete(maybe_coro)
|
||||
except Exception as exc: # pragma: no cover
|
||||
exc_container["error"] = exc
|
||||
finally:
|
||||
try:
|
||||
new_loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
if "error" in exc_container:
|
||||
raise exc_container["error"]
|
||||
return result_container.get("value")
|
||||
|
||||
return asyncio.run(maybe_coro)
|
||||
471
comfy/isolation/runtime_helpers.py
Normal file
471
comfy/isolation/runtime_helpers.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
from .proxies.helper_proxies import restore_input_types
|
||||
from .shm_forensics import scan_shm_forensics
|
||||
|
||||
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||
|
||||
_ComfyNodeInternal = object
|
||||
latest_io = None
|
||||
|
||||
if _IMPORT_TORCH:
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from comfy_api.latest import _io as latest_io
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class _RemoteObjectRegistryCaller:
|
||||
def __init__(self, extension: Any) -> None:
|
||||
self._extension = extension
|
||||
|
||||
def __getattr__(self, method_name: str) -> Any:
|
||||
async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self._extension.call_remote_object_method(
|
||||
instance_id,
|
||||
method_name,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _call
|
||||
|
||||
|
||||
def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any:
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle
|
||||
|
||||
if isinstance(value, RemoteObjectHandle):
|
||||
if value.type_name == "ModelPatcher":
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
if value.type_name == "VAE":
|
||||
from comfy.isolation.vae_proxy import VAEProxy
|
||||
|
||||
proxy = VAEProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
if value.type_name == "CLIP":
|
||||
from comfy.isolation.clip_proxy import CLIPProxy
|
||||
|
||||
proxy = CLIPProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
if value.type_name == "ModelSampling":
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingProxy
|
||||
|
||||
proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
return value
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items()
|
||||
}
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value]
|
||||
return type(value)(wrapped)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _resource_snapshot() -> Dict[str, int]:
|
||||
fd_count = -1
|
||||
shm_sender_files = 0
|
||||
try:
|
||||
fd_count = len(os.listdir("/proc/self/fd"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
shm_root = Path("/dev/shm")
|
||||
if shm_root.exists():
|
||||
prefix = f"torch_{os.getpid()}_"
|
||||
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
|
||||
except Exception:
|
||||
pass
|
||||
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
|
||||
|
||||
|
||||
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
|
||||
summary: Dict[str, int] = {
|
||||
"tensor_count": 0,
|
||||
"cpu_tensors": 0,
|
||||
"cuda_tensors": 0,
|
||||
"shared_cpu_tensors": 0,
|
||||
"tensor_bytes": 0,
|
||||
}
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return summary
|
||||
|
||||
def visit(node: Any) -> None:
|
||||
if isinstance(node, torch.Tensor):
|
||||
summary["tensor_count"] += 1
|
||||
summary["tensor_bytes"] += int(node.numel() * node.element_size())
|
||||
if node.device.type == "cpu":
|
||||
summary["cpu_tensors"] += 1
|
||||
if node.is_shared():
|
||||
summary["shared_cpu_tensors"] += 1
|
||||
elif node.device.type == "cuda":
|
||||
summary["cuda_tensors"] += 1
|
||||
return
|
||||
if isinstance(node, dict):
|
||||
for v in node.values():
|
||||
visit(v)
|
||||
return
|
||||
if isinstance(node, (list, tuple)):
|
||||
for v in node:
|
||||
visit(v)
|
||||
|
||||
visit(value)
|
||||
return summary
|
||||
|
||||
|
||||
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
|
||||
for key, value in inputs.items():
|
||||
key_text = str(key)
|
||||
if "unique_id" in key_text:
|
||||
return str(value)
|
||||
return None
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return
|
||||
if not callable(flush_tensor_keeper):
|
||||
return
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
|
||||
|
||||
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _detach_shared_cpu_tensors(value: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.device.type == "cpu" and value.is_shared():
|
||||
clone = value.clone()
|
||||
if value.requires_grad:
|
||||
clone.requires_grad_(True)
|
||||
return clone
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return [_detach_shared_cpu_tensors(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_detach_shared_cpu_tensors(v) for v in value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def build_stub_class(
|
||||
node_name: str,
|
||||
info: Dict[str, object],
|
||||
extension: "ComfyNodeExtension",
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
logger: logging.Logger,
|
||||
) -> type:
|
||||
if latest_io is None:
|
||||
raise RuntimeError("comfy_api.latest._io is required to build isolation stubs")
|
||||
is_v3 = bool(info.get("is_v3", False))
|
||||
function_name = "_pyisolate_execute"
|
||||
restored_input_types = restore_input_types(info.get("input_types", {}))
|
||||
|
||||
async def _execute(self, **inputs):
|
||||
from comfy.isolation import _RUNNING_EXTENSIONS
|
||||
|
||||
# Update BOTH the local dict AND the module-level dict
|
||||
running_extensions[extension.name] = extension
|
||||
_RUNNING_EXTENSIONS[extension.name] = extension
|
||||
prev_child = None
|
||||
node_unique_id = _extract_hidden_unique_id(inputs)
|
||||
summary = _tensor_transport_summary(inputs)
|
||||
resources = _resource_snapshot()
|
||||
logger.debug(
|
||||
"%s ISO:execute_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
summary["tensor_count"],
|
||||
summary["cpu_tensors"],
|
||||
summary["cuda_tensors"],
|
||||
summary["shared_cpu_tensors"],
|
||||
summary["tensor_bytes"],
|
||||
resources["fd_count"],
|
||||
resources["shm_sender_files"],
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
|
||||
try:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
|
||||
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
|
||||
from pyisolate._internal.model_serialization import (
|
||||
serialize_for_isolation,
|
||||
deserialize_from_isolation,
|
||||
)
|
||||
|
||||
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
# Unwrap NodeOutput-like dicts before serialization.
|
||||
# OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)}
|
||||
# and the executor may pass this dict as input to downstream nodes.
|
||||
unwrapped_inputs = {}
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v):
|
||||
result = v.get("result")
|
||||
if isinstance(result, (tuple, list)) and len(result) > 0:
|
||||
unwrapped_inputs[k] = result[0]
|
||||
else:
|
||||
unwrapped_inputs[k] = result
|
||||
else:
|
||||
unwrapped_inputs[k] = v
|
||||
serialized = serialize_for_isolation(unwrapped_inputs)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
result = await extension.execute_node(node_name, **serialized)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
# Reconstruct NodeOutput if the child serialized one
|
||||
if isinstance(result, dict) and result.get("__node_output__"):
|
||||
from comfy_api.latest import io as latest_io
|
||||
args_raw = result.get("args", ())
|
||||
deserialized_args = await deserialize_from_isolation(args_raw, extension)
|
||||
deserialized_args = _wrap_remote_handles_as_host_proxies(
|
||||
deserialized_args, extension
|
||||
)
|
||||
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
|
||||
ui_raw = result.get("ui")
|
||||
deserialized_ui = None
|
||||
if ui_raw is not None:
|
||||
deserialized_ui = await deserialize_from_isolation(ui_raw, extension)
|
||||
deserialized_ui = _wrap_remote_handles_as_host_proxies(
|
||||
deserialized_ui, extension
|
||||
)
|
||||
deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return latest_io.NodeOutput(
|
||||
*deserialized_args,
|
||||
ui=deserialized_ui,
|
||||
expand=result.get("expand"),
|
||||
block_execution=result.get("block_execution"),
|
||||
)
|
||||
# OUTPUT_NODE: if sealed worker returned a tuple/list whose first
|
||||
# element is a {"ui": ...} dict, unwrap it for the executor.
|
||||
if (isinstance(result, (tuple, list)) and len(result) == 1
|
||||
and isinstance(result[0], dict) and "ui" in result[0]):
|
||||
return result[0]
|
||||
deserialized = await deserialize_from_isolation(result, extension)
|
||||
deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return _detach_shared_cpu_tensors(deserialized)
|
||||
except ImportError:
|
||||
return await extension.execute_node(node_name, **inputs)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:execute_error ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if prev_child is not None:
|
||||
os.environ["PYISOLATE_CHILD"] = prev_child
|
||||
logger.debug(
|
||||
"%s ISO:execute_end ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
|
||||
|
||||
def _input_types(
|
||||
cls,
|
||||
include_hidden: bool = True,
|
||||
return_schema: bool = False,
|
||||
live_inputs: Any = None,
|
||||
):
|
||||
if not is_v3:
|
||||
return restored_input_types
|
||||
|
||||
inputs_copy = copy.deepcopy(restored_input_types)
|
||||
if not include_hidden:
|
||||
inputs_copy.pop("hidden", None)
|
||||
|
||||
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
|
||||
dynamic = inputs_copy.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
|
||||
if return_schema:
|
||||
hidden_vals = info.get("hidden", []) or []
|
||||
hidden_enums = []
|
||||
for h in hidden_vals:
|
||||
try:
|
||||
hidden_enums.append(latest_io.Hidden(h))
|
||||
except Exception:
|
||||
hidden_enums.append(h)
|
||||
|
||||
class SchemaProxy:
|
||||
hidden = hidden_enums
|
||||
|
||||
return inputs_copy, SchemaProxy, v3_data
|
||||
return inputs_copy
|
||||
|
||||
def _validate_class(cls):
|
||||
return True
|
||||
|
||||
def _get_node_info_v1(cls):
|
||||
node_info = copy.deepcopy(info.get("schema_v1", {}))
|
||||
relative_python_module = node_info.get("python_module")
|
||||
if not isinstance(relative_python_module, str) or not relative_python_module:
|
||||
relative_python_module = f"custom_nodes.{extension.name}"
|
||||
node_info["python_module"] = relative_python_module
|
||||
return node_info
|
||||
|
||||
def _get_base_class(cls):
|
||||
return latest_io.ComfyNode
|
||||
|
||||
attributes: Dict[str, object] = {
|
||||
"FUNCTION": function_name,
|
||||
"CATEGORY": info.get("category", ""),
|
||||
"OUTPUT_NODE": info.get("output_node", False),
|
||||
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
|
||||
"RETURN_NAMES": info.get("return_names"),
|
||||
function_name: _execute,
|
||||
"_pyisolate_extension": extension,
|
||||
"_pyisolate_node_name": node_name,
|
||||
"INPUT_TYPES": classmethod(_input_types),
|
||||
}
|
||||
|
||||
output_is_list = info.get("output_is_list")
|
||||
if output_is_list is not None:
|
||||
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
|
||||
|
||||
if is_v3:
|
||||
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
|
||||
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
|
||||
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
|
||||
attributes["DESCRIPTION"] = info.get("description", "")
|
||||
attributes["EXPERIMENTAL"] = info.get("experimental", False)
|
||||
attributes["DEPRECATED"] = info.get("deprecated", False)
|
||||
attributes["API_NODE"] = info.get("api_node", False)
|
||||
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
|
||||
attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||
attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
|
||||
|
||||
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
|
||||
bases = (_ComfyNodeInternal,) if is_v3 else ()
|
||||
stub_cls = type(class_name, bases, attributes)
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
stub_cls.VALIDATE_CLASS()
|
||||
except Exception as e:
|
||||
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
|
||||
|
||||
return stub_cls
|
||||
|
||||
|
||||
def get_class_types_for_extension(
|
||||
extension_name: str,
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
specs: List[Any],
|
||||
) -> Set[str]:
|
||||
extension = running_extensions.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in specs:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
return class_types
|
||||
|
||||
|
||||
__all__ = ["build_stub_class", "get_class_types_for_extension"]
|
||||
217
comfy/isolation/shm_forensics.py
Normal file
217
comfy/isolation/shm_forensics.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _shm_debug_enabled() -> bool:
|
||||
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
|
||||
|
||||
|
||||
class _SHMForensicsTracker:
|
||||
def __init__(self) -> None:
|
||||
self._started = False
|
||||
self._tracked_files: Set[str] = set()
|
||||
self._current_model_context: Dict[str, str] = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _snapshot_shm() -> Set[str]:
|
||||
shm_path = Path("/dev/shm")
|
||||
if not shm_path.exists():
|
||||
return set()
|
||||
return {f.name for f in shm_path.glob("torch_*")}
|
||||
|
||||
def start(self) -> None:
|
||||
if self._started or not _shm_debug_enabled():
|
||||
return
|
||||
self._tracked_files = self._snapshot_shm()
|
||||
self._started = True
|
||||
logger.debug(
|
||||
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
self.scan("shutdown", refresh_model_context=True)
|
||||
self._started = False
|
||||
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
|
||||
|
||||
def _compute_model_hash(self, model_patcher: Any) -> str:
|
||||
try:
|
||||
model_instance_id = getattr(model_patcher, "_instance_id", None)
|
||||
if model_instance_id is not None:
|
||||
model_id_text = str(model_instance_id)
|
||||
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
|
||||
|
||||
import torch
|
||||
|
||||
real_model = (
|
||||
model_patcher.model
|
||||
if hasattr(model_patcher, "model")
|
||||
else model_patcher
|
||||
)
|
||||
tensor = None
|
||||
if hasattr(real_model, "parameters"):
|
||||
for p in real_model.parameters():
|
||||
if torch.is_tensor(p) and p.numel() > 0:
|
||||
tensor = p
|
||||
break
|
||||
|
||||
if tensor is None:
|
||||
return "0000"
|
||||
|
||||
flat = tensor.flatten()
|
||||
values = []
|
||||
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
|
||||
for i in indices:
|
||||
if i < flat.shape[0]:
|
||||
values.append(flat[i].item())
|
||||
|
||||
size = 0
|
||||
if hasattr(model_patcher, "model_size"):
|
||||
size = model_patcher.model_size()
|
||||
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
|
||||
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
|
||||
except Exception:
|
||||
return "err!"
|
||||
|
||||
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
snapshot: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for loaded_model in model_management.current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is None:
|
||||
continue
|
||||
if str(getattr(loaded_model, "device", "")) != "cuda:0":
|
||||
continue
|
||||
|
||||
name = (
|
||||
model.model.__class__.__name__
|
||||
if hasattr(model, "model")
|
||||
else type(model).__name__
|
||||
)
|
||||
model_hash = self._compute_model_hash(model)
|
||||
model_instance_id = getattr(model, "_instance_id", None)
|
||||
if model_instance_id is None:
|
||||
model_instance_id = model_hash
|
||||
snapshot.append(
|
||||
{
|
||||
"name": str(name),
|
||||
"id": str(model_instance_id),
|
||||
"hash": str(model_hash or "????"),
|
||||
"used": bool(getattr(loaded_model, "currently_used", False)),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return snapshot
|
||||
|
||||
def _update_model_context(self) -> None:
|
||||
snapshot = self._get_models_snapshot()
|
||||
selected = None
|
||||
|
||||
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
|
||||
if used_models:
|
||||
selected = used_models[-1]
|
||||
else:
|
||||
live_models = [m for m in snapshot if m.get("id")]
|
||||
if live_models:
|
||||
selected = live_models[-1]
|
||||
|
||||
if selected is None:
|
||||
self._current_model_context = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
return
|
||||
|
||||
self._current_model_context = {
|
||||
"id": str(selected.get("id", "unknown")),
|
||||
"name": str(selected.get("name", "unknown")),
|
||||
"hash": str(selected.get("hash", "????") or "????"),
|
||||
}
|
||||
|
||||
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
|
||||
if not self._started or not _shm_debug_enabled():
|
||||
return
|
||||
|
||||
if refresh_model_context:
|
||||
self._update_model_context()
|
||||
|
||||
current = self._snapshot_shm()
|
||||
added = current - self._tracked_files
|
||||
removed = self._tracked_files - current
|
||||
self._tracked_files = current
|
||||
|
||||
if not added and not removed:
|
||||
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
|
||||
return
|
||||
|
||||
for filename in sorted(added):
|
||||
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
|
||||
model_id = self._current_model_context["id"]
|
||||
if model_id == "unknown":
|
||||
logger.error(
|
||||
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
|
||||
LOG_PREFIX,
|
||||
filename,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
|
||||
LOG_PREFIX,
|
||||
model_id,
|
||||
filename,
|
||||
self._current_model_context["name"],
|
||||
self._current_model_context["hash"],
|
||||
)
|
||||
|
||||
for filename in sorted(removed):
|
||||
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
|
||||
|
||||
logger.debug(
|
||||
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
len(added),
|
||||
len(removed),
|
||||
len(self._tracked_files),
|
||||
)
|
||||
|
||||
|
||||
_TRACKER = _SHMForensicsTracker()
|
||||
|
||||
|
||||
def start_shm_forensics() -> None:
|
||||
_TRACKER.start()
|
||||
|
||||
|
||||
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
|
||||
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
|
||||
|
||||
|
||||
def stop_shm_forensics() -> None:
|
||||
_TRACKER.stop()
|
||||
|
||||
|
||||
atexit.register(stop_shm_forensics)
|
||||
214
comfy/isolation/vae_proxy.py
Normal file
214
comfy/isolation/vae_proxy.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FirstStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "first_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
async def has_property(self, instance_id: str, name: str) -> bool:
|
||||
obj = self._get_instance(instance_id)
|
||||
return hasattr(obj, name)
|
||||
|
||||
|
||||
class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]):
|
||||
_registry_class = FirstStageModelRegistry
|
||||
__module__ = "comfy.ldm.models.autoencoder"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FirstStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class VAERegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "vae"
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return ModelPatcherRegistry().register(vae.patcher)
|
||||
|
||||
async def get_first_stage_model_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return FirstStageModelRegistry().register(vae.first_stage_model)
|
||||
|
||||
async def encode(self, instance_id: str, pixels: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(pixels))
|
||||
|
||||
async def encode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
pixels: Any,
|
||||
tile_x: int = 512,
|
||||
tile_y: int = 512,
|
||||
overlap: int = 64,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_tiled(
|
||||
pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap
|
||||
)
|
||||
)
|
||||
|
||||
async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs))
|
||||
|
||||
async def decode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).decode_tiled(
|
||||
samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_encode(shape, dtype)
|
||||
|
||||
async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_decode(shape, dtype)
|
||||
|
||||
async def process_input(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_input(image))
|
||||
|
||||
async def process_output(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_output(image))
|
||||
|
||||
|
||||
class VAEProxy(BaseProxy[VAERegistry]):
|
||||
_registry_class = VAERegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
@property
|
||||
def patcher(self) -> ModelPatcherProxy:
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@property
|
||||
def first_stage_model(self) -> FirstStageModelProxy:
|
||||
if not hasattr(self, "_first_stage_model_proxy"):
|
||||
fsm_id = self._call_rpc("get_first_stage_model_id")
|
||||
self._first_stage_model_proxy = FirstStageModelProxy(
|
||||
fsm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._first_stage_model_proxy
|
||||
|
||||
@property
|
||||
def vae_dtype(self) -> Any:
|
||||
return self._get_property("vae_dtype")
|
||||
|
||||
def encode(self, pixels: Any) -> Any:
|
||||
return self._call_rpc("encode", pixels)
|
||||
|
||||
def encode_tiled(
|
||||
self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64
|
||||
) -> Any:
|
||||
return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap)
|
||||
|
||||
def decode(self, samples: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc("decode", samples, **kwargs)
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"decode_tiled", samples, tile_x, tile_y, overlap, **kwargs
|
||||
)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def _get_property(self, name: str) -> Any:
|
||||
return self._call_rpc("get_property", name)
|
||||
|
||||
@property
|
||||
def latent_dim(self) -> int:
|
||||
return self._get_property("latent_dim")
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return self._get_property("latent_channels")
|
||||
|
||||
@property
|
||||
def downscale_ratio(self) -> Any:
|
||||
return self._get_property("downscale_ratio")
|
||||
|
||||
@property
|
||||
def upscale_ratio(self) -> Any:
|
||||
return self._get_property("upscale_ratio")
|
||||
|
||||
@property
|
||||
def output_channels(self) -> int:
|
||||
return self._get_property("output_channels")
|
||||
|
||||
@property
|
||||
def check_not_vide(self) -> bool:
|
||||
return self._get_property("not_video")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self._get_property("device")
|
||||
|
||||
@property
|
||||
def working_dtypes(self) -> Any:
|
||||
return self._get_property("working_dtypes")
|
||||
|
||||
@property
|
||||
def disable_offload(self) -> bool:
|
||||
return self._get_property("disable_offload")
|
||||
|
||||
@property
|
||||
def size(self) -> Any:
|
||||
return self._get_property("size")
|
||||
|
||||
def memory_used_encode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_encode", shape, dtype)
|
||||
|
||||
def memory_used_decode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_decode", shape, dtype)
|
||||
|
||||
def process_input(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_input", image)
|
||||
|
||||
def process_output(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_output", image)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_VAE_REGISTRY_SINGLETON = VAERegistry()
|
||||
_FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
@@ -12,8 +13,8 @@ from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.cli_args import args
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
@@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if isolation_active:
|
||||
target_device = sigmas.device
|
||||
if x.device != target_device:
|
||||
x = x.to(target_device)
|
||||
s_in = s_in.to(target_device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
|
||||
@@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor * m_mult
|
||||
else:
|
||||
for d in modulation_dims:
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
|
||||
if m_add is not None:
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
|
||||
return tensor
|
||||
|
||||
|
||||
@@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
|
||||
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
@@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module):
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
@@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
|
||||
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
@@ -170,7 +170,7 @@ class Flux(nn.Module):
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
@@ -81,7 +82,7 @@ class LowPassFilter1d(nn.Module):
|
||||
_, C, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
@@ -125,7 +126,7 @@ class UpSample1d(nn.Module):
|
||||
_, C, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C
|
||||
)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
return x
|
||||
@@ -190,7 +191,7 @@ class Snake(nn.Module):
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||
@@ -217,8 +218,8 @@ class SnakeBeta(nn.Module):
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
b = torch.exp(b)
|
||||
@@ -596,7 +597,7 @@ class _STFTFn(nn.Module):
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||
@@ -647,7 +648,7 @@ class MelSTFT(nn.Module):
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
@@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
|
||||
@@ -258,7 +258,8 @@ def slice_attention(q, k, v):
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
model_management.soft_empty_cache(True)
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
@@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
|
||||
try:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
oom_fallback = True
|
||||
if oom_fallback:
|
||||
|
||||
@@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
|
||||
@@ -149,6 +149,9 @@ class Attention(nn.Module):
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
@@ -167,15 +170,22 @@ class Attention(nn.Module):
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
if encoder_hidden_states_mask is not None:
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
|
||||
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
|
||||
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attn_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
@@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
@@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = num_embeds
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
@@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
||||
if tp > 0 and not k.startswith("clip_"):
|
||||
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import comfy.ldm.lightricks.av_model
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
@@ -112,8 +113,20 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
from comfy.cli_args import args
|
||||
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
if isolation_runtime_enabled:
|
||||
def __reduce__(self):
|
||||
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||
try:
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||
registry = ModelSamplingRegistry()
|
||||
ms_id = registry.register(self)
|
||||
return (ModelSamplingProxy, (ms_id,))
|
||||
except Exception as exc:
|
||||
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
import comfy.utils
|
||||
@@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
new[:old_weight.shape[0]] = old_weight
|
||||
old_weight = new
|
||||
|
||||
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
|
||||
old_weight = old_weight.clone()
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
weight = weight.clone()
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
|
||||
@@ -270,6 +270,23 @@ try:
|
||||
except:
|
||||
OOM_EXCEPTION = Exception
|
||||
|
||||
try:
|
||||
ACCELERATOR_ERROR = torch.AcceleratorError
|
||||
except AttributeError:
|
||||
ACCELERATOR_ERROR = RuntimeError
|
||||
|
||||
def is_oom(e):
|
||||
if isinstance(e, OOM_EXCEPTION):
|
||||
return True
|
||||
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
|
||||
discard_cuda_async_error()
|
||||
return True
|
||||
return False
|
||||
|
||||
def raise_non_oom(e):
|
||||
if not is_oom(e):
|
||||
raise e
|
||||
|
||||
XFORMERS_VERSION = ""
|
||||
XFORMERS_ENABLED_VAE = True
|
||||
if args.disable_xformers:
|
||||
@@ -355,7 +372,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
@@ -383,7 +400,7 @@ try:
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
@@ -480,6 +497,9 @@ except:
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
def _isolation_mode_enabled():
|
||||
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@@ -559,8 +579,9 @@ class LoadedModel:
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
if self.model_finalizer is not None:
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
@@ -574,8 +595,15 @@ class LoadedModel:
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def dead_state(self):
|
||||
model_ref_gone = self.model is None
|
||||
real_model_ref = self.real_model
|
||||
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
|
||||
return model_ref_gone, real_model_ref_gone
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
model_ref_gone, real_model_ref_gone = self.dead_state()
|
||||
return model_ref_gone or real_model_ref_gone
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
@@ -621,6 +649,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
isolation_active = _isolation_mode_enabled()
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
@@ -629,6 +658,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
if can_unload and isolation_active:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
flush_tensor_keeper = None
|
||||
if callable(flush_tensor_keeper):
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
|
||||
gc.collect()
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
@@ -649,7 +689,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
unloaded = current_loaded_models.pop(i)
|
||||
model_obj = unloaded.model
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
unloaded_models.append(unloaded)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@@ -708,7 +754,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
for i in to_unload:
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
if model_to_unload.model_finalizer is not None:
|
||||
model_to_unload.model_finalizer.detach()
|
||||
model_to_unload.model_finalizer = None
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
@@ -771,25 +819,62 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
if not _isolation_mode_enabled():
|
||||
dead_found = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].is_dead():
|
||||
dead_found = True
|
||||
break
|
||||
|
||||
if dead_found:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
return
|
||||
|
||||
dead_found = False
|
||||
has_real_model_leak = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
dead_found = True
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
has_real_model_leak = True
|
||||
|
||||
if do_gc:
|
||||
if dead_found:
|
||||
if has_real_model_leak:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
model_ref_gone, real_model_ref_gone = cur.dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
|
||||
|
||||
def archive_model_dtypes(model):
|
||||
@@ -803,11 +888,20 @@ def archive_model_dtypes(model):
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
real_model_ref = current_loaded_models[i].real_model
|
||||
if real_model_ref is None:
|
||||
to_delete = [i] + to_delete
|
||||
continue
|
||||
if callable(real_model_ref) and real_model_ref() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
model_obj = getattr(x, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
@@ -939,7 +1033,7 @@ def text_encoder_offload_device():
|
||||
def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
|
||||
if should_use_fp16(prioritize_performance=False):
|
||||
return get_torch_device()
|
||||
else:
|
||||
@@ -1148,6 +1242,7 @@ def reset_cast_buffers():
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
@@ -1262,7 +1357,7 @@ def discard_cuda_async_error():
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
except RuntimeError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
|
||||
@@ -599,6 +599,27 @@ class ModelPatcher:
|
||||
|
||||
return models
|
||||
|
||||
def model_patches_call_function(self, function_name="cleanup", arguments={}):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], function_name):
|
||||
getattr(patch_list[i], function_name)(**arguments)
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], function_name):
|
||||
getattr(patch_list[k], function_name)(**arguments)
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, function_name):
|
||||
getattr(wrap_func, function_name)(**arguments)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
@@ -715,8 +736,8 @@ class ModelPatcher:
|
||||
default = True # default random weights in non leaf modules
|
||||
break
|
||||
if default and default_device is not None:
|
||||
for param in params.values():
|
||||
param.data = param.data.to(device=default_device)
|
||||
for param_name, param in params.items():
|
||||
param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
|
||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
@@ -1062,6 +1083,7 @@ class ModelPatcher:
|
||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||
|
||||
def cleanup(self):
|
||||
self.model_patches_call_function(function_name="cleanup")
|
||||
self.clean_hooks()
|
||||
if hasattr(self.model, "current_patcher"):
|
||||
self.model.current_patcher = None
|
||||
|
||||
@@ -11,12 +11,14 @@ from functools import partial
|
||||
import collections
|
||||
import math
|
||||
import logging
|
||||
import os
|
||||
import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
from comfy.cli_args import args
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
result = executor.execute(model, conds, x_in, timestep, model_options)
|
||||
return result
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
|
||||
if memory_required * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
if isolation_active:
|
||||
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
|
||||
input_x = torch.cat(input_x).to(target_device)
|
||||
else:
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
if isolation_active:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
|
||||
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
|
||||
else:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
@@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
out_t = output[o]
|
||||
mult_t = mult[o]
|
||||
if isolation_active:
|
||||
target_dev = out_conds[cond_index].device
|
||||
if hasattr(out_t, "device") and out_t.device != target_dev:
|
||||
out_t = out_t.to(target_dev)
|
||||
if hasattr(mult_t, "device") and mult_t.device != target_dev:
|
||||
mult_t = mult_t.to(target_dev)
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
out_conds[cond_index] += out_t * mult_t
|
||||
out_counts[cond_index] += mult_t
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
@@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
out_c += out_t * mult_t
|
||||
out_cts += mult_t
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@@ -392,14 +413,31 @@ class KSamplerX0Inpaint:
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if denoise_mask is not None:
|
||||
if isolation_active and denoise_mask.device != x.device:
|
||||
denoise_mask = denoise_mask.to(x.device)
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||
if isolation_active:
|
||||
latent_image = self.latent_image
|
||||
if hasattr(latent_image, "device") and latent_image.device != x.device:
|
||||
latent_image = latent_image.to(x.device)
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
|
||||
if hasattr(scaled, "device") and scaled.device != x.device:
|
||||
scaled = scaled.to(x.device)
|
||||
else:
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(
|
||||
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
|
||||
)
|
||||
x = x * denoise_mask + scaled * latent_mask
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
latent_image = self.latent_image
|
||||
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
|
||||
latent_image = latent_image.to(out.device)
|
||||
out = out * denoise_mask + latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model_sampling, steps):
|
||||
@@ -741,7 +779,11 @@ class KSAMPLER(Sampler):
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
model_sampling = model_wrap.inner_model.model_sampling
|
||||
noise = model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
|
||||
@@ -954,7 +954,8 @@ class VAE:
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
@@ -1029,7 +1030,8 @@ class VAE:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
|
||||
@@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
"assets": args.enable_assets,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D, PLY as _PLY, NPZ as _NPZ
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@@ -678,6 +678,16 @@ class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
|
||||
@comfytype(io_type="PLY")
|
||||
class Ply(ComfyTypeIO):
|
||||
Type = _PLY
|
||||
|
||||
|
||||
@comfytype(io_type="NPZ")
|
||||
class Npz(ComfyTypeIO):
|
||||
Type = _NPZ
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
"""General 3D file type - accepts any supported 3D format."""
|
||||
@@ -2197,6 +2207,8 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"Ply",
|
||||
"Npz",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
|
||||
@@ -65,6 +65,22 @@ class SavedAudios(_UIOutput):
|
||||
return {"audio": self.results}
|
||||
|
||||
|
||||
def _is_isolated_child() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def _get_preview_folder_type() -> FolderType:
|
||||
if _is_isolated_child():
|
||||
return FolderType.output
|
||||
return FolderType.temp
|
||||
|
||||
|
||||
def _get_preview_route_prefix(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.output:
|
||||
return "output"
|
||||
return "temp"
|
||||
|
||||
|
||||
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.input:
|
||||
return folder_paths.get_input_directory()
|
||||
@@ -388,10 +404,11 @@ class AudioSaveHelper:
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
|
||||
folder_type = _get_preview_folder_type()
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
folder_type=folder_type,
|
||||
cls=cls,
|
||||
compress_level=1,
|
||||
)
|
||||
@@ -412,10 +429,11 @@ class PreviewMask(PreviewImage):
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
|
||||
folder_type = _get_preview_folder_type()
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
folder_type=folder_type,
|
||||
cls=cls,
|
||||
format="flac",
|
||||
quality="128k",
|
||||
@@ -438,15 +456,16 @@ class PreviewUI3D(_UIOutput):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
self.bg_image_path = None
|
||||
folder_type = _get_preview_folder_type()
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
if bg_image is not None:
|
||||
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||
img = PILImage.fromarray(img_array)
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
preview_dir = _get_directory_by_folder_type(folder_type)
|
||||
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||
bg_image_path = os.path.join(temp_dir, filename)
|
||||
bg_image_path = os.path.join(preview_dir, filename)
|
||||
img.save(bg_image_path, compress_level=1)
|
||||
self.bg_image_path = f"temp/{filename}"
|
||||
self.bg_image_path = f"{_get_preview_route_prefix(folder_type)}/{filename}"
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .image_types import SVG
|
||||
from .ply_types import PLY
|
||||
from .npz_types import NPZ
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
@@ -11,4 +13,6 @@ __all__ = [
|
||||
"MESH",
|
||||
"File3D",
|
||||
"SVG",
|
||||
"PLY",
|
||||
"NPZ",
|
||||
]
|
||||
|
||||
27
comfy_api/latest/_util/npz_types.py
Normal file
27
comfy_api/latest/_util/npz_types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class NPZ:
|
||||
"""Ordered collection of NPZ file payloads.
|
||||
|
||||
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
|
||||
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
|
||||
``save_to`` writes numbered files into a directory.
|
||||
"""
|
||||
|
||||
def __init__(self, frames: list[bytes]) -> None:
|
||||
self.frames = frames
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return len(self.frames)
|
||||
|
||||
def save_to(self, directory: str, prefix: str = "frame") -> str:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for i, frame_bytes in enumerate(self.frames):
|
||||
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
|
||||
with open(path, "wb") as f:
|
||||
f.write(frame_bytes)
|
||||
return directory
|
||||
97
comfy_api/latest/_util/ply_types.py
Normal file
97
comfy_api/latest/_util/ply_types.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PLY:
|
||||
"""Point cloud payload for PLY file output.
|
||||
|
||||
Supports two schemas:
|
||||
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
|
||||
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
|
||||
|
||||
When ``raw_data`` is provided, the object acts as an opaque binary PLY
|
||||
carrier and ``save_to`` writes the bytes directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
points: np.ndarray | None = None,
|
||||
colors: np.ndarray | None = None,
|
||||
confidence: np.ndarray | None = None,
|
||||
view_id: np.ndarray | None = None,
|
||||
raw_data: bytes | None = None,
|
||||
) -> None:
|
||||
self.raw_data = raw_data
|
||||
if raw_data is not None:
|
||||
self.points = None
|
||||
self.colors = None
|
||||
self.confidence = None
|
||||
self.view_id = None
|
||||
return
|
||||
if points is None:
|
||||
raise ValueError("Either points or raw_data must be provided")
|
||||
if points.ndim != 2 or points.shape[1] != 3:
|
||||
raise ValueError(f"points must be (N, 3), got {points.shape}")
|
||||
self.points = np.ascontiguousarray(points, dtype=np.float32)
|
||||
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
|
||||
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
|
||||
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
|
||||
|
||||
@property
|
||||
def is_gaussian(self) -> bool:
|
||||
return self.raw_data is not None
|
||||
|
||||
@property
|
||||
def num_points(self) -> int:
|
||||
if self.points is not None:
|
||||
return self.points.shape[0]
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
if self.raw_data is not None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(self.raw_data)
|
||||
return path
|
||||
self.points = self._to_numpy(self.points, np.float32)
|
||||
self.colors = self._to_numpy(self.colors, np.float32)
|
||||
self.confidence = self._to_numpy(self.confidence, np.float32)
|
||||
self.view_id = self._to_numpy(self.view_id, np.int32)
|
||||
N = self.num_points
|
||||
header_lines = [
|
||||
"ply",
|
||||
"format ascii 1.0",
|
||||
f"element vertex {N}",
|
||||
"property float x",
|
||||
"property float y",
|
||||
"property float z",
|
||||
]
|
||||
if self.colors is not None:
|
||||
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
|
||||
if self.confidence is not None:
|
||||
header_lines.append("property float confidence")
|
||||
if self.view_id is not None:
|
||||
header_lines.append("property int view_id")
|
||||
header_lines.append("end_header")
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(header_lines) + "\n")
|
||||
for i in range(N):
|
||||
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
|
||||
if self.colors is not None:
|
||||
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
|
||||
parts.append(f"{r} {g} {b}")
|
||||
if self.confidence is not None:
|
||||
parts.append(f"{self.confidence[i]}")
|
||||
if self.view_id is not None:
|
||||
parts.append(f"{int(self.view_id[i])}")
|
||||
f.write(" ".join(parts) + "\n")
|
||||
return path
|
||||
259
comfy_api/latest/_util/trimesh_types.py
Normal file
259
comfy_api/latest/_util/trimesh_types.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrimeshData:
|
||||
"""Triangular mesh payload for cross-process transfer.
|
||||
|
||||
Lightweight carrier for mesh geometry that does not depend on the
|
||||
``trimesh`` library. Serializers create this on the host side;
|
||||
isolated child processes convert to/from ``trimesh.Trimesh`` as needed.
|
||||
|
||||
Supports both ColorVisuals (vertex_colors) and TextureVisuals
|
||||
(uv + material with textures).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vertices: np.ndarray,
|
||||
faces: np.ndarray,
|
||||
vertex_normals: np.ndarray | None = None,
|
||||
face_normals: np.ndarray | None = None,
|
||||
vertex_colors: np.ndarray | None = None,
|
||||
uv: np.ndarray | None = None,
|
||||
material: dict | None = None,
|
||||
vertex_attributes: dict | None = None,
|
||||
face_attributes: dict | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
self.vertices = np.ascontiguousarray(vertices, dtype=np.float64)
|
||||
self.faces = np.ascontiguousarray(faces, dtype=np.int64)
|
||||
self.vertex_normals = (
|
||||
np.ascontiguousarray(vertex_normals, dtype=np.float64)
|
||||
if vertex_normals is not None
|
||||
else None
|
||||
)
|
||||
self.face_normals = (
|
||||
np.ascontiguousarray(face_normals, dtype=np.float64)
|
||||
if face_normals is not None
|
||||
else None
|
||||
)
|
||||
self.vertex_colors = (
|
||||
np.ascontiguousarray(vertex_colors, dtype=np.uint8)
|
||||
if vertex_colors is not None
|
||||
else None
|
||||
)
|
||||
self.uv = (
|
||||
np.ascontiguousarray(uv, dtype=np.float64)
|
||||
if uv is not None
|
||||
else None
|
||||
)
|
||||
self.material = material
|
||||
self.vertex_attributes = vertex_attributes or {}
|
||||
self.face_attributes = face_attributes or {}
|
||||
self.metadata = self._detensorize_dict(metadata) if metadata else {}
|
||||
|
||||
@staticmethod
|
||||
def _detensorize_dict(d):
|
||||
"""Recursively convert any tensors in a dict back to numpy arrays."""
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if hasattr(v, "numpy"):
|
||||
result[k] = v.cpu().numpy() if hasattr(v, "cpu") else v.numpy()
|
||||
elif isinstance(v, dict):
|
||||
result[k] = TrimeshData._detensorize_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [
|
||||
item.cpu().numpy() if hasattr(item, "numpy") and hasattr(item, "cpu")
|
||||
else item.numpy() if hasattr(item, "numpy")
|
||||
else item
|
||||
for item in v
|
||||
]
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
@property
|
||||
def num_vertices(self) -> int:
|
||||
return self.vertices.shape[0]
|
||||
|
||||
@property
|
||||
def num_faces(self) -> int:
|
||||
return self.faces.shape[0]
|
||||
|
||||
@property
|
||||
def has_texture(self) -> bool:
|
||||
return self.uv is not None and self.material is not None
|
||||
|
||||
def to_trimesh(self):
|
||||
"""Convert to trimesh.Trimesh (requires trimesh in the environment)."""
|
||||
import trimesh
|
||||
from trimesh.visual import TextureVisuals
|
||||
|
||||
kwargs = {}
|
||||
if self.vertex_normals is not None:
|
||||
kwargs["vertex_normals"] = self.vertex_normals
|
||||
if self.face_normals is not None:
|
||||
kwargs["face_normals"] = self.face_normals
|
||||
if self.metadata:
|
||||
kwargs["metadata"] = self.metadata
|
||||
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices=self.vertices, faces=self.faces, process=False, **kwargs
|
||||
)
|
||||
|
||||
# Reconstruct visual
|
||||
if self.has_texture:
|
||||
material = self._dict_to_material(self.material)
|
||||
mesh.visual = TextureVisuals(uv=self.uv, material=material)
|
||||
elif self.vertex_colors is not None:
|
||||
mesh.visual.vertex_colors = self.vertex_colors
|
||||
|
||||
for k, v in self.vertex_attributes.items():
|
||||
mesh.vertex_attributes[k] = v
|
||||
|
||||
for k, v in self.face_attributes.items():
|
||||
mesh.face_attributes[k] = v
|
||||
|
||||
return mesh
|
||||
|
||||
@staticmethod
|
||||
def _material_to_dict(material) -> dict:
|
||||
"""Serialize a trimesh material to a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
result = {"type": type(material).__name__, "name": getattr(material, "name", None)}
|
||||
|
||||
if isinstance(material, PBRMaterial):
|
||||
result["baseColorFactor"] = material.baseColorFactor
|
||||
result["metallicFactor"] = material.metallicFactor
|
||||
result["roughnessFactor"] = material.roughnessFactor
|
||||
result["emissiveFactor"] = material.emissiveFactor
|
||||
result["alphaMode"] = material.alphaMode
|
||||
result["alphaCutoff"] = material.alphaCutoff
|
||||
result["doubleSided"] = material.doubleSided
|
||||
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
tex = getattr(material, tex_name, None)
|
||||
if tex is not None:
|
||||
buf = BytesIO()
|
||||
tex.save(buf, format="PNG")
|
||||
result[tex_name] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
elif isinstance(material, SimpleMaterial):
|
||||
result["main_color"] = list(material.main_color) if material.main_color is not None else None
|
||||
result["glossiness"] = material.glossiness
|
||||
if hasattr(material, "image") and material.image is not None:
|
||||
buf = BytesIO()
|
||||
material.image.save(buf, format="PNG")
|
||||
result["image"] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_material(d: dict):
|
||||
"""Reconstruct a trimesh material from a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
mat_type = d.get("type", "PBRMaterial")
|
||||
|
||||
if mat_type == "PBRMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"baseColorFactor": d.get("baseColorFactor"),
|
||||
"metallicFactor": d.get("metallicFactor"),
|
||||
"roughnessFactor": d.get("roughnessFactor"),
|
||||
"emissiveFactor": d.get("emissiveFactor"),
|
||||
"alphaMode": d.get("alphaMode"),
|
||||
"alphaCutoff": d.get("alphaCutoff"),
|
||||
"doubleSided": d.get("doubleSided"),
|
||||
}
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
if tex_name in d and d[tex_name] is not None:
|
||||
img = Image.open(BytesIO(base64.b64decode(d[tex_name])))
|
||||
kwargs[tex_name] = img
|
||||
return PBRMaterial(**{k: v for k, v in kwargs.items() if v is not None})
|
||||
|
||||
elif mat_type == "SimpleMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"glossiness": d.get("glossiness"),
|
||||
}
|
||||
if d.get("main_color") is not None:
|
||||
kwargs["diffuse"] = d["main_color"]
|
||||
if d.get("image") is not None:
|
||||
kwargs["image"] = Image.open(BytesIO(base64.b64decode(d["image"])))
|
||||
return SimpleMaterial(**kwargs)
|
||||
|
||||
raise ValueError(f"Unknown material type: {mat_type}")
|
||||
|
||||
@classmethod
|
||||
def from_trimesh(cls, mesh) -> TrimeshData:
|
||||
"""Create from a trimesh.Trimesh object."""
|
||||
from trimesh.visual.texture import TextureVisuals
|
||||
|
||||
vertex_normals = None
|
||||
if mesh._cache.cache.get("vertex_normals") is not None:
|
||||
vertex_normals = np.asarray(mesh.vertex_normals)
|
||||
|
||||
face_normals = None
|
||||
if mesh._cache.cache.get("face_normals") is not None:
|
||||
face_normals = np.asarray(mesh.face_normals)
|
||||
|
||||
vertex_colors = None
|
||||
uv = None
|
||||
material = None
|
||||
|
||||
if isinstance(mesh.visual, TextureVisuals):
|
||||
if mesh.visual.uv is not None:
|
||||
uv = np.asarray(mesh.visual.uv, dtype=np.float64)
|
||||
if mesh.visual.material is not None:
|
||||
material = cls._material_to_dict(mesh.visual.material)
|
||||
else:
|
||||
try:
|
||||
vc = mesh.visual.vertex_colors
|
||||
if vc is not None and len(vc) > 0:
|
||||
vertex_colors = np.asarray(vc, dtype=np.uint8)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
va = {}
|
||||
if hasattr(mesh, "vertex_attributes") and mesh.vertex_attributes:
|
||||
for k, v in mesh.vertex_attributes.items():
|
||||
va[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
fa = {}
|
||||
if hasattr(mesh, "face_attributes") and mesh.face_attributes:
|
||||
for k, v in mesh.face_attributes.items():
|
||||
fa[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
return cls(
|
||||
vertices=np.asarray(mesh.vertices),
|
||||
faces=np.asarray(mesh.faces),
|
||||
vertex_normals=vertex_normals,
|
||||
face_normals=face_normals,
|
||||
vertex_colors=vertex_colors,
|
||||
uv=uv,
|
||||
material=material,
|
||||
vertex_attributes=va if va else None,
|
||||
face_attributes=fa if fa else None,
|
||||
metadata=mesh.metadata if mesh.metadata else None,
|
||||
)
|
||||
@@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel):
|
||||
JobId: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVFileInput(BaseModel):
|
||||
class TaskFile3DInput(BaseModel):
|
||||
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
||||
Url: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVTaskRequest(BaseModel):
|
||||
File: To3DUVFileInput = Field(...)
|
||||
File: TaskFile3DInput = Field(...)
|
||||
|
||||
|
||||
class To3DPartTaskRequest(BaseModel):
|
||||
File: TaskFile3DInput = Field(...)
|
||||
|
||||
|
||||
class TextureEditImageInfo(BaseModel):
|
||||
@@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel):
|
||||
|
||||
|
||||
class TextureEditTaskRequest(BaseModel):
|
||||
File3D: To3DUVFileInput = Field(...)
|
||||
File3D: TaskFile3DInput = Field(...)
|
||||
Image: TextureEditImageInfo | None = Field(None)
|
||||
Prompt: str | None = Field(None)
|
||||
EnablePBR: bool | None = Field(None)
|
||||
|
||||
|
||||
class SmartTopologyRequest(BaseModel):
|
||||
File3D: TaskFile3DInput = Field(...)
|
||||
PolygonType: str | None = Field(...)
|
||||
FaceLevel: str | None = Field(...)
|
||||
|
||||
68
comfy_api_nodes/apis/reve.py
Normal file
68
comfy_api_nodes/apis/reve.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RevePostprocessingOperation(BaseModel):
|
||||
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
|
||||
upscale_factor: int | None = Field(
|
||||
None,
|
||||
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
|
||||
ge=2,
|
||||
le=4,
|
||||
)
|
||||
|
||||
|
||||
class ReveImageCreateRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageEditRequest(BaseModel):
|
||||
edit_instruction: str = Field(...)
|
||||
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int | None = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageRemixRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int | None = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageResponse(BaseModel):
|
||||
image: str | None = Field(None, description="The base64 encoded image data.")
|
||||
request_id: str | None = Field(None, description="A unique id for the request.")
|
||||
credits_used: float | None = Field(None, description="The number of credits used for this request.")
|
||||
version: str | None = Field(None, description="The specific model version used.")
|
||||
content_violation: bool | None = Field(
|
||||
None, description="Indicates whether the generated image violates the content policy."
|
||||
)
|
||||
@@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
"""
|
||||
Gemini Model Names allowed by comfy-api
|
||||
"""
|
||||
|
||||
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||
gemini_2_5_pro = "gemini-2.5-pro"
|
||||
gemini_2_5_flash = "gemini-2.5-flash"
|
||||
gemini_3_0_pro = "gemini-3-pro-preview"
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
"""
|
||||
Gemini Image Model Names allowed by comfy-api
|
||||
@@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
input_tokens_price = 0.30
|
||||
output_text_tokens_price = 2.50
|
||||
output_image_tokens_price = 30.0
|
||||
elif response.modelVersion == "gemini-3-pro-preview":
|
||||
elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"):
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
@@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=GeminiModel,
|
||||
default=GeminiModel.gemini_2_5_pro,
|
||||
options=[
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-1-pro",
|
||||
"gemini-3-1-flash-lite",
|
||||
],
|
||||
default="gemini-3-1-pro",
|
||||
tooltip="The Gemini model to use for generating responses.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
@@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode):
|
||||
"usd": [0.00125, 0.01],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gemini-3-pro-preview") ? {
|
||||
: ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.002, 0.012],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gemini-3-1-flash-lite") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.0015],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: {"type":"text", "text":"Token-based"}
|
||||
)
|
||||
""",
|
||||
@@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode):
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
if model == "gemini-3-pro-preview":
|
||||
model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google
|
||||
elif model == "gemini-3-1-pro":
|
||||
model = "gemini-3.1-pro-preview"
|
||||
elif model == "gemini-3-1-flash-lite":
|
||||
model = "gemini-3.1-flash-lite-preview"
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if audio is not None:
|
||||
|
||||
@@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
||||
Hunyuan3DViewImage,
|
||||
InputGenerateType,
|
||||
ResultFile3D,
|
||||
SmartTopologyRequest,
|
||||
TaskFile3DInput,
|
||||
TextureEditTaskRequest,
|
||||
To3DPartTaskRequest,
|
||||
To3DProTaskCreateResponse,
|
||||
To3DProTaskQueryRequest,
|
||||
To3DProTaskRequest,
|
||||
To3DProTaskResultResponse,
|
||||
To3DUVFileInput,
|
||||
To3DUVTaskRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
@@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(
|
||||
File=TaskFile3DInput(
|
||||
Type=file_format.upper(),
|
||||
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
|
||||
)
|
||||
@@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
|
||||
)
|
||||
|
||||
|
||||
@@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=TextureEditTaskRequest(
|
||||
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
|
||||
Prompt=prompt,
|
||||
EnablePBR=True,
|
||||
),
|
||||
@@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
data=To3DPartTaskRequest(
|
||||
File=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
@@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class TencentSmartTopologyNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentSmartTopologyNode",
|
||||
display_name="Hunyuan3D: Smart Topology",
|
||||
category="api node/3d/Tencent",
|
||||
description="Perform smart retopology on a 3D model. "
|
||||
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny],
|
||||
tooltip="Input 3D model (GLB or OBJ)",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"polygon_type",
|
||||
options=["triangle", "quadrilateral"],
|
||||
tooltip="Surface composition type.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"face_level",
|
||||
options=["medium", "high", "low"],
|
||||
tooltip="Polygon reduction level.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'),
|
||||
)
|
||||
|
||||
SUPPORTED_FORMATS = {"glb", "obj"}
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
polygon_type: str,
|
||||
face_level: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
|
||||
)
|
||||
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=SmartTopologyRequest(
|
||||
File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
|
||||
PolygonType=polygon_type,
|
||||
FaceLevel=face_level,
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
)
|
||||
|
||||
|
||||
class TencentHunyuan3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
# TencentModelTo3DUVNode,
|
||||
TencentModelTo3DUVNode,
|
||||
# Tencent3DTextureEditNode,
|
||||
Tencent3DPartNode,
|
||||
TencentSmartTopologyNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
395
comfy_api_nodes/nodes_reve.py
Normal file
395
comfy_api_nodes/nodes_reve.py
Normal file
@@ -0,0 +1,395 @@
|
||||
from io import BytesIO
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.reve import (
|
||||
ReveImageCreateRequest,
|
||||
ReveImageEditRequest,
|
||||
ReveImageRemixRequest,
|
||||
RevePostprocessingOperation,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
sync_op_raw,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
|
||||
ops = []
|
||||
if upscale["upscale"] == "enabled":
|
||||
ops.append(
|
||||
RevePostprocessingOperation(
|
||||
process="upscale",
|
||||
upscale_factor=upscale["upscale_factor"],
|
||||
)
|
||||
)
|
||||
if remove_background:
|
||||
ops.append(RevePostprocessingOperation(process="remove_background"))
|
||||
return ops or None
|
||||
|
||||
|
||||
def _postprocessing_inputs():
|
||||
return [
|
||||
IO.DynamicCombo.Input(
|
||||
"upscale",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"enabled",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"upscale_factor",
|
||||
default=2,
|
||||
min=2,
|
||||
max=4,
|
||||
step=1,
|
||||
tooltip="Upscale factor (2x, 3x, or 4x).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Upscale the generated image. May add additional cost.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"remove_background",
|
||||
default=False,
|
||||
tooltip="Remove the background from the generated image. May add additional cost.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _reve_price_extractor(headers: dict) -> float | None:
|
||||
credits_used = headers.get("x-reve-credits-used")
|
||||
if credits_used is not None:
|
||||
return float(credits_used) / 524.48
|
||||
return None
|
||||
|
||||
|
||||
def _reve_response_header_validator(headers: dict) -> None:
|
||||
error_code = headers.get("x-reve-error-code")
|
||||
if error_code:
|
||||
raise ValueError(f"Reve API error: {error_code}")
|
||||
if headers.get("x-reve-content-violation", "").lower() == "true":
|
||||
raise ValueError("The generated image was flagged for content policy violation.")
|
||||
|
||||
|
||||
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
|
||||
return [
|
||||
IO.DynamicCombo.Option(
|
||||
version,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=aspect_ratios,
|
||||
tooltip="Aspect ratio of the output image.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"test_time_scaling",
|
||||
default=1,
|
||||
min=1,
|
||||
max=5,
|
||||
step=1,
|
||||
tooltip="Higher values produce better images but cost more credits.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
for version in versions
|
||||
]
|
||||
|
||||
|
||||
class ReveImageCreateNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageCreateNode",
|
||||
display_name="Reve Image Create",
|
||||
category="api node/image/Reve",
|
||||
description="Generate images from text descriptions using Reve.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. Maximum 2560 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-create@20250915"],
|
||||
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for generation.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2560)
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/create",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageCreateRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
version=model["model"],
|
||||
test_time_scaling=model["test_time_scaling"],
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageEditNode",
|
||||
display_name="Reve Image Edit",
|
||||
category="api node/image/Reve",
|
||||
description="Edit images using natural language instructions with Reve.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to edit."),
|
||||
IO.String.Input(
|
||||
"edit_instruction",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-edit@20250915", "reve-edit-fast@20251030"],
|
||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for editing.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
edit_instruction: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(edit_instruction, min_length=1, max_length=2560)
|
||||
tts = model["test_time_scaling"]
|
||||
ar = model["aspect_ratio"]
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/edit",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageEditRequest(
|
||||
edit_instruction=edit_instruction,
|
||||
reference_image=tensor_to_base64_string(image),
|
||||
aspect_ratio=ar if ar != "auto" else None,
|
||||
version=model["model"],
|
||||
test_time_scaling=tts if tts and tts > 1 else None,
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveImageRemixNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageRemixNode",
|
||||
display_name="Reve Image Remix",
|
||||
category="api node/image/Reve",
|
||||
description="Combine reference images with text prompts to create new images using Reve.",
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="image_",
|
||||
min=1,
|
||||
max=6,
|
||||
),
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. "
|
||||
"May include XML img tags to reference specific images by index, "
|
||||
"e.g. <img>0</img>, <img>1</img>, etc.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-remix@20250915", "reve-remix-fast@20251030"],
|
||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for remixing.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
reference_images: IO.Autogrow.Type,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2560)
|
||||
if not reference_images:
|
||||
raise ValueError("At least one reference image is required.")
|
||||
ref_base64_list = []
|
||||
for key in reference_images:
|
||||
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
|
||||
if len(ref_base64_list) > 6:
|
||||
raise ValueError("Maximum 6 reference images are allowed.")
|
||||
tts = model["test_time_scaling"]
|
||||
ar = model["aspect_ratio"]
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/remix",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageRemixRequest(
|
||||
prompt=prompt,
|
||||
reference_images=ref_base64_list,
|
||||
aspect_ratio=ar if ar != "auto" else None,
|
||||
version=model["model"],
|
||||
test_time_scaling=tts if tts and tts > 1 else None,
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ReveImageCreateNode,
|
||||
ReveImageEditNode,
|
||||
ReveImageRemixNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ReveExtension:
|
||||
return ReveExtension()
|
||||
@@ -67,6 +67,7 @@ class _RequestConfig:
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -83,7 +84,7 @@ class _PollUIState:
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
@@ -202,11 +203,13 @@ async def sync_op_raw(
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
- If as_binary=True: returns bytes.
|
||||
- response_header_validator: optional callback receiving response headers dict
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
@@ -232,6 +235,7 @@ async def sync_op_raw(
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
response_header_validator=response_header_validator,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||
if cfg.price_extractor:
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(resp_headers)
|
||||
if cfg.response_header_validator:
|
||||
cfg.response_header_validator(resp_headers)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
@@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_headers=resp_headers,
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
return bytes_payload
|
||||
|
||||
18
comfy_api_sealed_worker/__init__.py
Normal file
18
comfy_api_sealed_worker/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""comfy_api_sealed_worker — torch-free type definitions for sealed worker children.
|
||||
|
||||
Drop-in replacement for comfy_api.latest._util type imports in sealed workers
|
||||
that do not have torch installed. Contains only data type definitions (TrimeshData,
|
||||
PLY, NPZ, etc.) with numpy-only dependencies.
|
||||
|
||||
Usage in serializers:
|
||||
if _IMPORT_TORCH:
|
||||
from comfy_api.latest._util.trimesh_types import TrimeshData
|
||||
else:
|
||||
from comfy_api_sealed_worker.trimesh_types import TrimeshData
|
||||
"""
|
||||
|
||||
from .trimesh_types import TrimeshData
|
||||
from .ply_types import PLY
|
||||
from .npz_types import NPZ
|
||||
|
||||
__all__ = ["TrimeshData", "PLY", "NPZ"]
|
||||
27
comfy_api_sealed_worker/npz_types.py
Normal file
27
comfy_api_sealed_worker/npz_types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class NPZ:
|
||||
"""Ordered collection of NPZ file payloads.
|
||||
|
||||
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
|
||||
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
|
||||
``save_to`` writes numbered files into a directory.
|
||||
"""
|
||||
|
||||
def __init__(self, frames: list[bytes]) -> None:
|
||||
self.frames = frames
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return len(self.frames)
|
||||
|
||||
def save_to(self, directory: str, prefix: str = "frame") -> str:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for i, frame_bytes in enumerate(self.frames):
|
||||
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
|
||||
with open(path, "wb") as f:
|
||||
f.write(frame_bytes)
|
||||
return directory
|
||||
97
comfy_api_sealed_worker/ply_types.py
Normal file
97
comfy_api_sealed_worker/ply_types.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PLY:
|
||||
"""Point cloud payload for PLY file output.
|
||||
|
||||
Supports two schemas:
|
||||
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
|
||||
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
|
||||
|
||||
When ``raw_data`` is provided, the object acts as an opaque binary PLY
|
||||
carrier and ``save_to`` writes the bytes directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
points: np.ndarray | None = None,
|
||||
colors: np.ndarray | None = None,
|
||||
confidence: np.ndarray | None = None,
|
||||
view_id: np.ndarray | None = None,
|
||||
raw_data: bytes | None = None,
|
||||
) -> None:
|
||||
self.raw_data = raw_data
|
||||
if raw_data is not None:
|
||||
self.points = None
|
||||
self.colors = None
|
||||
self.confidence = None
|
||||
self.view_id = None
|
||||
return
|
||||
if points is None:
|
||||
raise ValueError("Either points or raw_data must be provided")
|
||||
if points.ndim != 2 or points.shape[1] != 3:
|
||||
raise ValueError(f"points must be (N, 3), got {points.shape}")
|
||||
self.points = np.ascontiguousarray(points, dtype=np.float32)
|
||||
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
|
||||
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
|
||||
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
|
||||
|
||||
@property
|
||||
def is_gaussian(self) -> bool:
|
||||
return self.raw_data is not None
|
||||
|
||||
@property
|
||||
def num_points(self) -> int:
|
||||
if self.points is not None:
|
||||
return self.points.shape[0]
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
if self.raw_data is not None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(self.raw_data)
|
||||
return path
|
||||
self.points = self._to_numpy(self.points, np.float32)
|
||||
self.colors = self._to_numpy(self.colors, np.float32)
|
||||
self.confidence = self._to_numpy(self.confidence, np.float32)
|
||||
self.view_id = self._to_numpy(self.view_id, np.int32)
|
||||
N = self.num_points
|
||||
header_lines = [
|
||||
"ply",
|
||||
"format ascii 1.0",
|
||||
f"element vertex {N}",
|
||||
"property float x",
|
||||
"property float y",
|
||||
"property float z",
|
||||
]
|
||||
if self.colors is not None:
|
||||
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
|
||||
if self.confidence is not None:
|
||||
header_lines.append("property float confidence")
|
||||
if self.view_id is not None:
|
||||
header_lines.append("property int view_id")
|
||||
header_lines.append("end_header")
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(header_lines) + "\n")
|
||||
for i in range(N):
|
||||
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
|
||||
if self.colors is not None:
|
||||
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
|
||||
parts.append(f"{r} {g} {b}")
|
||||
if self.confidence is not None:
|
||||
parts.append(f"{self.confidence[i]}")
|
||||
if self.view_id is not None:
|
||||
parts.append(f"{int(self.view_id[i])}")
|
||||
f.write(" ".join(parts) + "\n")
|
||||
return path
|
||||
259
comfy_api_sealed_worker/trimesh_types.py
Normal file
259
comfy_api_sealed_worker/trimesh_types.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrimeshData:
|
||||
"""Triangular mesh payload for cross-process transfer.
|
||||
|
||||
Lightweight carrier for mesh geometry that does not depend on the
|
||||
``trimesh`` library. Serializers create this on the host side;
|
||||
isolated child processes convert to/from ``trimesh.Trimesh`` as needed.
|
||||
|
||||
Supports both ColorVisuals (vertex_colors) and TextureVisuals
|
||||
(uv + material with textures).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vertices: np.ndarray,
|
||||
faces: np.ndarray,
|
||||
vertex_normals: np.ndarray | None = None,
|
||||
face_normals: np.ndarray | None = None,
|
||||
vertex_colors: np.ndarray | None = None,
|
||||
uv: np.ndarray | None = None,
|
||||
material: dict | None = None,
|
||||
vertex_attributes: dict | None = None,
|
||||
face_attributes: dict | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
self.vertices = np.ascontiguousarray(vertices, dtype=np.float64)
|
||||
self.faces = np.ascontiguousarray(faces, dtype=np.int64)
|
||||
self.vertex_normals = (
|
||||
np.ascontiguousarray(vertex_normals, dtype=np.float64)
|
||||
if vertex_normals is not None
|
||||
else None
|
||||
)
|
||||
self.face_normals = (
|
||||
np.ascontiguousarray(face_normals, dtype=np.float64)
|
||||
if face_normals is not None
|
||||
else None
|
||||
)
|
||||
self.vertex_colors = (
|
||||
np.ascontiguousarray(vertex_colors, dtype=np.uint8)
|
||||
if vertex_colors is not None
|
||||
else None
|
||||
)
|
||||
self.uv = (
|
||||
np.ascontiguousarray(uv, dtype=np.float64)
|
||||
if uv is not None
|
||||
else None
|
||||
)
|
||||
self.material = material
|
||||
self.vertex_attributes = vertex_attributes or {}
|
||||
self.face_attributes = face_attributes or {}
|
||||
self.metadata = self._detensorize_dict(metadata) if metadata else {}
|
||||
|
||||
@staticmethod
|
||||
def _detensorize_dict(d):
|
||||
"""Recursively convert any tensors in a dict back to numpy arrays."""
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if hasattr(v, "numpy"):
|
||||
result[k] = v.cpu().numpy() if hasattr(v, "cpu") else v.numpy()
|
||||
elif isinstance(v, dict):
|
||||
result[k] = TrimeshData._detensorize_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [
|
||||
item.cpu().numpy() if hasattr(item, "numpy") and hasattr(item, "cpu")
|
||||
else item.numpy() if hasattr(item, "numpy")
|
||||
else item
|
||||
for item in v
|
||||
]
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
@property
|
||||
def num_vertices(self) -> int:
|
||||
return self.vertices.shape[0]
|
||||
|
||||
@property
|
||||
def num_faces(self) -> int:
|
||||
return self.faces.shape[0]
|
||||
|
||||
@property
|
||||
def has_texture(self) -> bool:
|
||||
return self.uv is not None and self.material is not None
|
||||
|
||||
def to_trimesh(self):
|
||||
"""Convert to trimesh.Trimesh (requires trimesh in the environment)."""
|
||||
import trimesh
|
||||
from trimesh.visual import TextureVisuals
|
||||
|
||||
kwargs = {}
|
||||
if self.vertex_normals is not None:
|
||||
kwargs["vertex_normals"] = self.vertex_normals
|
||||
if self.face_normals is not None:
|
||||
kwargs["face_normals"] = self.face_normals
|
||||
if self.metadata:
|
||||
kwargs["metadata"] = self.metadata
|
||||
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices=self.vertices, faces=self.faces, process=False, **kwargs
|
||||
)
|
||||
|
||||
# Reconstruct visual
|
||||
if self.has_texture:
|
||||
material = self._dict_to_material(self.material)
|
||||
mesh.visual = TextureVisuals(uv=self.uv, material=material)
|
||||
elif self.vertex_colors is not None:
|
||||
mesh.visual.vertex_colors = self.vertex_colors
|
||||
|
||||
for k, v in self.vertex_attributes.items():
|
||||
mesh.vertex_attributes[k] = v
|
||||
|
||||
for k, v in self.face_attributes.items():
|
||||
mesh.face_attributes[k] = v
|
||||
|
||||
return mesh
|
||||
|
||||
@staticmethod
|
||||
def _material_to_dict(material) -> dict:
|
||||
"""Serialize a trimesh material to a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
result = {"type": type(material).__name__, "name": getattr(material, "name", None)}
|
||||
|
||||
if isinstance(material, PBRMaterial):
|
||||
result["baseColorFactor"] = material.baseColorFactor
|
||||
result["metallicFactor"] = material.metallicFactor
|
||||
result["roughnessFactor"] = material.roughnessFactor
|
||||
result["emissiveFactor"] = material.emissiveFactor
|
||||
result["alphaMode"] = material.alphaMode
|
||||
result["alphaCutoff"] = material.alphaCutoff
|
||||
result["doubleSided"] = material.doubleSided
|
||||
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
tex = getattr(material, tex_name, None)
|
||||
if tex is not None:
|
||||
buf = BytesIO()
|
||||
tex.save(buf, format="PNG")
|
||||
result[tex_name] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
elif isinstance(material, SimpleMaterial):
|
||||
result["main_color"] = list(material.main_color) if material.main_color is not None else None
|
||||
result["glossiness"] = material.glossiness
|
||||
if hasattr(material, "image") and material.image is not None:
|
||||
buf = BytesIO()
|
||||
material.image.save(buf, format="PNG")
|
||||
result["image"] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_material(d: dict):
|
||||
"""Reconstruct a trimesh material from a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
mat_type = d.get("type", "PBRMaterial")
|
||||
|
||||
if mat_type == "PBRMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"baseColorFactor": d.get("baseColorFactor"),
|
||||
"metallicFactor": d.get("metallicFactor"),
|
||||
"roughnessFactor": d.get("roughnessFactor"),
|
||||
"emissiveFactor": d.get("emissiveFactor"),
|
||||
"alphaMode": d.get("alphaMode"),
|
||||
"alphaCutoff": d.get("alphaCutoff"),
|
||||
"doubleSided": d.get("doubleSided"),
|
||||
}
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
if tex_name in d and d[tex_name] is not None:
|
||||
img = Image.open(BytesIO(base64.b64decode(d[tex_name])))
|
||||
kwargs[tex_name] = img
|
||||
return PBRMaterial(**{k: v for k, v in kwargs.items() if v is not None})
|
||||
|
||||
elif mat_type == "SimpleMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"glossiness": d.get("glossiness"),
|
||||
}
|
||||
if d.get("main_color") is not None:
|
||||
kwargs["diffuse"] = d["main_color"]
|
||||
if d.get("image") is not None:
|
||||
kwargs["image"] = Image.open(BytesIO(base64.b64decode(d["image"])))
|
||||
return SimpleMaterial(**kwargs)
|
||||
|
||||
raise ValueError(f"Unknown material type: {mat_type}")
|
||||
|
||||
@classmethod
|
||||
def from_trimesh(cls, mesh) -> TrimeshData:
|
||||
"""Create from a trimesh.Trimesh object."""
|
||||
from trimesh.visual.texture import TextureVisuals
|
||||
|
||||
vertex_normals = None
|
||||
if mesh._cache.cache.get("vertex_normals") is not None:
|
||||
vertex_normals = np.asarray(mesh.vertex_normals)
|
||||
|
||||
face_normals = None
|
||||
if mesh._cache.cache.get("face_normals") is not None:
|
||||
face_normals = np.asarray(mesh.face_normals)
|
||||
|
||||
vertex_colors = None
|
||||
uv = None
|
||||
material = None
|
||||
|
||||
if isinstance(mesh.visual, TextureVisuals):
|
||||
if mesh.visual.uv is not None:
|
||||
uv = np.asarray(mesh.visual.uv, dtype=np.float64)
|
||||
if mesh.visual.material is not None:
|
||||
material = cls._material_to_dict(mesh.visual.material)
|
||||
else:
|
||||
try:
|
||||
vc = mesh.visual.vertex_colors
|
||||
if vc is not None and len(vc) > 0:
|
||||
vertex_colors = np.asarray(vc, dtype=np.uint8)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
va = {}
|
||||
if hasattr(mesh, "vertex_attributes") and mesh.vertex_attributes:
|
||||
for k, v in mesh.vertex_attributes.items():
|
||||
va[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
fa = {}
|
||||
if hasattr(mesh, "face_attributes") and mesh.face_attributes:
|
||||
for k, v in mesh.face_attributes.items():
|
||||
fa[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
return cls(
|
||||
vertices=np.asarray(mesh.vertices),
|
||||
faces=np.asarray(mesh.faces),
|
||||
vertex_normals=vertex_normals,
|
||||
face_normals=face_normals,
|
||||
vertex_colors=vertex_colors,
|
||||
uv=uv,
|
||||
material=material,
|
||||
vertex_attributes=va if va else None,
|
||||
face_attributes=fa if fa else None,
|
||||
metadata=mesh.metadata if mesh.metadata else None,
|
||||
)
|
||||
@@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return frame_idx, latent_idx
|
||||
|
||||
@classmethod
|
||||
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1):
|
||||
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None):
|
||||
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
|
||||
if causal_fix is None:
|
||||
causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix)
|
||||
pixel_coords[:, 0] += frame_idx
|
||||
|
||||
# The following adjusts keyframe end positions for small grid IC-LoRA.
|
||||
@@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
@classmethod
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1):
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None):
|
||||
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
|
||||
|
||||
if guide_mask is not None:
|
||||
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||
|
||||
119
comfy_extras/nodes_math.py
Normal file
119
comfy_extras/nodes_math.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Math expression node using simpleeval for safe evaluation.
|
||||
|
||||
Provides a ComfyMathExpression node that evaluates math expressions
|
||||
against dynamically-grown numeric inputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import string
|
||||
|
||||
from simpleeval import simple_eval
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
MAX_EXPONENT = 4000
|
||||
|
||||
|
||||
def _variadic_sum(*args):
|
||||
"""Support both sum(values) and sum(a, b, c)."""
|
||||
if len(args) == 1 and hasattr(args[0], "__iter__"):
|
||||
return sum(args[0])
|
||||
return sum(args)
|
||||
|
||||
|
||||
def _safe_pow(base, exp):
|
||||
"""Wrap pow() with an exponent cap to prevent DoS via huge exponents.
|
||||
|
||||
The ** operator is already guarded by simpleeval's safe_power, but
|
||||
pow() as a callable bypasses that guard.
|
||||
"""
|
||||
if abs(exp) > MAX_EXPONENT:
|
||||
raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})")
|
||||
return pow(base, exp)
|
||||
|
||||
|
||||
MATH_FUNCTIONS = {
|
||||
"sum": _variadic_sum,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"pow": _safe_pow,
|
||||
"sqrt": math.sqrt,
|
||||
"ceil": math.ceil,
|
||||
"floor": math.floor,
|
||||
"log": math.log,
|
||||
"log2": math.log2,
|
||||
"log10": math.log10,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"int": int,
|
||||
"float": float,
|
||||
}
|
||||
|
||||
|
||||
class MathExpressionNode(io.ComfyNode):
|
||||
"""Evaluates a math expression against dynamically-grown inputs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
autogrow = io.Autogrow.TemplateNames(
|
||||
input=io.MultiType.Input("value", [io.Float, io.Int]),
|
||||
names=list(string.ascii_lowercase),
|
||||
min=1,
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="ComfyMathExpression",
|
||||
display_name="Math Expression",
|
||||
category="math",
|
||||
search_aliases=[
|
||||
"expression", "formula", "calculate", "calculator",
|
||||
"eval", "math",
|
||||
],
|
||||
inputs=[
|
||||
io.String.Input("expression", default="a + b", multiline=True),
|
||||
io.Autogrow.Input("values", template=autogrow),
|
||||
],
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, expression: str, values: io.Autogrow.Type
|
||||
) -> io.NodeOutput:
|
||||
if not expression.strip():
|
||||
raise ValueError("Expression cannot be empty.")
|
||||
|
||||
context: dict = dict(values)
|
||||
context["values"] = list(values.values())
|
||||
|
||||
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
|
||||
# bool check must come first because bool is a subclass of int in Python
|
||||
if isinstance(result, bool) or not isinstance(result, (int, float)):
|
||||
raise ValueError(
|
||||
f"Math Expression '{expression}' must evaluate to a numeric result, "
|
||||
f"got {type(result).__name__}: {result!r}"
|
||||
)
|
||||
if not math.isfinite(result):
|
||||
raise ValueError(
|
||||
f"Math Expression '{expression}' produced a non-finite result: {result}"
|
||||
)
|
||||
return io.NodeOutput(float(result), int(result))
|
||||
|
||||
|
||||
class MathExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [MathExpressionNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MathExtension:
|
||||
return MathExtension()
|
||||
40
comfy_extras/nodes_save_npz.py
Normal file
40
comfy_extras/nodes_save_npz.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import io
|
||||
from comfy_api.latest._util.npz_types import NPZ
|
||||
|
||||
|
||||
class SaveNPZ(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveNPZ",
|
||||
display_name="Save NPZ",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Npz.Input("npz"),
|
||||
io.String.Input("filename_prefix", default="da3_streaming/ComfyUI"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, npz: NPZ, filename_prefix: str) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
batch_dir = os.path.join(full_output_folder, f"{filename}_{counter:05}")
|
||||
os.makedirs(batch_dir, exist_ok=True)
|
||||
filenames = []
|
||||
for i, frame_bytes in enumerate(npz.frames):
|
||||
f = f"frame_{i:06d}.npz"
|
||||
with open(os.path.join(batch_dir, f), "wb") as fh:
|
||||
fh.write(frame_bytes)
|
||||
filenames.append(f)
|
||||
return io.NodeOutput(ui={"npz_files": [{"folder": os.path.join(subfolder, f"{filename}_{counter:05}"), "count": len(filenames), "type": "output"}]})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveNPZ": SaveNPZ,
|
||||
}
|
||||
34
comfy_extras/nodes_save_ply.py
Normal file
34
comfy_extras/nodes_save_ply.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import io
|
||||
from comfy_api.latest._util.ply_types import PLY
|
||||
|
||||
|
||||
class SavePLY(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SavePLY",
|
||||
display_name="Save PLY",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Ply.Input("ply"),
|
||||
io.String.Input("filename_prefix", default="pointcloud/ComfyUI"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ply: PLY, filename_prefix: str) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
f = f"{filename}_{counter:05}_.ply"
|
||||
ply.save_to(os.path.join(full_output_folder, f))
|
||||
return io.NodeOutput(ui={"pointclouds": [{"filename": f, "subfolder": subfolder, "type": "output"}]})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SavePLY": SavePLY,
|
||||
}
|
||||
@@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.16.2"
|
||||
__version__ = "0.16.4"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user