Compare commits

..

15 Commits

Author SHA1 Message Date
John Pollock
c02372936d Merge pull request #12959 from pollockjj/wheel-support-pr
feat(isolation): wheel support for isolated custom nodes
2026-03-15 01:29:05 -05:00
John Pollock
6aa0b838a0 feat(isolation): wheel support for isolated custom nodes
Extends pyisolate process isolation with wheel-based dependency
management, sandbox mode policy, and compatibility fixes validated
against DA3 as the first complex isolated custom node.

- Add sandbox_mode policy (required/disabled) with COMFY_HOST_POLICY_PATH
  env override for host security configuration
- Plumb cuda_wheels config and standardize child environment detection
- Add PLY, NPZ, File3D, VIDEO serializers for core save nodes
- Register isolated extension web directories on host side
  (sandbox_mode=disabled) for frontend JS widget serving
- Capture ACCEPT_ALL_INPUTS from child node class to prevent
  @classproperty trigger on proxy class
- Serialize NodeOutput with ui/expand/block_execution through JSON-RPC
- Cherry-pick V3 python_module metadata fix (e4592190)
- Remove improperly committed cache artifacts
2026-03-15 01:25:40 -05:00
John Pollock
54461f9ecc Merge pull request #12898 from pollockjj/pyisolate-pr-final
feat(isolation): upstream master sync + serializers, save nodes, and fixes
2026-03-12 01:48:58 -05:00
John Pollock
b602cc4533 fix(lint): suppress F401 on side-effect import utils.extra_config 2026-03-12 01:38:07 -05:00
John Pollock
08b92a48c3 Merge branch 'pyisolate-squash-final' into pyisolate-pr-final
# Conflicts:
#	comfy/isolation/adapter.py
#	comfy/isolation/extension_loader.py
#	comfy/isolation/extension_wrapper.py
#	comfy/isolation/model_patcher_proxy_utils.py
#	comfy/isolation/runtime_helpers.py
#	comfy/model_patcher.py
#	comfy/supported_models.py
#	main.py
#	nodes.py
#	requirements.txt
#	server.py
2026-03-12 01:33:42 -05:00
John Pollock
c5e7b9cdaf feat(isolation): process isolation for custom nodes via pyisolate
Adds opt-in process isolation for custom nodes using pyisolate's
bwrap sandbox and JSON-RPC bridge. Each isolated node pack runs in
its own child process with zero-copy tensor transfer via shared memory.

Core infrastructure:
- CLI flag --use-process-isolation to enable isolation
- Host/child startup fencing via PYISOLATE_CHILD env var
- Manifest-driven node discovery and extension loading
- JSON-RPC bridge between host and child processes
- Shared memory forensics for leak detection

Proxy layer:
- ModelPatcher, CLIP, VAE, and ModelSampling proxies
- Host service proxies (folder_paths, model_management, progress, etc.)
- Proxy base with automatic method forwarding

Execution integration:
- Extension wrapper with V3 hidden param mapping
- Runtime helpers for isolated node execution
- Host policy for node isolation decisions
- Fenced sampler device handling and model ejection parity

Serializers for cross-process data transfer:
- File3D (GLB), PLY (structured + gaussian), NPZ (streaming frames),
  VIDEO (VideoFromFile + VideoFromComponents) serializers
- data_type flag in SerializerRegistry for type-aware dispatch
- Isolated get_temp_directory() fence

New core save nodes:
- SavePLY and SaveNPZ with comfytype registrations (Ply, Npz)

DynamicVRAM compatibility:
- comfy-aimdo early init gated by isolation fence

Tests:
- Integration and policy tests for isolation lifecycle
- Manifest loader, host policy, proxy, and adapter unit tests

Depends on: pyisolate >= 0.9.2
2026-03-12 01:13:43 -05:00
John Pollock
623a9d21e9 Merge pull request #12775 from pollockjj/pyisolate-pr-20260304
feat(isolation): DynamicVRAM compatibility for process isolation
2026-03-05 06:01:33 +00:00
John Pollock
9250191c65 feat(isolation): DynamicVRAM compatibility for process isolation
DynamicVRAM's on-demand model loading/offloading conflicted with  process isolation in three ways: RPC tensor transport stalls from mid-call GPU offload, race conditions between model lifecycle and active RPC operations, and false positive memory leak detection from changed finalizer patterns.

- Marshal CUDA tensors to CPU before RPC transport for dynamic models
- Add operation state tracking + quiescence waits at workflow boundaries
- Distinguish proxy reference release from actual leaks in cleanup_models_gc
- Fix init order: DynamicVRAM must initialize before isolation proxies
- Add RPC timeouts to prevent indefinite hangs on model unavailability
- Prevent proxy-of-proxy chains from DynamicVRAM model reload cycles
- Add torch.device/torch.dtype serializers for new DynamicVRAM RPC paths
- Guard isolation overhead so non-isolated workflows are unaffected
- Migrate env var to PYISOLATE_CHILD
2026-03-04 23:48:02 -06:00
Jedrzej Kosinski
a0f8784e9f Merge remote-tracking branch 'pollock/comfy-isolation-squash' into pyisolate-support 2026-02-28 03:54:28 -08:00
John Pollock
7962db477a test(isolation): isolation integration + policy tests 2026-02-27 13:07:23 -06:00
John Pollock
3c8ba051b6 fix(isolation-lifecycle): execution/model ejection parity + fenced sampler device handling
add pyisolate==0.9.1 to requirements.txt
2026-02-27 13:07:23 -06:00
John Pollock
a1c3124821 feat(isolation-model-proxies): model patcher + clip/vae/model sampling proxies 2026-02-27 12:42:11 -06:00
John Pollock
9ca799362d feat(isolation-proxies): proxy base + host service proxies 2026-02-27 12:41:58 -06:00
John Pollock
22f5e43c12 feat(isolation-runtime): manifest loading, orchestration, host policy, shm forensics 2026-02-27 12:41:44 -06:00
John Pollock
3cfd5e3311 feat(isolation-bootstrap): cli flag + host/child startup fencing 2026-02-27 12:41:27 -06:00
133 changed files with 9995 additions and 4290 deletions

View File

@@ -1,103 +0,0 @@
#!/usr/bin/env bash
# Checks pull request commits for AI agent Co-authored-by trailers.
# Exits non-zero when any are found and prints fix instructions.
set -euo pipefail
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
# Known AI coding-agent trailer patterns (case-insensitive).
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
AGENT_PATTERNS=(
# Anthropic — Claude Code / Amp
'noreply@anthropic\.com'
# Cursor
'cursoragent@cursor\.com'
# GitHub Copilot
'copilot-swe-agent\[bot\]'
'copilot@github\.com'
# OpenAI Codex
'noreply@openai\.com'
'codex@openai\.com'
# Aider
'aider@aider\.chat'
# Google — Gemini / Jules
'gemini@google\.com'
'jules@google\.com'
# Windsurf / Codeium
'@codeium\.com'
# Devin
'devin-ai-integration\[bot\]'
'devin@cognition\.ai'
'devin@cognition-labs\.com'
# Amazon Q Developer
'amazon-q-developer'
'@amazon\.com.*[Qq].[Dd]eveloper'
# Cline
'cline-bot'
'cline@cline\.ai'
# Continue
'continue-agent'
'continue@continue\.dev'
# Sourcegraph
'noreply@sourcegraph\.com'
# Generic catch-alls for common agent name patterns
'Co-authored-by:.*\b[Cc]laude\b'
'Co-authored-by:.*\b[Cc]opilot\b'
'Co-authored-by:.*\b[Cc]ursor\b'
'Co-authored-by:.*\b[Cc]odex\b'
'Co-authored-by:.*\b[Gg]emini\b'
'Co-authored-by:.*\b[Aa]ider\b'
'Co-authored-by:.*\b[Dd]evin\b'
'Co-authored-by:.*\b[Ww]indsurf\b'
'Co-authored-by:.*\b[Cc]line\b'
'Co-authored-by:.*\b[Aa]mazon Q\b'
'Co-authored-by:.*\b[Jj]ules\b'
'Co-authored-by:.*\bOpenCode\b'
)
# Build a single alternation regex from all patterns.
regex=""
for pattern in "${AGENT_PATTERNS[@]}"; do
if [[ -n "$regex" ]]; then
regex="${regex}|${pattern}"
else
regex="$pattern"
fi
done
# Collect Co-authored-by lines from every commit in the PR range.
violations=""
while IFS= read -r sha; do
message="$(git log -1 --format='%B' "$sha")"
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
if [[ -z "$matched_lines" ]]; then
continue
fi
while IFS= read -r line; do
if echo "$line" | grep -iqE "$regex"; then
short="$(git log -1 --format='%h' "$sha")"
violations="${violations} ${short}: ${line}"$'\n'
fi
done <<< "$matched_lines"
done < <(git rev-list "${base_sha}..${head_sha}")
if [[ -n "$violations" ]]; then
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
echo ""
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
echo ""
echo "$violations"
echo "These trailers should be removed before merging."
echo ""
echo "To fix, rewrite the commit messages with:"
echo " git rebase -i ${base_sha}"
echo ""
echo "and remove the Co-authored-by lines, then force-push your branch."
echo ""
echo "If you believe this is a false positive, please open an issue."
exit 1
fi
echo "No AI agent Co-authored-by trailers found."

View File

@@ -1,19 +0,0 @@
name: Check AI Co-Authors
on:
pull_request:
branches: ['*']
jobs:
check-ai-co-authors:
name: Check for AI agent co-author trailers
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check commits for AI co-author trailers
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"

1
.gitignore vendored
View File

@@ -24,3 +24,4 @@ web_custom_versions/
openapi.yaml
filtered-openapi.yaml
uv.lock
.pyisolate_venvs/

View File

@@ -38,8 +38,6 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started
### Local
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
@@ -51,13 +49,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
### Cloud
#### [Comfy Cloud](https://www.comfy.org/cloud)
- Our official paid cloud version for those who can't afford local hardware.
## Examples
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.

View File

@@ -8,7 +8,7 @@ from alembic import context
config = context.config
from app.database.models import Base, NAMING_CONVENTION
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
@@ -51,10 +51,7 @@ def run_migrations_online() -> None:
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
naming_convention=NAMING_CONVENTION,
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():

View File

@@ -1,98 +0,0 @@
"""
Add system_metadata and job_id columns to asset_references.
Change preview_id FK from assets.id to asset_references.id.
Revision ID: 0003_add_metadata_job_id
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
from app.database.models import NAMING_CONVENTION
revision = "0003_add_metadata_job_id"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("job_id", sa.String(length=36), nullable=True)
)
# Change preview_id FK from assets.id to asset_references.id (self-ref).
# Existing values are asset-content IDs that won't match reference IDs,
# so null them out first.
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_constraint(
"fk_asset_references_preview_id_assets", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_asset_references",
"asset_references",
["preview_id"],
["id"],
ondelete="SET NULL",
)
batch_op.create_index(
"ix_asset_references_preview_id", ["preview_id"]
)
# Purge any all-null meta rows before adding the constraint
op.execute(
"DELETE FROM asset_reference_meta"
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
)
with op.batch_alter_table("asset_reference_meta") as batch_op:
batch_op.create_check_constraint(
"ck_asset_reference_meta_has_value",
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
)
def downgrade() -> None:
# SQLite doesn't reflect CHECK constraints, so we must declare it
# explicitly via table_args for the batch recreate to find it.
# Use the fully-rendered constraint name to avoid the naming convention
# doubling the prefix.
with op.batch_alter_table(
"asset_reference_meta",
table_args=[
sa.CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="ck_asset_reference_meta_has_value",
),
],
) as batch_op:
batch_op.drop_constraint(
"ck_asset_reference_meta_has_value", type_="check"
)
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_index("ix_asset_references_preview_id")
batch_op.drop_constraint(
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_assets",
"assets",
["preview_id"],
["id"],
ondelete="SET NULL",
)
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("job_id")
batch_op.drop_column("system_metadata")

View File

@@ -13,7 +13,6 @@ from pydantic import ValidationError
import folder_paths
from app import user_manager
from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas
from app.assets.api.schemas_in import (
AssetValidationError,
UploadError,
@@ -39,7 +38,6 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -124,61 +122,6 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at"
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
if not user_metadata:
return None
filename = user_metadata.get("filename")
if not filename:
return None
if "input" in tags:
view_type = "input"
elif "output" in tags:
view_type = "output"
else:
return None
subfolder = ""
if "/" in filename:
subfolder, filename = filename.rsplit("/", 1)
encoded_filename = urllib.parse.quote(filename, safe="")
url = f"/api/view?type={view_type}&filename={encoded_filename}"
if subfolder:
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
return url
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
if result.ref.preview_id:
preview_detail = get_asset_detail(result.ref.preview_id)
if preview_detail:
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
else:
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
job_id=result.ref.job_id,
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response:
@@ -221,7 +164,20 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order,
)
summaries = [_build_asset_response(item) for item in result.items]
summaries = [
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList(
assets=summaries,
@@ -251,7 +207,18 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id},
)
payload = _build_asset_response(result)
payload = schemas_out.AssetDetail(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
except ValueError as e:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@@ -263,7 +230,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@@ -345,31 +312,32 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON."
)
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash(
hash_str=body.hash,
name=name,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
)
if result is None:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
@@ -390,8 +358,6 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
}
)
except ValidationError as ve:
@@ -420,8 +386,6 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
)
if result is None:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -446,8 +410,6 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
)
except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -466,13 +428,21 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
payload = schemas_out.AssetCreated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new,
)
status = 201 if result.created_new else 200
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
return web.json_response(payload.model_dump(mode="json"), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@@ -494,9 +464,15 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
preview_id=body.preview_id,
)
payload = _build_asset_response(result)
payload = schemas_out.AssetUpdated(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve:
@@ -510,7 +486,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@@ -579,7 +555,7 @@ async def get_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@@ -627,7 +603,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@@ -674,29 +650,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")

View File

@@ -45,8 +45,6 @@ class ParsedUpload:
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel):
@@ -100,17 +98,11 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@model_validator(mode="after")
def _validate_at_least_one_field(self):
if all(
v is None
for v in (self.name, self.user_metadata, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, preview_id."
)
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
@@ -118,11 +110,9 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str | None = None
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@field_validator("hash")
@classmethod
@@ -148,44 +138,6 @@ class CreateFromHashBody(BaseModel):
return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@@ -234,25 +186,21 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: optional list; if provided, first is root ('models'|'input'|'output');
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
- mime_type: optional MIME type override
- preview_id: optional asset_reference ID for preview
Files are stored using the content hash as filename stem.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(default_factory=list)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None) # references an asset_reference id
@field_validator("hash", mode="before")
@classmethod
@@ -331,7 +279,7 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("at least one tag is required for uploads")
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")

View File

@@ -4,10 +4,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class Asset(BaseModel):
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
``id`` here is the AssetReference id, not the content-addressed Asset id."""
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: str | None = None
@@ -15,14 +12,8 @@ class Asset(BaseModel):
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
job_id: str | None = None
prompt_id: str | None = None # deprecated: use job_id
created_at: datetime
updated_at: datetime
created_at: datetime | None = None
updated_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@@ -32,16 +23,50 @@ class Asset(BaseModel):
return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel):
assets: list[Asset]
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@@ -66,7 +91,3 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@@ -52,8 +52,6 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0
tmp_path: str | None = None
@@ -130,16 +128,6 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
elif fname == "id":
raise UploadError(
400,
"UNSUPPORTED_FIELD",
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
)
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
@@ -164,8 +152,6 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
)

View File

@@ -45,7 +45,13 @@ class Asset(Base):
passive_deletes=True,
)
# preview_id on AssetReference is a self-referential FK to asset_references.id
preview_of: Mapped[list[AssetReference]] = relationship(
"AssetReference",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
foreign_keys=lambda: [AssetReference.preview_id],
viewonly=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
@@ -85,15 +91,11 @@ class AssetReference(Base):
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
String(36), ForeignKey("assets.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
@@ -113,10 +115,10 @@ class AssetReference(Base):
foreign_keys=[asset_id],
lazy="selectin",
)
preview_ref: Mapped[AssetReference | None] = relationship(
"AssetReference",
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
remote_side=lambda: [AssetReference.id],
)
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
@@ -150,7 +152,6 @@ class AssetReference(Base):
Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_deleted_at", "deleted_at"),
Index("ix_asset_references_preview_id", "preview_id"),
Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
@@ -191,10 +192,6 @@ class AssetReferenceMeta(Base):
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="has_value",
),
)

View File

@@ -31,21 +31,16 @@ from app.assets.database.queries.asset_reference import (
get_unenriched_references,
get_unreferenced_unhashed_asset_ids,
insert_reference,
list_all_file_paths_by_asset_id,
list_references_by_asset_id,
list_references_page,
mark_references_missing_outside_prefixes,
rebuild_metadata_projection,
reference_exists,
reference_exists_for_asset_id,
restore_references_by_paths,
set_reference_metadata,
set_reference_preview,
set_reference_system_metadata,
soft_delete_reference_by_id,
update_reference_access_time,
update_reference_name,
update_is_missing_by_asset_id,
update_reference_timestamps,
update_reference_updated_at,
upsert_reference,
@@ -59,7 +54,6 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
@@ -103,26 +97,20 @@ __all__ = [
"get_unenriched_references",
"get_unreferenced_unhashed_asset_ids",
"insert_reference",
"list_all_file_paths_by_asset_id",
"list_references_by_asset_id",
"list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",
"rebuild_metadata_projection",
"reference_exists",
"reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id",
"remove_tags_from_reference",
"restore_references_by_paths",
"set_reference_metadata",
"set_reference_preview",
"set_reference_system_metadata",
"soft_delete_reference_by_id",
"set_reference_tags",
"update_asset_hash_and_mime",
"update_is_missing_by_asset_id",
"update_reference_access_time",
"update_reference_name",
"update_reference_timestamps",

View File

@@ -69,7 +69,7 @@ def upsert_asset(
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and not asset.mime_type:
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
@@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
return False
if asset_hash is not None:
asset.hash = asset_hash
if mime_type is not None and not asset.mime_type:
if mime_type is not None:
asset.mime_type = mime_type
return True

View File

@@ -10,7 +10,7 @@ from decimal import Decimal
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, select
from sqlalchemy import delete, exists, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, noload
@@ -24,14 +24,12 @@ from app.assets.database.models import (
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
apply_metadata_filter,
apply_tag_filters,
build_prefix_like_conditions,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
def _check_is_scalar(v):
@@ -46,6 +44,15 @@ def _check_is_scalar(v):
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)):
@@ -59,19 +66,96 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""Turn a metadata key/value into typed projection rows."""
if value is None:
return []
return [_scalar_to_row(key, 0, None)]
if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)]
if isinstance(value, list):
if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def get_reference_by_id(
@@ -128,21 +212,6 @@ def reference_exists_for_asset_id(
return session.execute(q).first() is not None
def reference_exists(
session: Session,
reference_id: str,
) -> bool:
"""Return True if a reference with the given ID exists (not soft-deleted)."""
q = (
select(sa.literal(True))
.select_from(AssetReference)
.where(AssetReference.id == reference_id)
.where(AssetReference.deleted_at.is_(None))
.limit(1)
)
return session.execute(q).first() is not None
def insert_reference(
session: Session,
asset_id: str,
@@ -267,8 +336,8 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
@@ -297,8 +366,8 @@ def list_references_page(
count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
)
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0)
refs = session.execute(base).unique().scalars().all()
@@ -310,7 +379,7 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
.order_by(AssetReferenceTag.tag_name.asc())
.order_by(AssetReferenceTag.added_at)
)
for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name)
@@ -423,42 +492,6 @@ def update_reference_updated_at(
)
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
override system keys of the same name.
"""
session.execute(
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == ref.id
)
)
session.flush()
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
if not merged:
return
rows: list[AssetReferenceMeta] = []
for k, v in merged.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=ref.id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def set_reference_metadata(
session: Session,
reference_id: str,
@@ -472,24 +505,33 @@ def set_reference_metadata(
ref.updated_at = get_utc_now()
session.flush()
rebuild_metadata_projection(session, ref)
def set_reference_system_metadata(
session: Session,
reference_id: str,
system_metadata: dict | None = None,
) -> None:
"""Set system_metadata on a reference and rebuild the merged projection."""
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
ref.system_metadata = system_metadata or {}
ref.updated_at = get_utc_now()
session.execute(
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == reference_id
)
)
session.flush()
rebuild_metadata_projection(session, ref)
if not user_metadata:
return
rows: list[AssetReferenceMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=reference_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_reference_by_id(
@@ -529,19 +571,19 @@ def soft_delete_reference_by_id(
def set_reference_preview(
session: Session,
reference_id: str,
preview_reference_id: str | None = None,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
if preview_reference_id is None:
if preview_asset_id is None:
ref.preview_id = None
else:
if not session.get(AssetReference, preview_reference_id):
raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
ref.preview_id = preview_reference_id
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
ref.preview_id = preview_asset_id
ref.updated_at = get_utc_now()
session.flush()
@@ -567,8 +609,6 @@ def list_references_by_asset_id(
session.execute(
select(AssetReference)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
.order_by(AssetReference.id.asc())
)
.scalars()
@@ -576,25 +616,6 @@ def list_references_by_asset_id(
)
def list_all_file_paths_by_asset_id(
session: Session,
asset_id: str,
) -> list[str]:
"""Return every file_path for an asset, including soft-deleted/missing refs.
Used for orphan cleanup where all on-disk files must be removed.
"""
return list(
session.execute(
select(AssetReference.file_path)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.file_path.isnot(None))
)
.scalars()
.all()
)
def upsert_reference(
session: Session,
asset_id: str,
@@ -834,22 +855,6 @@ def bulk_update_is_missing(
return total
def update_is_missing_by_asset_id(
session: Session, asset_id: str, value: bool
) -> int:
"""Set is_missing flag for ALL references belonging to an asset.
Returns: Number of rows updated
"""
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.deleted_at.is_(None))
.values(is_missing=value)
)
return result.rowcount
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
"""Delete references by their IDs.

View File

@@ -1,14 +1,12 @@
"""Shared utilities for database query modules."""
import os
from decimal import Decimal
from typing import Iterable, Sequence
from typing import Iterable
import sqlalchemy as sa
from sqlalchemy import exists
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string, normalize_tags
from app.assets.database.models import AssetReference
from app.assets.helpers import escape_sql_like_string
MAX_BIND_PARAMS = 800
@@ -54,74 +52,3 @@ def build_prefix_like_conditions(
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
return sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@@ -8,15 +8,12 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import (
Asset,
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
apply_metadata_filter,
apply_tag_filters,
build_visible_owner_clause,
iter_row_chunks,
)
@@ -75,9 +72,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
tag_name
for (tag_name,) in (
session.execute(
select(AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id == reference_id)
.order_by(AssetReferenceTag.tag_name.asc())
select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
]
@@ -120,7 +117,7 @@ def set_reference_tags(
)
session.flush()
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
def add_tags_to_reference(
@@ -275,12 +272,6 @@ def list_tags_with_usage(
.select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
.subquery()
@@ -317,12 +308,6 @@ def list_tags_with_usage(
select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
)
@@ -335,53 +320,6 @@ def list_tags_with_usage(
return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],

View File

@@ -18,7 +18,7 @@ from app.assets.database.queries import (
mark_references_missing_outside_prefixes,
reassign_asset_references,
remove_missing_tag_for_asset_id,
set_reference_system_metadata,
set_reference_metadata,
update_asset_hash_and_mime,
)
from app.assets.services.bulk_ingest import (
@@ -490,8 +490,8 @@ def enrich_asset(
logging.warning("Failed to hash %s: %s", file_path, e)
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
set_reference_system_metadata(session, reference_id, system_metadata)
user_metadata = metadata.to_user_metadata()
set_reference_metadata(session, reference_id, user_metadata)
if full_hash:
existing = get_asset_by_hash(session, full_hash)

View File

@@ -16,12 +16,10 @@ from app.assets.database.queries import (
get_reference_by_id,
get_reference_with_owner_check,
list_references_page,
list_all_file_paths_by_asset_id,
list_references_by_asset_id,
set_reference_metadata,
set_reference_preview,
set_reference_tags,
update_asset_hash_and_mime,
update_reference_access_time,
update_reference_name,
update_reference_updated_at,
@@ -69,8 +67,6 @@ def update_asset_metadata(
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult:
with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id)
@@ -107,21 +103,6 @@ def update_asset_metadata(
)
touched = True
if mime_type is not None:
updated = update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
if updated:
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_id,
)
touched = True
if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id)
@@ -178,9 +159,11 @@ def delete_asset_reference(
session.commit()
return True
# Orphaned asset - gather ALL file paths (including
# soft-deleted / missing refs) so their on-disk files get cleaned up.
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
# Orphaned asset - delete it and its files
refs = list_references_by_asset_id(session, asset_id=asset_id)
file_paths = [
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
]
# Also include the just-deleted file path
if file_path:
file_paths.append(file_path)
@@ -202,7 +185,7 @@ def delete_asset_reference(
def set_asset_preview(
reference_id: str,
preview_reference_id: str | None = None,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
@@ -211,7 +194,7 @@ def set_asset_preview(
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_reference_id,
preview_asset_id=preview_asset_id,
)
result = fetch_reference_asset_and_tags(
@@ -280,47 +263,6 @@ def list_assets_page(
return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(
asset_hash: str,
owner_id: str = "",
) -> DownloadResolutionResult | None:
"""Resolve a blake3 hash to an on-disk file path.
Only references visible to *owner_id* are considered (owner-less
references are always visible).
Returns a DownloadResolutionResult with abs_path, content_type, and
download_name, or None if no asset or live path is found.
"""
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash)
if not asset:
return None
refs = list_references_by_asset_id(session, asset_id=asset.id)
visible = [
r for r in refs
if r.owner_id == "" or r.owner_id == owner_id
]
abs_path = select_best_live_path(visible)
if not abs_path:
return None
display_name = os.path.basename(abs_path)
for ref in visible:
if ref.file_path == abs_path and ref.name:
display_name = ref.name
break
ctype = (
asset.mime_type
or mimetypes.guess_type(display_name)[0]
or "application/octet-stream"
)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=display_name,
)
def resolve_asset_for_download(
reference_id: str,
owner_id: str = "",

View File

@@ -11,14 +11,13 @@ from app.assets.database.queries import (
add_tags_to_reference,
fetch_reference_and_asset,
get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
upsert_reference,
validate_tags_exist,
@@ -27,7 +26,6 @@ from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
validate_path_within_base,
)
@@ -67,7 +65,7 @@ def _ingest_file_from_path(
with create_session() as session:
if preview_id:
if not reference_exists(session, preview_id):
if preview_id not in get_existing_asset_ids(session, [preview_id]):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
@@ -137,8 +135,6 @@ def _register_existing_asset(
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> RegisterAssetResult:
user_metadata = user_metadata or {}
@@ -147,25 +143,14 @@ def _register_existing_asset(
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
if mime_type and not asset.mime_type:
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
ref, ref_created = get_or_create_reference(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
preview_id=preview_id,
)
if not ref_created:
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
@@ -257,8 +242,6 @@ def upload_from_temp_path(
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult:
try:
digest, _ = hashing.compute_blake3_hash(temp_path)
@@ -287,8 +270,6 @@ def upload_from_temp_path(
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
)
return UploadResult(
ref=result.ref,
@@ -310,7 +291,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = mime_type or (
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
@@ -334,7 +315,7 @@ def upload_from_temp_path(
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=preview_id,
preview_id=None,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
@@ -361,99 +342,30 @@ def upload_from_temp_path(
)
def register_file_in_place(
abs_path: str,
name: str,
tags: list[str],
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash file: {e}")
asset_hash = "blake3:" + digest
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
content_type = mime_type or (
mimetypes.guess_type(abs_path, strict=False)[0]
or "application/octet-stream"
)
ingest_result = _ingest_file_from_path(
abs_path=abs_path,
asset_hash=asset_hash,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name, fallback=digest),
owner_id=owner_id,
tags=merged_tags,
tag_origin="upload",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult | None:
canonical = hash_str.strip().lower()
try:
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
)
except ValueError:
logging.warning("create_from_hash: no asset found for hash %s", canonical)
return None
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
ref=result.ref,

View File

@@ -25,9 +25,7 @@ class ReferenceData:
preview_id: str | None
created_at: datetime
updated_at: datetime
system_metadata: dict[str, Any] | None = None
job_id: str | None = None
last_access_time: datetime | None = None
last_access_time: datetime | None
@dataclass(frozen=True)
@@ -95,8 +93,6 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
job_id=ref.job_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,

View File

@@ -1,5 +1,3 @@
from typing import Sequence
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
@@ -8,7 +6,6 @@ from app.assets.database.queries import (
list_tags_with_usage,
remove_tags_from_reference,
)
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage
from app.database.db import create_session
@@ -76,23 +73,3 @@ def list_tags(
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@@ -1,18 +1,9 @@
from typing import Any
from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase):
metadata = MetaData(naming_convention=NAMING_CONVENTION)
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()

View File

@@ -6,7 +6,6 @@ import uuid
import glob
import shutil
import logging
import tempfile
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@@ -378,15 +377,8 @@ class UserManager():
try:
body = await request.read()
dir_name = os.path.dirname(path)
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
with open(path, "wb") as f:
f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(

View File

@@ -83,8 +83,6 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
@@ -149,7 +147,6 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@@ -182,6 +179,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
@@ -263,6 +262,4 @@ else:
args.fast = set(args.fast)
def enables_dynamic_vram():
if args.enable_dynamic_vram:
return True
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[dict]]
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
class HiddenInputTypeDict(TypedDict):

View File

@@ -93,50 +93,6 @@ class IndexListCallbacks:
return {}
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
return None
cond_tensor = cond_value.cond
if temporal_dim >= cond_tensor.ndim:
return None
cond_size = cond_tensor.size(temporal_dim)
if temporal_scale == 1:
expected_size = x_in.size(window.dim) - temporal_offset
if cond_size != expected_size:
return None
if temporal_offset == 0 and temporal_scale == 1:
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
return cond_value._copy_with(sliced)
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)
if not indices:
return None
if temporal_scale > 1:
scaled = []
for i in indices:
for k in range(temporal_scale):
si = i * temporal_scale + k
if si < cond_size:
scaled.append(si)
indices = scaled
if not indices:
return None
idx = tuple([slice(None)] * temporal_dim + [indices])
sliced = cond_tensor[idx].to(device)
return cond_value._copy_with(sliced)
@dataclass
class ContextSchedule:
name: str
@@ -221,17 +177,10 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item[cond_key] = result
handled = True
break
if not handled and self._model is not None:
result = self._model.resize_cond_for_context_window(
cond_key, cond_value, window, x_in, device,
retain_index_list=self.cond_retain_index_list)
if result is not None:
new_cond_item[cond_key] = result
handled = True
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
@@ -275,7 +224,6 @@ class IndexListContextHandler(ContextHandlerABC):
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))

View File

@@ -209,39 +209,3 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@@ -14,6 +14,9 @@ if TYPE_CHECKING:
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.cli_args import args
import uuid
import os
from node_helpers import conditioning_set_values
# #######################################################################################################
@@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
HookedOnly = "hooked_only"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class _HookRef:
pass
def __init__(self):
if _ISOLATION_HOOKREF_MODE:
self._pyisolate_id = str(uuid.uuid4())
def _ensure_pyisolate_id(self):
pyisolate_id = getattr(self, "_pyisolate_id", None)
if pyisolate_id is None:
pyisolate_id = str(uuid.uuid4())
self._pyisolate_id = pyisolate_id
return pyisolate_id
def __eq__(self, other):
if not _ISOLATION_HOOKREF_MODE:
return self is other
if not isinstance(other, _HookRef):
return False
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
def __hash__(self):
if not _ISOLATION_HOOKREF_MODE:
return id(self)
return hash(self._ensure_pyisolate_id())
def __str__(self):
if not _ISOLATION_HOOKREF_MODE:
return super().__str__()
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
@@ -168,6 +200,8 @@ class WeightHook(Hook):
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if self.weights is None:
self.weights = {}
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Clip:

394
comfy/isolation/__init__.py Normal file
View File

@@ -0,0 +1,394 @@
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
from __future__ import annotations
import asyncio
import inspect
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, TYPE_CHECKING
import folder_paths
from .extension_loader import load_isolated_node
from .manifest_loader import find_manifest_directories
from .runtime_helpers import build_stub_class, get_class_types_for_extension
from .shm_forensics import scan_shm_forensics, start_shm_forensics
if TYPE_CHECKING:
from pyisolate import ExtensionManager
from .extension_wrapper import ComfyNodeExtension
LOG_PREFIX = "]["
isolated_node_timings: List[tuple[float, Path, int]] = []
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
def initialize_proxies() -> None:
from .child_hooks import is_child_process
is_child = is_child_process()
if is_child:
from .child_hooks import initialize_child_process
initialize_child_process()
else:
from .host_hooks import initialize_host_process
initialize_host_process()
start_shm_forensics()
@dataclass(frozen=True)
class IsolatedNodeSpec:
node_name: str
display_name: str
stub_class: type
module_path: Path
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
_CLAIMED_PATHS: Set[Path] = set()
_ISOLATION_SCAN_ATTEMPTED = False
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
_EARLY_START_TIME: Optional[float] = None
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
if _ISOLATION_BACKGROUND_TASK is not None:
return
_EARLY_START_TIME = time.perf_counter()
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
if _ISOLATION_BACKGROUND_TASK is not None:
specs = await _ISOLATION_BACKGROUND_TASK
return specs
return await initialize_isolation_nodes()
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
if _ISOLATED_NODE_SPECS:
return _ISOLATED_NODE_SPECS
if _ISOLATION_SCAN_ATTEMPTED:
return []
_ISOLATION_SCAN_ATTEMPTED = True
manifest_entries = find_manifest_directories()
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
if not manifest_entries:
return []
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
semaphore = asyncio.Semaphore(concurrency_limit)
async def load_with_semaphore(
node_dir: Path, manifest: Path
) -> List[IsolatedNodeSpec]:
async with semaphore:
load_start = time.perf_counter()
spec_list = await load_isolated_node(
node_dir,
manifest,
logger,
lambda name, info, extension: build_stub_class(
name,
info,
extension,
_RUNNING_EXTENSIONS,
logger,
),
PYISOLATE_VENV_ROOT,
_EXTENSION_MANAGERS,
)
spec_list = [
IsolatedNodeSpec(
node_name=node_name,
display_name=display_name,
stub_class=stub_cls,
module_path=node_dir,
)
for node_name, display_name, stub_cls in spec_list
]
isolated_node_timings.append(
(time.perf_counter() - load_start, node_dir, len(spec_list))
)
return spec_list
tasks = [
load_with_semaphore(node_dir, manifest)
for node_dir, manifest in manifest_entries
]
results = await asyncio.gather(*tasks, return_exceptions=True)
specs: List[IsolatedNodeSpec] = []
for result in results:
if isinstance(result, Exception):
logger.error(
"%s Isolated node failed during startup; continuing: %s",
LOG_PREFIX,
result,
)
continue
specs.extend(result)
_ISOLATED_NODE_SPECS = specs
return list(_ISOLATED_NODE_SPECS)
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
"""Get all node class types (node names) belonging to an extension."""
extension = _RUNNING_EXTENSIONS.get(extension_name)
if not extension:
return set()
ext_path = Path(extension.module_path)
class_types = set()
for spec in _ISOLATED_NODE_SPECS:
if spec.module_path.resolve() == ext_path.resolve():
class_types.add(spec.node_name)
return class_types
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution."""
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:notify_graph_wait_idle",
)
async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str
) -> None:
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
await stop_result
_RUNNING_EXTENSIONS.pop(ext_name, None)
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
isolated_class_types_in_graph = needed_class_types.intersection(
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
)
graph_uses_isolation = bool(isolated_class_types_in_graph)
logger.debug(
"%s ISO:notify_graph_start running=%d needed=%d",
LOG_PREFIX,
len(_RUNNING_EXTENSIONS),
len(needed_class_types),
)
if graph_uses_isolation:
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
ext_class_types = _get_class_types_for_extension(ext_name)
# If NONE of this extension's nodes are in the execution graph -> evict.
if not ext_class_types.intersection(needed_class_types):
await _stop_extension(
ext_name,
extension,
"isolated custom_node not in execution graph, evicting",
)
else:
logger.debug(
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
LOG_PREFIX,
len(_RUNNING_EXTENSIONS),
)
# Isolated child processes add steady VRAM pressure; reclaim host-side models
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
try:
import comfy.model_management as model_management
device = model_management.get_torch_device()
if getattr(device, "type", None) == "cuda":
required = max(
model_management.minimum_inference_memory(),
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
)
free_before = model_management.get_free_memory(device)
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
await _stop_extension(
ext_name,
extension,
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
)
if model_management.get_free_memory(device) < required:
model_management.unload_all_models()
model_management.cleanup_models_gc()
model_management.cleanup_models()
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=False)
model_management.soft_empty_cache()
except Exception:
logger.debug(
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
)
finally:
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
logger.debug(
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
)
async def flush_running_extensions_transport_state() -> int:
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:flush_transport_wait_idle",
)
total_flushed = 0
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
flush_fn = getattr(extension, "flush_transport_state", None)
if not callable(flush_fn):
continue
try:
flushed = await flush_fn()
if isinstance(flushed, int):
total_flushed += flushed
if flushed > 0:
logger.debug(
"%s %s workflow-end flush released=%d",
LOG_PREFIX,
ext_name,
flushed,
)
except Exception:
logger.debug(
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
)
scan_shm_forensics(
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
)
return total_flushed
async def wait_for_model_patcher_quiescence(
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
*,
fail_loud: bool = False,
marker: str = "ISO:wait_model_patcher_idle",
) -> bool:
try:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
registry = ModelPatcherRegistry()
start = time.perf_counter()
idle = await registry.wait_all_idle(timeout_ms)
elapsed_ms = (time.perf_counter() - start) * 1000.0
if idle:
logger.debug(
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
)
return True
states = await registry.get_all_operation_states()
logger.error(
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
states,
)
if fail_loud:
raise TimeoutError(
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
)
return False
except Exception:
if fail_loud:
raise
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
return False
def get_claimed_paths() -> Set[Path]:
return _CLAIMED_PATHS
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
"""Update all active RPC instances with the current event loop.
This MUST be called at the start of each workflow execution to ensure
RPC calls are scheduled on the correct event loop. This handles the case
where asyncio.run() creates a new event loop for each workflow.
Args:
loop: The event loop to use. If None, uses asyncio.get_running_loop().
"""
if loop is None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
update_count = 0
# Update RPCs from ExtensionManagers
for manager in _EXTENSION_MANAGERS:
if not hasattr(manager, "extensions"):
continue
for name, extension in manager.extensions.items():
if hasattr(extension, "rpc") and extension.rpc is not None:
if hasattr(extension.rpc, "update_event_loop"):
extension.rpc.update_event_loop(loop)
update_count += 1
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
# Also update RPCs from running extensions (they may have direct RPC refs)
for name, extension in _RUNNING_EXTENSIONS.items():
if hasattr(extension, "rpc") and extension.rpc is not None:
if hasattr(extension.rpc, "update_event_loop"):
extension.rpc.update_event_loop(loop)
update_count += 1
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
if update_count > 0:
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
else:
logger.debug(
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
)
__all__ = [
"LOG_PREFIX",
"initialize_proxies",
"initialize_isolation_nodes",
"start_isolation_loading_early",
"await_isolation_loading",
"notify_execution_graph",
"flush_running_extensions_transport_state",
"wait_for_model_patcher_quiescence",
"get_claimed_paths",
"update_rpc_event_loops",
"IsolatedNodeSpec",
"get_class_types_for_extension",
]

641
comfy/isolation/adapter.py Normal file
View File

@@ -0,0 +1,641 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
try:
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
from comfy.isolation.model_patcher_proxy import (
ModelPatcherProxy,
ModelPatcherRegistry,
)
from comfy.isolation.model_sampling_proxy import (
ModelSamplingProxy,
ModelSamplingRegistry,
)
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
from comfy.isolation.proxies.utils_proxy import UtilsProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
except ImportError as exc: # Fail loud if Comfy environment is incomplete
raise ImportError(f"ComfyUI environment incomplete: {exc}")
logger = logging.getLogger(__name__)
# Force /dev/shm for shared memory (bwrap makes /tmp private)
import tempfile
if os.path.exists("/dev/shm"):
# Only override if not already set or if default is not /dev/shm
current_tmp = tempfile.gettempdir()
if not current_tmp.startswith("/dev/shm"):
logger.debug(
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
)
os.environ["TMPDIR"] = "/dev/shm"
tempfile.tempdir = None # Clear cache to force re-evaluation
class ComfyUIAdapter(IsolationAdapter):
# ComfyUI-specific IsolationAdapter implementation
@property
def identifier(self) -> str:
return "comfyui"
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
if "ComfyUI" in module_path and "custom_nodes" in module_path:
parts = module_path.split("ComfyUI")
if len(parts) > 1:
comfy_root = parts[0] + "ComfyUI"
return {
"preferred_root": comfy_root,
"additional_paths": [
os.path.join(comfy_root, "custom_nodes"),
os.path.join(comfy_root, "comfy"),
],
}
return None
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
comfy_root = snapshot.get("preferred_root")
if not comfy_root:
return
requirements_path = Path(comfy_root) / "requirements.txt"
if requirements_path.exists():
import re
for line in requirements_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
if pkg_name:
logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
import torch
def serialize_device(obj: Any) -> Dict[str, Any]:
return {"__type__": "device", "device_str": str(obj)}
def deserialize_device(data: Dict[str, Any]) -> Any:
return torch.device(data["device_str"])
registry.register("device", serialize_device, deserialize_device)
_VALID_DTYPES = {
"float16", "float32", "float64", "bfloat16",
"int8", "int16", "int32", "int64",
"uint8", "bool",
}
def serialize_dtype(obj: Any) -> Dict[str, Any]:
return {"__type__": "dtype", "dtype_str": str(obj)}
def deserialize_dtype(data: Dict[str, Any]) -> Any:
dtype_name = data["dtype_str"].replace("torch.", "")
if dtype_name not in _VALID_DTYPES:
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
return getattr(torch, dtype_name)
registry.register("dtype", serialize_dtype, deserialize_dtype)
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
raise RuntimeError(
f"ModelPatcher in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side: register with registry
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
model_id = ModelPatcherRegistry().register(obj)
return {"__type__": "ModelPatcherRef", "model_id": model_id}
def deserialize_model_patcher(data: Any) -> Any:
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
if isinstance(data, dict):
return ModelPatcherProxy(
data["model_id"], registry=None, manage_lifecycle=False
)
return data
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
"""Context-aware ModelPatcherRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
return ModelPatcherProxy(
data["model_id"], registry=None, manage_lifecycle=False
)
else:
return ModelPatcherRegistry()._get_instance(data["model_id"])
# Register ModelPatcher type for serialization
registry.register(
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
)
# Register ModelPatcherProxy type (already a proxy, just return ref)
registry.register(
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
)
# Register ModelPatcherRef for deserialization (context-aware: host or child)
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
def serialize_clip(obj: Any) -> Dict[str, Any]:
if hasattr(obj, "_instance_id"):
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
clip_id = CLIPRegistry().register(obj)
return {"__type__": "CLIPRef", "clip_id": clip_id}
def deserialize_clip(data: Any) -> Any:
if isinstance(data, dict):
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
return data
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
"""Context-aware CLIPRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
else:
return CLIPRegistry()._get_instance(data["clip_id"])
# Register CLIP type for serialization
registry.register("CLIP", serialize_clip, deserialize_clip)
# Register CLIPProxy type (already a proxy, just return ref)
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
# Register CLIPRef for deserialization (context-aware: host or child)
registry.register("CLIPRef", None, deserialize_clip_ref)
def serialize_vae(obj: Any) -> Dict[str, Any]:
if hasattr(obj, "_instance_id"):
return {"__type__": "VAERef", "vae_id": obj._instance_id}
vae_id = VAERegistry().register(obj)
return {"__type__": "VAERef", "vae_id": vae_id}
def deserialize_vae(data: Any) -> Any:
if isinstance(data, dict):
return VAEProxy(data["vae_id"])
return data
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
"""Context-aware VAERef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
# Child: create a proxy
return VAEProxy(data["vae_id"])
else:
# Host: lookup real VAE from registry
return VAERegistry()._get_instance(data["vae_id"])
# Register VAE type for serialization
registry.register("VAE", serialize_vae, deserialize_vae)
# Register VAEProxy type (already a proxy, just return ref)
registry.register("VAEProxy", serialize_vae, deserialize_vae)
# Register VAERef for deserialization (context-aware: host or child)
registry.register("VAERef", None, deserialize_vae_ref)
# ModelSampling serialization - handles ModelSampling* types
# copyreg removed - no pickle fallback allowed
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
raise RuntimeError(
f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side pass-through for proxies: do not re-register a proxy as a
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
def deserialize_model_sampling(data: Any) -> Any:
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
if isinstance(data, dict):
return ModelSamplingProxy(data["ms_id"])
return data
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
"""Context-aware ModelSamplingRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if is_child:
return ModelSamplingProxy(data["ms_id"])
else:
return ModelSamplingRegistry()._get_instance(data["ms_id"])
# Register all ModelSampling* and StableCascadeSampling classes dynamically
import comfy.model_sampling
for ms_cls in vars(comfy.model_sampling).values():
if not isinstance(ms_cls, type):
continue
if not issubclass(ms_cls, torch.nn.Module):
continue
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
continue
registry.register(
ms_cls.__name__,
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
)
# Register ModelSamplingRef for deserialization (context-aware: host or child)
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
def serialize_cond(obj: Any) -> Dict[str, Any]:
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
return {
"__type__": type_key,
"cond": obj.cond,
}
def deserialize_cond(data: Dict[str, Any]) -> Any:
import importlib
type_key = data["__type__"]
module_name, class_name = type_key.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
return cls(data["cond"])
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
state: Dict[str, Any] = {}
for key, value in obj.__dict__.items():
if key.startswith("_"):
continue
if callable(value):
continue
state[key] = value
return state
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
return {
"__type__": type_key,
"state": _serialize_public_state(obj),
}
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
import importlib
type_key = data["__type__"]
module_name, class_name = type_key.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
obj = cls()
for key, value in data.get("state", {}).items():
prop = getattr(type(obj), key, None)
if isinstance(prop, property) and prop.fset is None:
continue
setattr(obj, key, value)
return obj
import comfy.conds
for cond_cls in vars(comfy.conds).values():
if not isinstance(cond_cls, type):
continue
if not issubclass(cond_cls, comfy.conds.CONDRegular):
continue
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
registry.register(type_key, serialize_cond, deserialize_cond)
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
import comfy.latent_formats
for latent_cls in vars(comfy.latent_formats).values():
if not isinstance(latent_cls, type):
continue
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
continue
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
registry.register(
type_key, serialize_latent_format, deserialize_latent_format
)
registry.register(
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
)
# V3 API: unwrap NodeOutput.args
def deserialize_node_output(data: Any) -> Any:
return getattr(data, "args", data)
registry.register("NodeOutput", None, deserialize_node_output)
# KSAMPLER serializer: stores sampler name instead of function object
# sampler_function is a callable which gets filtered out by JSONSocketTransport
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
func_name = obj.sampler_function.__name__
# Map function name back to sampler name
if func_name == "sample_unipc":
sampler_name = "uni_pc"
elif func_name == "sample_unipc_bh2":
sampler_name = "uni_pc_bh2"
elif func_name == "dpm_fast_function":
sampler_name = "dpm_fast"
elif func_name == "dpm_adaptive_function":
sampler_name = "dpm_adaptive"
elif func_name.startswith("sample_"):
sampler_name = func_name[7:] # Remove "sample_" prefix
else:
sampler_name = func_name
return {
"__type__": "KSAMPLER",
"sampler_name": sampler_name,
"extra_options": obj.extra_options,
"inpaint_options": obj.inpaint_options,
}
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
import comfy.samplers
return comfy.samplers.ksampler(
data["sampler_name"],
data.get("extra_options", {}),
data.get("inpaint_options", {}),
)
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
register_hooks_serializers(registry)
# Generic Numpy Serializer
def serialize_numpy(obj: Any) -> Any:
import torch
try:
# Attempt zero-copy conversion to Tensor
return torch.from_numpy(obj)
except Exception:
# Fallback for non-numeric arrays (strings, objects, mixes)
return obj.tolist()
registry.register("ndarray", serialize_numpy, None)
def serialize_ply(obj: Any) -> Dict[str, Any]:
import base64
import torch
if obj.raw_data is not None:
return {
"__type__": "PLY",
"raw_data": base64.b64encode(obj.raw_data).decode("ascii"),
}
result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)}
if obj.colors is not None:
result["colors"] = torch.from_numpy(obj.colors)
if obj.confidence is not None:
result["confidence"] = torch.from_numpy(obj.confidence)
if obj.view_id is not None:
result["view_id"] = torch.from_numpy(obj.view_id)
return result
def deserialize_ply(data: Any) -> Any:
import base64
from comfy_api.latest._util.ply_types import PLY
if "raw_data" in data:
return PLY(raw_data=base64.b64decode(data["raw_data"]))
return PLY(
points=data["points"],
colors=data.get("colors"),
confidence=data.get("confidence"),
view_id=data.get("view_id"),
)
registry.register("PLY", serialize_ply, deserialize_ply, data_type=True)
def serialize_npz(obj: Any) -> Dict[str, Any]:
import base64
return {
"__type__": "NPZ",
"frames": [base64.b64encode(f).decode("ascii") for f in obj.frames],
}
def deserialize_npz(data: Any) -> Any:
import base64
from comfy_api.latest._util.npz_types import NPZ
return NPZ(frames=[base64.b64decode(f) for f in data["frames"]])
registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True)
def serialize_file3d(obj: Any) -> Dict[str, Any]:
import base64
return {
"__type__": "File3D",
"format": obj.format,
"data": base64.b64encode(obj.get_bytes()).decode("ascii"),
}
def deserialize_file3d(data: Any) -> Any:
import base64
from io import BytesIO
from comfy_api.latest._util.geometry_types import File3D
return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"])
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
def serialize_video(obj: Any) -> Dict[str, Any]:
components = obj.get_components()
images = components.images.detach() if components.images.requires_grad else components.images
result: Dict[str, Any] = {
"__type__": "VIDEO",
"images": images,
"frame_rate_num": components.frame_rate.numerator,
"frame_rate_den": components.frame_rate.denominator,
}
if components.audio is not None:
waveform = components.audio["waveform"]
if waveform.requires_grad:
waveform = waveform.detach()
result["audio_waveform"] = waveform
result["audio_sample_rate"] = components.audio["sample_rate"]
if components.metadata is not None:
result["metadata"] = components.metadata
return result
def deserialize_video(data: Any) -> Any:
from fractions import Fraction
from comfy_api.latest._input_impl.video_types import VideoFromComponents
from comfy_api.latest._util.video_types import VideoComponents
audio = None
if "audio_waveform" in data:
audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]}
components = VideoComponents(
images=data["images"],
frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]),
audio=audio,
metadata=data.get("metadata"),
)
return VideoFromComponents(components)
registry.register("VIDEO", serialize_video, deserialize_video, data_type=True)
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
return [
PromptServerService,
FolderPathsProxy,
ModelManagementProxy,
UtilsProxy,
ProgressProxy,
VAERegistry,
CLIPRegistry,
ModelPatcherRegistry,
ModelSamplingRegistry,
FirstStageModelRegistry,
]
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
# Resolve the real name whether it's an instance or the Singleton class itself
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
if api_name == "FolderPathsProxy":
import folder_paths
# Replace module-level functions with proxy methods
# This is aggressive but necessary for transparent proxying
# Handle both instance and class cases
instance = api() if isinstance(api, type) else api
for name in dir(instance):
if not name.startswith("_"):
setattr(folder_paths, name, getattr(instance, name))
# Fence: isolated children get writable temp inside sandbox
if os.environ.get("PYISOLATE_CHILD") == "1":
_child_temp = os.path.join("/tmp", "comfyui_temp")
os.makedirs(_child_temp, exist_ok=True)
folder_paths.temp_directory = _child_temp
return
if api_name == "ModelManagementProxy":
import comfy.model_management
instance = api() if isinstance(api, type) else api
# Replace module-level functions with proxy methods
for name in dir(instance):
if not name.startswith("_"):
setattr(comfy.model_management, name, getattr(instance, name))
return
if api_name == "UtilsProxy":
import comfy.utils
# Static Injection of RPC mechanism to ensure Child can access it
# independent of instance lifecycle.
api.set_rpc(rpc)
# Don't overwrite host hook (infinite recursion)
return
if api_name == "PromptServerProxy":
# Defer heavy import to child context
import server
instance = api() if isinstance(api, type) else api
proxy = (
instance.instance
) # PromptServerProxy instance has .instance property returning self
original_register_route = proxy.register_route
def register_route_wrapper(
method: str, path: str, handler: Callable[..., Any]
) -> None:
callback_id = rpc.register_callback(handler)
loop = getattr(rpc, "loop", None)
if loop and loop.is_running():
import asyncio
asyncio.create_task(
original_register_route(
method, path, handler=callback_id, is_callback=True
)
)
else:
original_register_route(
method, path, handler=callback_id, is_callback=True
)
return None
proxy.register_route = register_route_wrapper
class RouteTableDefProxy:
def __init__(self, proxy_instance: Any):
self.proxy = proxy_instance
def get(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("GET", path, handler)
return handler
return decorator
def post(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("POST", path, handler)
return handler
return decorator
def patch(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PATCH", path, handler)
return handler
return decorator
def put(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PUT", path, handler)
return handler
return decorator
def delete(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("DELETE", path, handler)
return handler
return decorator
proxy.routes = RouteTableDefProxy(proxy)
if (
hasattr(server, "PromptServer")
and getattr(server.PromptServer, "instance", None) != proxy
):
server.PromptServer.instance = proxy

View File

@@ -0,0 +1,141 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
# Child process initialization for PyIsolate
import logging
import os
logger = logging.getLogger(__name__)
def is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
def initialize_child_process() -> None:
# Manual RPC injection
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc:
_setup_prompt_server_stub(rpc)
_setup_utils_proxy(rpc)
else:
logger.warning("Could not get child RPC instance for manual injection")
_setup_prompt_server_stub()
_setup_utils_proxy()
except Exception as e:
logger.error(f"Manual RPC Injection failed: {e}")
_setup_prompt_server_stub()
_setup_utils_proxy()
_setup_logging()
def _setup_prompt_server_stub(rpc=None) -> None:
try:
from .proxies.prompt_server_impl import PromptServerStub
import sys
import types
# Mock server module
if "server" not in sys.modules:
mock_server = types.ModuleType("server")
sys.modules["server"] = mock_server
server = sys.modules["server"]
if not hasattr(server, "PromptServer"):
class MockPromptServer:
pass
server.PromptServer = MockPromptServer
stub = PromptServerStub()
if rpc:
PromptServerStub.set_rpc(rpc)
if hasattr(stub, "set_rpc"):
stub.set_rpc(rpc)
server.PromptServer.instance = stub
except Exception as e:
logger.error(f"Failed to setup PromptServerStub: {e}")
def _setup_utils_proxy(rpc=None) -> None:
try:
import comfy.utils
import asyncio
# Capture main loop during initialization (safe context)
main_loop = None
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
try:
main_loop = asyncio.get_event_loop()
except RuntimeError:
pass
try:
from .proxies.base import set_global_loop
if main_loop:
set_global_loop(main_loop)
except ImportError:
pass
# Sync hook wrapper for progress updates
def sync_hook_wrapper(
value: int, total: int, preview: None = None, node_id: None = None
) -> None:
if node_id is None:
try:
from comfy_execution.utils import get_executing_context
ctx = get_executing_context()
if ctx:
node_id = ctx.node_id
else:
pass
except Exception:
pass
# Bypass blocked event loop by direct outbox injection
if rpc:
try:
# Use captured main loop if available (for threaded execution), or current loop
loop = main_loop
if loop is None:
loop = asyncio.get_event_loop()
rpc.outbox.put(
{
"kind": "call",
"object_id": "UtilsProxy",
"parent_call_id": None, # We are root here usually
"calling_loop": loop,
"future": loop.create_future(), # Dummy future
"method": "progress_bar_hook",
"args": (value, total, preview, node_id),
"kwargs": {},
}
)
except Exception as e:
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
else:
logging.getLogger(__name__).warning(
"No RPC instance available for progress update"
)
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
except Exception as e:
logger.error(f"Failed to setup UtilsProxy hook: {e}")
def _setup_logging() -> None:
logging.getLogger().setLevel(logging.INFO)

View File

@@ -0,0 +1,327 @@
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
# CLIP Proxy implementation
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
)
if TYPE_CHECKING:
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
class CondStageModelRegistry(BaseRegistry[Any]):
_type_prefix = "cond_stage_model"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
_registry_class = CondStageModelRegistry
__module__ = "comfy.sd"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<CondStageModelProxy {self._instance_id}>"
class TokenizerRegistry(BaseRegistry[Any]):
_type_prefix = "tokenizer"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
_registry_class = TokenizerRegistry
__module__ = "comfy.sd"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<TokenizerProxy {self._instance_id}>"
logger = logging.getLogger(__name__)
class CLIPRegistry(BaseRegistry[Any]):
_type_prefix = "clip"
_allowed_setters = {
"layer_idx",
"tokenizer_options",
"use_clip_schedule",
"apply_hooks_to_conds",
}
async def get_ram_usage(self, instance_id: str) -> int:
return self._get_instance(instance_id).get_ram_usage()
async def get_patcher_id(self, instance_id: str) -> str:
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
async def get_cond_stage_model_id(self, instance_id: str) -> str:
return CondStageModelRegistry().register(
self._get_instance(instance_id).cond_stage_model
)
async def get_tokenizer_id(self, instance_id: str) -> str:
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
async def load_model(self, instance_id: str) -> None:
self._get_instance(instance_id).load_model()
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
self._get_instance(instance_id).clip_layer(layer_idx)
async def set_tokenizer_option(
self, instance_id: str, option_name: str, value: Any
) -> None:
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
async def get_property(self, instance_id: str, name: str) -> Any:
return getattr(self._get_instance(instance_id), name)
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
if name not in self._allowed_setters:
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
setattr(self._get_instance(instance_id), name, value)
async def tokenize(
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
) -> Any:
return self._get_instance(instance_id).tokenize(
text, return_word_ids=return_word_ids, **kwargs
)
async def encode(self, instance_id: str, text: str) -> Any:
return detach_if_grad(self._get_instance(instance_id).encode(text))
async def encode_from_tokens(
self,
instance_id: str,
tokens: Any,
return_pooled: bool = False,
return_dict: bool = False,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).encode_from_tokens(
tokens, return_pooled=return_pooled, return_dict=return_dict
)
)
async def encode_from_tokens_scheduled(
self,
instance_id: str,
tokens: Any,
unprojected: bool = False,
add_dict: Optional[dict] = None,
show_pbar: bool = True,
) -> Any:
add_dict = add_dict or {}
return detach_if_grad(
self._get_instance(instance_id).encode_from_tokens_scheduled(
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
)
)
async def add_patches(
self,
instance_id: str,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> Any:
return self._get_instance(instance_id).add_patches(
patches, strength_patch=strength_patch, strength_model=strength_model
)
async def get_key_patches(self, instance_id: str) -> Any:
return self._get_instance(instance_id).get_key_patches()
async def load_sd(
self, instance_id: str, sd: dict, full_model: bool = False
) -> Any:
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
async def get_sd(self, instance_id: str) -> Any:
return self._get_instance(instance_id).get_sd()
async def clone(self, instance_id: str) -> str:
return self.register(self._get_instance(instance_id).clone())
class CLIPProxy(BaseProxy[CLIPRegistry]):
_registry_class = CLIPRegistry
__module__ = "comfy.sd"
def get_ram_usage(self) -> int:
return self._call_rpc("get_ram_usage")
@property
def patcher(self) -> "ModelPatcherProxy":
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
if not hasattr(self, "_patcher_proxy"):
patcher_id = self._call_rpc("get_patcher_id")
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
return self._patcher_proxy
@patcher.setter
def patcher(self, value: Any) -> None:
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
if isinstance(value, ModelPatcherProxy):
self._patcher_proxy = value
else:
logger.warning(
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
)
@property
def cond_stage_model(self) -> CondStageModelProxy:
if not hasattr(self, "_cond_stage_model_proxy"):
csm_id = self._call_rpc("get_cond_stage_model_id")
self._cond_stage_model_proxy = CondStageModelProxy(
csm_id, manage_lifecycle=False
)
return self._cond_stage_model_proxy
@property
def tokenizer(self) -> TokenizerProxy:
if not hasattr(self, "_tokenizer_proxy"):
tok_id = self._call_rpc("get_tokenizer_id")
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
return self._tokenizer_proxy
def load_model(self) -> ModelPatcherProxy:
self._call_rpc("load_model")
return self.patcher
@property
def layer_idx(self) -> Optional[int]:
return self._call_rpc("get_property", "layer_idx")
@layer_idx.setter
def layer_idx(self, value: Optional[int]) -> None:
self._call_rpc("set_property", "layer_idx", value)
@property
def tokenizer_options(self) -> dict:
return self._call_rpc("get_property", "tokenizer_options")
@tokenizer_options.setter
def tokenizer_options(self, value: dict) -> None:
self._call_rpc("set_property", "tokenizer_options", value)
@property
def use_clip_schedule(self) -> bool:
return self._call_rpc("get_property", "use_clip_schedule")
@use_clip_schedule.setter
def use_clip_schedule(self, value: bool) -> None:
self._call_rpc("set_property", "use_clip_schedule", value)
@property
def apply_hooks_to_conds(self) -> Any:
return self._call_rpc("get_property", "apply_hooks_to_conds")
@apply_hooks_to_conds.setter
def apply_hooks_to_conds(self, value: Any) -> None:
self._call_rpc("set_property", "apply_hooks_to_conds", value)
def clip_layer(self, layer_idx: int) -> None:
return self._call_rpc("clip_layer", layer_idx)
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
return self._call_rpc("set_tokenizer_option", option_name, value)
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
return self._call_rpc(
"tokenize", text, return_word_ids=return_word_ids, **kwargs
)
def encode(self, text: str) -> Any:
return self._call_rpc("encode", text)
def encode_from_tokens(
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
) -> Any:
res = self._call_rpc(
"encode_from_tokens",
tokens,
return_pooled=return_pooled,
return_dict=return_dict,
)
if return_pooled and isinstance(res, list) and not return_dict:
return tuple(res)
return res
def encode_from_tokens_scheduled(
self,
tokens: Any,
unprojected: bool = False,
add_dict: Optional[dict] = None,
show_pbar: bool = True,
) -> Any:
add_dict = add_dict or {}
return self._call_rpc(
"encode_from_tokens_scheduled",
tokens,
unprojected=unprojected,
add_dict=add_dict,
show_pbar=show_pbar,
)
def add_patches(
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
) -> Any:
return self._call_rpc(
"add_patches",
patches,
strength_patch=strength_patch,
strength_model=strength_model,
)
def get_key_patches(self) -> Any:
return self._call_rpc("get_key_patches")
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
return self._call_rpc("load_sd", sd, full_model=full_model)
def get_sd(self) -> Any:
return self._call_rpc("get_sd")
def clone(self) -> CLIPProxy:
new_id = self._call_rpc("clone")
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
if not IS_CHILD_PROCESS:
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()

View File

@@ -0,0 +1,388 @@
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
from __future__ import annotations
import logging
import os
import inspect
import sys
import types
import platform
from pathlib import Path
from typing import Callable, Dict, List, Tuple
import pyisolate
from pyisolate import ExtensionManager, ExtensionManagerConfig
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import canonicalize_name
from .extension_wrapper import ComfyNodeExtension
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
from .host_policy import load_host_policy
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
logger = logging.getLogger(__name__)
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
"""Register an isolated extension's web directory on the host side."""
import nodes
# Method 1: pyproject.toml [tool.comfy] web field
pyproject = node_dir / "pyproject.toml"
if pyproject.exists():
try:
with pyproject.open("rb") as f:
data = tomllib.load(f)
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
if web_dir_name:
web_dir_path = str(node_dir / web_dir_name)
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
return
except Exception:
pass
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
init_file = node_dir / "__init__.py"
if init_file.exists():
try:
source = init_file.read_text()
for line in source.splitlines():
stripped = line.strip()
if stripped.startswith("WEB_DIRECTORY"):
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
_, _, value = stripped.partition("=")
value = value.strip().strip("\"'")
if value:
web_dir_path = str((node_dir / value).resolve())
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
return
except Exception:
pass
async def _stop_extension_safe(
extension: ComfyNodeExtension, extension_name: str
) -> None:
try:
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
await stop_result
except Exception:
logger.debug("][ %s stop failed", extension_name, exc_info=True)
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
req, sep, marker = dep.partition(";")
req = req.strip()
marker_suffix = f";{marker}" if sep else ""
def _resolve_local_path(local_path: str) -> Path | None:
for base in base_paths:
candidate = (base / local_path).resolve()
if candidate.exists():
return candidate
return None
if req.startswith("./") or req.startswith("../"):
resolved = _resolve_local_path(req)
if resolved is not None:
return f"{resolved}{marker_suffix}"
if req.startswith("file://"):
raw = req[len("file://") :]
if raw.startswith("./") or raw.startswith("../"):
resolved = _resolve_local_path(raw)
if resolved is not None:
return f"file://{resolved}{marker_suffix}"
return dep
def _dependency_name_from_spec(dep: str) -> str | None:
stripped = dep.strip()
if not stripped or stripped == "-e" or stripped.startswith("-e "):
return None
if stripped.startswith(("/", "./", "../", "file://")):
return None
try:
return canonicalize_name(Requirement(stripped).name)
except InvalidRequirement:
return None
def _parse_cuda_wheels_config(
tool_config: dict[str, object], dependencies: list[str]
) -> dict[str, object] | None:
raw_config = tool_config.get("cuda_wheels")
if raw_config is None:
return None
if not isinstance(raw_config, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels] must be a table"
)
index_url = raw_config.get("index_url")
if not isinstance(index_url, str) or not index_url.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
)
packages = raw_config.get("packages")
if not isinstance(packages, list) or not all(
isinstance(package_name, str) and package_name.strip()
for package_name in packages
):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
)
declared_dependencies = {
dependency_name
for dep in dependencies
if (dependency_name := _dependency_name_from_spec(dep)) is not None
}
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
missing = [
package_name
for package_name in normalized_packages
if package_name not in declared_dependencies
]
if missing:
missing_joined = ", ".join(sorted(missing))
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
f"{missing_joined}"
)
package_map = raw_config.get("package_map", {})
if not isinstance(package_map, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
)
normalized_package_map: dict[str, str] = {}
for dependency_name, index_package_name in package_map.items():
if not isinstance(dependency_name, str) or not dependency_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
)
if not isinstance(index_package_name, str) or not index_package_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
)
canonical_dependency_name = canonicalize_name(dependency_name)
if canonical_dependency_name not in normalized_packages:
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
"[tool.comfy.isolation.cuda_wheels.packages]"
)
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
return {
"index_url": index_url.rstrip("/") + "/",
"packages": normalized_packages,
"package_map": normalized_package_map,
}
def get_enforcement_policy() -> Dict[str, bool]:
return {
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
}
class ExtensionLoadError(RuntimeError):
pass
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
normalized_name = extension_name.replace("-", "_").replace(".", "_")
if normalized_name not in sys.modules:
dummy_module = types.ModuleType(normalized_name)
dummy_module.__file__ = str(node_dir / "__init__.py")
dummy_module.__path__ = [str(node_dir)]
dummy_module.__package__ = normalized_name
sys.modules[normalized_name] = dummy_module
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
for details in cached_data.values():
if not isinstance(details, dict):
return True
if details.get("is_v3") and "schema_v1" not in details:
return True
return False
async def load_isolated_node(
node_dir: Path,
manifest_path: Path,
logger: logging.Logger,
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
venv_root: Path,
extension_managers: List[ExtensionManager],
) -> List[Tuple[str, str, type]]:
try:
with manifest_path.open("rb") as handle:
manifest_data = tomllib.load(handle)
except Exception as e:
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
return []
# Parse [tool.comfy.isolation]
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
can_isolate = tool_config.get("can_isolate", False)
share_torch = tool_config.get("share_torch", False)
# Parse [project] dependencies
project_config = manifest_data.get("project", {})
dependencies = project_config.get("dependencies", [])
if not isinstance(dependencies, list):
dependencies = []
# Get extension name (default to folder name if not in project.name)
extension_name = project_config.get("name", node_dir.name)
# LOGIC: Isolation Decision
policy = get_enforcement_policy()
isolated = can_isolate or policy["force_isolated"]
if not isolated:
return []
logger.info(f"][ Loading isolated node: {extension_name}")
import folder_paths
base_paths = [Path(folder_paths.base_path), node_dir]
dependencies = [
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
for dep in dependencies
]
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
manager: ExtensionManager = pyisolate.ExtensionManager(
ComfyNodeExtension, manager_config
)
extension_managers.append(manager)
host_policy = load_host_policy(Path(folder_paths.base_path))
sandbox_config = {}
is_linux = platform.system() == "Linux"
if is_linux and isolated:
sandbox_config = {
"network": host_policy["allow_network"],
"writable_paths": host_policy["writable_paths"],
"readonly_paths": host_policy["readonly_paths"],
}
share_cuda_ipc = share_torch and is_linux
extension_config = {
"name": extension_name,
"module_path": str(node_dir),
"isolated": True,
"dependencies": dependencies,
"share_torch": share_torch,
"share_cuda_ipc": share_cuda_ipc,
"sandbox_mode": host_policy["sandbox_mode"],
"sandbox": sandbox_config,
}
if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels
extension = manager.load_extension(extension_config)
register_dummy_module(extension_name, node_dir)
# Register web directory on the host — only when sandbox is disabled.
# In sandbox mode, serving untrusted JS to the browser is not safe.
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# Try cache first (lazy spawn)
if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root)
if cached_data:
if _is_stale_node_cache(cached_data):
logger.debug(
"][ %s cache is stale/incompatible; rebuilding metadata",
extension_name,
)
else:
logger.debug(f"][ {extension_name} loaded from cache")
specs: List[Tuple[str, str, type]] = []
for node_name, details in cached_data.items():
stub_cls = build_stub_class(node_name, details, extension)
specs.append(
(node_name, details.get("display_name", node_name), stub_cls)
)
return specs
# Cache miss - spawn process and get metadata
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
try:
remote_nodes: Dict[str, str] = await extension.list_nodes()
except Exception as exc:
logger.warning(
"][ %s metadata discovery failed, skipping isolated load: %s",
extension_name,
exc,
)
await _stop_extension_safe(extension, extension_name)
return []
if not remote_nodes:
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
await _stop_extension_safe(extension, extension_name)
return []
specs: List[Tuple[str, str, type]] = []
cache_data: Dict[str, Dict] = {}
for node_name, display_name in remote_nodes.items():
try:
details = await extension.get_node_details(node_name)
except Exception as exc:
logger.warning(
"][ %s failed to load metadata for %s, skipping node: %s",
extension_name,
node_name,
exc,
)
continue
details["display_name"] = display_name
cache_data[node_name] = details
stub_cls = build_stub_class(node_name, details, extension)
specs.append((node_name, display_name, stub_cls))
if not specs:
logger.warning(
"][ %s produced no usable nodes after metadata scan; skipping",
extension_name,
)
await _stop_extension_safe(extension, extension_name)
return []
# Save metadata to cache for future runs
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
logger.debug(f"][ {extension_name} metadata cached")
# EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name)
return specs
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]

View File

@@ -0,0 +1,699 @@
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
from __future__ import annotations
import asyncio
import torch
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError as e:
raise AttributeError(item) from e
def copy(self):
return AttrDict(super().copy())
import importlib
import inspect
import json
import logging
import os
import sys
import uuid
from dataclasses import asdict
from typing import Any, Dict, List, Tuple
from pyisolate import ExtensionBase
from comfy_api.internal import _ComfyNodeInternal
LOG_PREFIX = "]["
V3_DISCOVERY_TIMEOUT = 30
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
logger = logging.getLogger(__name__)
def _flush_tensor_transport_state(marker: str) -> int:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
return 0
if not callable(flush_tensor_keeper):
return 0
flushed = flush_tensor_keeper()
if flushed > 0:
logger.debug(
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
)
return flushed
def _relieve_child_vram_pressure(marker: str) -> None:
import comfy.model_management as model_management
model_management.cleanup_models_gc()
model_management.cleanup_models()
device = model_management.get_torch_device()
if not hasattr(device, "type") or device.type == "cpu":
return
required = max(
model_management.minimum_inference_memory(),
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=True)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=False)
model_management.cleanup_models()
model_management.soft_empty_cache()
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
def _sanitize_for_transport(value):
primitives = (str, int, float, bool, type(None))
if isinstance(value, primitives):
return value
cls_name = value.__class__.__name__
if cls_name == "FlexibleOptionalInputType":
return {
"__pyisolate_flexible_optional__": True,
"type": _sanitize_for_transport(getattr(value, "type", "*")),
}
if cls_name == "AnyType":
return {"__pyisolate_any_type__": True, "value": str(value)}
if cls_name == "ByPassTypeTuple":
return {
"__pyisolate_bypass_tuple__": [
_sanitize_for_transport(v) for v in tuple(value)
]
}
if isinstance(value, dict):
return {k: _sanitize_for_transport(v) for k, v in value.items()}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
if isinstance(value, list):
return [_sanitize_for_transport(v) for v in value]
return str(value)
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
# The canonical definition is now in pyisolate._internal.remote_handle
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
class ComfyNodeExtension(ExtensionBase):
def __init__(self) -> None:
super().__init__()
self.node_classes: Dict[str, type] = {}
self.display_names: Dict[str, str] = {}
self.node_instances: Dict[str, Any] = {}
self.remote_objects: Dict[str, Any] = {}
self._route_handlers: Dict[str, Any] = {}
self._module: Any = None
async def on_module_loaded(self, module: Any) -> None:
self._module = module
# Registries are initialized in host_hooks.py initialize_host_process()
# They auto-register via ProxiedSingleton when instantiated
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
try:
from comfy_api.latest import ComfyExtension
for name, obj in inspect.getmembers(module):
if not (
inspect.isclass(obj)
and issubclass(obj, ComfyExtension)
and obj is not ComfyExtension
):
continue
if not obj.__module__.startswith(module.__name__):
continue
try:
ext_instance = obj()
try:
await asyncio.wait_for(
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
)
except asyncio.TimeoutError:
logger.error(
"%s V3 Extension %s timed out in on_load()",
LOG_PREFIX,
name,
)
continue
try:
v3_nodes = await asyncio.wait_for(
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
)
except asyncio.TimeoutError:
logger.error(
"%s V3 Extension %s timed out in get_node_list()",
LOG_PREFIX,
name,
)
continue
for node_cls in v3_nodes:
if hasattr(node_cls, "GET_SCHEMA"):
schema = node_cls.GET_SCHEMA()
self.node_classes[schema.node_id] = node_cls
if schema.display_name:
self.display_names[schema.node_id] = schema.display_name
except Exception as e:
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
except ImportError:
pass
module_name = getattr(module, "__name__", "isolated_nodes")
for node_cls in self.node_classes.values():
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
node_cls.__module__ = module_name
self.node_instances = {}
async def list_nodes(self) -> Dict[str, str]:
return {name: self.display_names.get(name, name) for name in self.node_classes}
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
return await self.get_node_details(node_name)
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
node_cls = self._get_node_class(node_name)
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
input_types_raw = (
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
)
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
if output_is_list is not None:
output_is_list = tuple(bool(x) for x in output_is_list)
details: Dict[str, Any] = {
"input_types": _sanitize_for_transport(input_types_raw),
"return_types": tuple(
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
),
"return_names": getattr(node_cls, "RETURN_NAMES", None),
"function": str(getattr(node_cls, "FUNCTION", "execute")),
"category": str(getattr(node_cls, "CATEGORY", "")),
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
"output_is_list": output_is_list,
"is_v3": is_v3,
}
if is_v3:
try:
schema = node_cls.GET_SCHEMA()
schema_v1 = asdict(schema.get_v1_info(node_cls))
try:
schema_v3 = asdict(schema.get_v3_info(node_cls))
except (AttributeError, TypeError):
schema_v3 = self._build_schema_v3_fallback(schema)
details.update(
{
"schema_v1": schema_v1,
"schema_v3": schema_v3,
"hidden": [h.value for h in (schema.hidden or [])],
"description": getattr(schema, "description", ""),
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
"api_node": bool(getattr(node_cls, "API_NODE", False)),
"input_is_list": bool(
getattr(node_cls, "INPUT_IS_LIST", False)
),
"not_idempotent": bool(
getattr(node_cls, "NOT_IDEMPOTENT", False)
),
"accept_all_inputs": bool(
getattr(node_cls, "ACCEPT_ALL_INPUTS", False)
),
}
)
except Exception as exc:
logger.warning(
"%s V3 schema serialization failed for %s: %s",
LOG_PREFIX,
node_name,
exc,
)
return details
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
input_dict: Dict[str, Any] = {}
output_dict: Dict[str, Any] = {}
hidden_list: List[str] = []
if getattr(schema, "inputs", None):
for inp in schema.inputs:
self._add_schema_io_v3(inp, input_dict)
if getattr(schema, "outputs", None):
for out in schema.outputs:
self._add_schema_io_v3(out, output_dict)
if getattr(schema, "hidden", None):
for h in schema.hidden:
hidden_list.append(getattr(h, "value", str(h)))
return {
"input": input_dict,
"output": output_dict,
"hidden": hidden_list,
"name": getattr(schema, "node_id", None),
"display_name": getattr(schema, "display_name", None),
"description": getattr(schema, "description", None),
"category": getattr(schema, "category", None),
"output_node": getattr(schema, "is_output_node", False),
"deprecated": getattr(schema, "is_deprecated", False),
"experimental": getattr(schema, "is_experimental", False),
"api_node": getattr(schema, "is_api_node", False),
}
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
io_id = getattr(io_obj, "id", None)
if io_id is None:
return
io_type_fn = getattr(io_obj, "get_io_type", None)
io_type = (
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
)
as_dict_fn = getattr(io_obj, "as_dict", None)
payload = as_dict_fn() if callable(as_dict_fn) else {}
target[str(io_id)] = (io_type, payload)
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
node_cls = self._get_node_class(node_name)
if hasattr(node_cls, "INPUT_TYPES"):
return node_cls.INPUT_TYPES()
return {}
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
logger.debug(
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
len(inputs),
)
if os.environ.get("PYISOLATE_CHILD") == "1":
_relieve_child_vram_pressure("EXT:pre_execute")
resolved_inputs = self._resolve_remote_objects(inputs)
instance = self._get_node_instance(node_name)
node_cls = self._get_node_class(node_name)
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
# Hidden params come through RPC as string keys like "Hidden.prompt"
from comfy_api.latest._io import Hidden, HiddenHolder
# Map string representations back to Hidden enum keys
hidden_string_map = {
"Hidden.unique_id": Hidden.unique_id,
"Hidden.prompt": Hidden.prompt,
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
"Hidden.dynprompt": Hidden.dynprompt,
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
# Uppercase enum VALUE forms — V3 execution engine passes these
"UNIQUE_ID": Hidden.unique_id,
"PROMPT": Hidden.prompt,
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
"DYNPROMPT": Hidden.dynprompt,
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
}
# Find and extract hidden parameters (both enum and string form)
hidden_found = {}
keys_to_remove = []
for key in list(resolved_inputs.keys()):
# Check string form first (from RPC serialization)
if key in hidden_string_map:
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
keys_to_remove.append(key)
# Also check enum form (direct calls)
elif isinstance(key, Hidden):
hidden_found[key] = resolved_inputs[key]
keys_to_remove.append(key)
# Remove hidden params from kwargs
for key in keys_to_remove:
resolved_inputs.pop(key)
# Set hidden on node class if any hidden params found
if hidden_found:
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
else:
# Update existing hidden holder
for key, value in hidden_found.items():
setattr(node_cls.hidden, key.value.lower(), value)
function_name = getattr(node_cls, "FUNCTION", "execute")
if not hasattr(instance, function_name):
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
handler = getattr(instance, function_name)
try:
import torch
if asyncio.iscoroutinefunction(handler):
with torch.inference_mode():
result = await handler(**resolved_inputs)
else:
import functools
def _run_with_inference_mode(**kwargs):
with torch.inference_mode():
return handler(**kwargs)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, functools.partial(_run_with_inference_mode, **resolved_inputs)
)
except Exception:
logger.exception(
"%s ISO:child_execute_error ext=%s node=%s",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
)
raise
if type(result).__name__ == "NodeOutput":
node_output_dict = {
"__node_output__": True,
"args": self._wrap_unpicklable_objects(result.args),
}
if result.ui is not None:
node_output_dict["ui"] = result.ui
if getattr(result, "expand", None) is not None:
node_output_dict["expand"] = result.expand
if getattr(result, "block_execution", None) is not None:
node_output_dict["block_execution"] = result.block_execution
return node_output_dict
if self._is_comfy_protocol_return(result):
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
if not isinstance(result, tuple):
result = (result,)
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_CHILD") != "1":
return 0
logger.debug(
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
)
flushed = _flush_tensor_transport_state("EXT:workflow_end")
try:
from comfy.isolation.model_patcher_proxy_registry import (
ModelPatcherRegistry,
)
registry = ModelPatcherRegistry()
removed = registry.sweep_pending_cleanup()
if removed > 0:
logger.debug(
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
)
except Exception:
logger.debug(
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
)
logger.debug(
"%s ISO:child_flush_done ext=%s flushed=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
flushed,
)
return flushed
async def get_remote_object(self, object_id: str) -> Any:
"""Retrieve a remote object by ID for host-side deserialization."""
if object_id not in self.remote_objects:
raise KeyError(f"Remote object {object_id} not found")
return self.remote_objects[object_id]
def _wrap_unpicklable_objects(self, data: Any) -> Any:
if isinstance(data, (str, int, float, bool, type(None))):
return data
if isinstance(data, torch.Tensor):
tensor = data.detach() if data.requires_grad else data
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
return tensor.cpu()
return tensor
# Special-case clip vision outputs: preserve attribute access by packing fields
if hasattr(data, "penultimate_hidden_states") or hasattr(
data, "last_hidden_state"
):
fields = {}
for attr in (
"penultimate_hidden_states",
"last_hidden_state",
"image_embeds",
"text_embeds",
):
if hasattr(data, attr):
try:
fields[attr] = self._wrap_unpicklable_objects(
getattr(data, attr)
)
except Exception:
pass
if fields:
return {"__pyisolate_attribute_container__": True, "data": fields}
# Avoid converting arbitrary objects with stateful methods (models, etc.)
# They will be handled via RemoteObjectHandle below.
type_name = type(data).__name__
if type_name == "ModelPatcherProxy":
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
if type_name == "CLIPProxy":
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
if type_name == "VAEProxy":
return {"__type__": "VAERef", "vae_id": data._instance_id}
if type_name == "ModelSamplingProxy":
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
if isinstance(data, (list, tuple)):
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
return tuple(wrapped) if isinstance(data, tuple) else wrapped
if isinstance(data, dict):
converted_dict = {
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
}
return {"__pyisolate_attrdict__": True, "data": converted_dict}
from pyisolate._internal.serialization_registry import SerializerRegistry
registry = SerializerRegistry.get_instance()
if registry.is_data_type(type_name):
serializer = registry.get_serializer(type_name)
if serializer:
return serializer(data)
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = data
return RemoteObjectHandle(object_id, type(data).__name__)
def _resolve_remote_objects(self, data: Any) -> Any:
if isinstance(data, RemoteObjectHandle):
if data.object_id not in self.remote_objects:
raise KeyError(f"Remote object {data.object_id} not found")
return self.remote_objects[data.object_id]
if isinstance(data, dict):
ref_type = data.get("__type__")
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
from pyisolate._internal.model_serialization import (
deserialize_proxy_result,
)
return deserialize_proxy_result(data)
if ref_type == "ModelSamplingRef":
from pyisolate._internal.model_serialization import (
deserialize_proxy_result,
)
return deserialize_proxy_result(data)
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
if isinstance(data, (list, tuple)):
resolved = [self._resolve_remote_objects(item) for item in data]
return tuple(resolved) if isinstance(data, tuple) else resolved
return data
def _get_node_class(self, node_name: str) -> type:
if node_name not in self.node_classes:
raise KeyError(f"Unknown node: {node_name}")
return self.node_classes[node_name]
def _get_node_instance(self, node_name: str) -> Any:
if node_name not in self.node_instances:
if node_name not in self.node_classes:
raise KeyError(f"Unknown node: {node_name}")
self.node_instances[node_name] = self.node_classes[node_name]()
return self.node_instances[node_name]
async def before_module_loaded(self) -> None:
# Inject initialization here if we think this is the child
try:
from comfy.isolation import initialize_proxies
initialize_proxies()
except Exception as e:
logging.getLogger(__name__).error(
f"Failed to call initialize_proxies in before_module_loaded: {e}"
)
await super().before_module_loaded()
try:
from comfy_api.latest import ComfyAPI_latest
from .proxies.progress_proxy import ProgressProxy
ComfyAPI_latest.Execution = ProgressProxy
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
# latest_ui.folder_paths = fp_proxy
# latest_resources.folder_paths = fp_proxy
except Exception:
pass
async def call_route_handler(
self,
handler_module: str,
handler_func: str,
request_data: Dict[str, Any],
) -> Any:
cache_key = f"{handler_module}.{handler_func}"
if cache_key not in self._route_handlers:
if self._module is not None and hasattr(self._module, "__file__"):
node_dir = os.path.dirname(self._module.__file__)
if node_dir not in sys.path:
sys.path.insert(0, node_dir)
try:
module = importlib.import_module(handler_module)
self._route_handlers[cache_key] = getattr(module, handler_func)
except (ImportError, AttributeError) as e:
raise ValueError(f"Route handler not found: {cache_key}") from e
handler = self._route_handlers[cache_key]
mock_request = MockRequest(request_data)
if asyncio.iscoroutinefunction(handler):
result = await handler(mock_request)
else:
result = handler(mock_request)
return self._serialize_response(result)
def _is_comfy_protocol_return(self, result: Any) -> bool:
"""
Check if the result matches the ComfyUI 'Protocol Return' schema.
A Protocol Return is a dictionary containing specific reserved keys that
ComfyUI's execution engine interprets as instructions (UI updates,
Workflow expansion, etc.) rather than purely data outputs.
Schema:
- Must be a dict
- Must contain at least one of: 'ui', 'result', 'expand'
"""
if not isinstance(result, dict):
return False
return any(key in result for key in ("ui", "result", "expand"))
def _serialize_response(self, response: Any) -> Dict[str, Any]:
if response is None:
return {"type": "text", "body": "", "status": 204}
if isinstance(response, dict):
return {"type": "json", "body": response, "status": 200}
if isinstance(response, str):
return {"type": "text", "body": response, "status": 200}
if hasattr(response, "text") and hasattr(response, "status"):
return {
"type": "text",
"body": response.text
if hasattr(response, "text")
else str(response.body),
"status": response.status,
"headers": dict(response.headers)
if hasattr(response, "headers")
else {},
}
if hasattr(response, "body") and hasattr(response, "status"):
body = response.body
if isinstance(body, bytes):
try:
return {
"type": "text",
"body": body.decode("utf-8"),
"status": response.status,
}
except UnicodeDecodeError:
return {
"type": "binary",
"body": body.hex(),
"status": response.status,
}
return {"type": "json", "body": body, "status": response.status}
return {"type": "text", "body": str(response), "status": 200}
class MockRequest:
def __init__(self, data: Dict[str, Any]):
self.method = data.get("method", "GET")
self.path = data.get("path", "/")
self.query = data.get("query", {})
self._body = data.get("body", {})
self._text = data.get("text", "")
self.headers = data.get("headers", {})
self.content_type = data.get(
"content_type", self.headers.get("Content-Type", "application/json")
)
self.match_info = data.get("match_info", {})
async def json(self) -> Any:
if isinstance(self._body, dict):
return self._body
if isinstance(self._body, str):
return json.loads(self._body)
return {}
async def post(self) -> Dict[str, Any]:
if isinstance(self._body, dict):
return self._body
return {}
async def text(self) -> str:
if self._text:
return self._text
if isinstance(self._body, str):
return self._body
if isinstance(self._body, dict):
return json.dumps(self._body)
return ""
async def read(self) -> bytes:
return (await self.text()).encode("utf-8")

View File

@@ -0,0 +1,26 @@
# pylint: disable=import-outside-toplevel
# Host process initialization for PyIsolate
import logging
logger = logging.getLogger(__name__)
def initialize_host_process() -> None:
root = logging.getLogger()
for handler in root.handlers[:]:
root.removeHandler(handler)
root.addHandler(logging.NullHandler())
from .proxies.folder_paths_proxy import FolderPathsProxy
from .proxies.model_management_proxy import ModelManagementProxy
from .proxies.progress_proxy import ProgressProxy
from .proxies.prompt_server_impl import PromptServerService
from .proxies.utils_proxy import UtilsProxy
from .vae_proxy import VAERegistry
FolderPathsProxy()
ModelManagementProxy()
ProgressProxy()
PromptServerService()
UtilsProxy()
VAERegistry()

View File

@@ -0,0 +1,107 @@
# pylint: disable=logging-fstring-interpolation
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Dict, List, TypedDict
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
logger = logging.getLogger(__name__)
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
class HostSecurityPolicy(TypedDict):
sandbox_mode: str
allow_network: bool
writable_paths: List[str]
readonly_paths: List[str]
whitelist: Dict[str, str]
DEFAULT_POLICY: HostSecurityPolicy = {
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": ["/dev/shm", "/tmp"],
"readonly_paths": [],
"whitelist": {},
}
def _default_policy() -> HostSecurityPolicy:
return {
"sandbox_mode": DEFAULT_POLICY["sandbox_mode"],
"allow_network": DEFAULT_POLICY["allow_network"],
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
}
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
policy = _default_policy()
if not config_path.exists():
logger.debug("Host policy file missing at %s, using defaults.", config_path)
return policy
try:
with config_path.open("rb") as f:
data = tomllib.load(f)
except Exception:
logger.warning(
"Failed to parse host policy from %s, using defaults.",
config_path,
exc_info=True,
)
return policy
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
if not isinstance(tool_config, dict):
logger.debug("No [tool.comfy.host] section found, using defaults.")
return policy
sandbox_mode = tool_config.get("sandbox_mode")
if isinstance(sandbox_mode, str):
normalized_sandbox_mode = sandbox_mode.strip().lower()
if normalized_sandbox_mode in VALID_SANDBOX_MODES:
policy["sandbox_mode"] = normalized_sandbox_mode
else:
logger.warning(
"Invalid host sandbox_mode %r in %s, using default %r.",
sandbox_mode,
config_path,
DEFAULT_POLICY["sandbox_mode"],
)
if "allow_network" in tool_config:
policy["allow_network"] = bool(tool_config["allow_network"])
if "writable_paths" in tool_config:
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
if "readonly_paths" in tool_config:
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
whitelist_raw = tool_config.get("whitelist")
if isinstance(whitelist_raw, dict):
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
logger.debug(
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
len(policy["whitelist"]),
policy["sandbox_mode"],
policy["allow_network"],
)
return policy
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]

View File

@@ -0,0 +1,186 @@
# pylint: disable=import-outside-toplevel
from __future__ import annotations
import hashlib
import json
import logging
import os
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import folder_paths
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
LOG_PREFIX = "]["
logger = logging.getLogger(__name__)
CACHE_SUBDIR = "cache"
CACHE_KEY_FILE = "cache_key"
CACHE_DATA_FILE = "node_info.json"
CACHE_KEY_LENGTH = 16
def find_manifest_directories() -> List[Tuple[Path, Path]]:
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
manifest_dirs: List[Tuple[Path, Path]] = []
# Standard custom_nodes paths
for base_path in folder_paths.get_folder_paths("custom_nodes"):
base = Path(base_path)
if not base.exists() or not base.is_dir():
continue
for entry in base.iterdir():
if not entry.is_dir():
continue
# Look for pyproject.toml
manifest = entry / "pyproject.toml"
if not manifest.exists():
continue
# Validate [tool.comfy.isolation] section existence
try:
with manifest.open("rb") as f:
data = tomllib.load(f)
if (
"tool" in data
and "comfy" in data["tool"]
and "isolation" in data["tool"]["comfy"]
):
manifest_dirs.append((entry, manifest))
except Exception:
continue
return manifest_dirs
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
hasher = hashlib.sha256()
try:
# Hashing the manifest content ensures config changes invalidate cache
hasher.update(manifest_path.read_bytes())
except OSError:
hasher.update(b"__manifest_read_error__")
try:
py_files = sorted(node_dir.rglob("*.py"))
for py_file in py_files:
rel_path = py_file.relative_to(node_dir)
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
continue
hasher.update(str(rel_path).encode("utf-8"))
try:
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
except OSError:
hasher.update(b"__file_stat_error__")
except OSError:
hasher.update(b"__dir_scan_error__")
hasher.update(sys.version.encode("utf-8"))
try:
import pyisolate
hasher.update(pyisolate.__version__.encode("utf-8"))
except (ImportError, AttributeError):
hasher.update(b"__pyisolate_unknown__")
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
"""Return True only if stored cache key matches current computed key."""
try:
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
if not cache_key_file.exists() or not cache_data_file.exists():
return False
current_key = compute_cache_key(node_dir, manifest_path)
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
return current_key == stored_key
except Exception as e:
logger.debug(
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
)
return False
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
"""Load node metadata from cache, return None on any error."""
try:
_, cache_data_file = get_cache_path(node_dir, venv_root)
if not cache_data_file.exists():
return None
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return None
return data
except Exception:
return None
def save_to_cache(
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
) -> None:
"""Save node metadata and cache key atomically."""
try:
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
cache_dir = cache_key_file.parent
cache_dir.mkdir(parents=True, exist_ok=True)
cache_key = compute_cache_key(node_dir, manifest_path)
# Atomic write: data
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
try:
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
json.dump(node_data, f, indent=2)
os.replace(tmp_data_path, cache_data_file)
except Exception:
try:
os.unlink(tmp_data_path)
except OSError:
pass
raise
# Atomic write: key
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
try:
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
f.write(cache_key)
os.replace(tmp_key_path, cache_key_file)
except Exception:
try:
os.unlink(tmp_key_path)
except OSError:
pass
raise
except Exception as e:
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
__all__ = [
"LOG_PREFIX",
"find_manifest_directories",
"compute_cache_key",
"get_cache_path",
"is_cache_valid",
"load_from_cache",
"save_to_cache",
]

View File

@@ -0,0 +1,861 @@
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
# RPC proxy for ModelPatcher (parent process)
from __future__ import annotations
import logging
from typing import Any, Optional, List, Set, Dict, Callable
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
)
from comfy.isolation.model_patcher_proxy_registry import (
ModelPatcherRegistry,
AutoPatcherEjector,
)
logger = logging.getLogger(__name__)
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
_registry_class = ModelPatcherRegistry
__module__ = "comfy.model_patcher"
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is not None:
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
else:
self._rpc_caller = self._registry
return self._rpc_caller
def get_all_callbacks(self, call_type: str = None) -> Any:
return self._call_rpc("get_all_callbacks", call_type)
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
return self._call_rpc("get_all_wrappers", wrapper_type)
def _load_list(self, *args, **kwargs) -> Any:
return self._call_rpc("load_list_internal", *args, **kwargs)
def prepare_hook_patches_current_keyframe(
self, t: Any, hook_group: Any, model_options: Any
) -> None:
self._call_rpc(
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
)
def add_hook_patches(
self,
hook: Any,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> None:
self._call_rpc(
"add_hook_patches", hook, patches, strength_patch, strength_model
)
def clear_cached_hook_weights(self) -> None:
self._call_rpc("clear_cached_hook_weights")
def get_combined_hook_patches(self, hooks: Any) -> Any:
return self._call_rpc("get_combined_hook_patches", hooks)
def get_additional_models_with_key(self, key: str) -> Any:
return self._call_rpc("get_additional_models_with_key", key)
@property
def object_patches(self) -> Any:
return self._call_rpc("get_object_patches")
@property
def patches(self) -> Any:
res = self._call_rpc("get_patches")
if isinstance(res, dict):
new_res = {}
for k, v in res.items():
new_list = []
for item in v:
if isinstance(item, list):
new_list.append(tuple(item))
else:
new_list.append(item)
new_res[k] = new_list
return new_res
return res
@property
def pinned(self) -> Set:
val = self._call_rpc("get_patcher_attr", "pinned")
return set(val) if val is not None else set()
@property
def hook_patches(self) -> Dict:
val = self._call_rpc("get_patcher_attr", "hook_patches")
if val is None:
return {}
try:
from comfy.hooks import _HookRef
import json
new_val = {}
for k, v in val.items():
if isinstance(k, str):
if k.startswith("PYISOLATE_HOOKREF:"):
ref_id = k.split(":", 1)[1]
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
elif k.startswith("__pyisolate_key__"):
try:
json_str = k[len("__pyisolate_key__") :]
data = json.loads(json_str)
ref_id = None
if isinstance(data, list):
for item in data:
if (
isinstance(item, list)
and len(item) == 2
and item[0] == "id"
):
ref_id = item[1]
break
if ref_id:
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
else:
new_val[k] = v
except Exception:
new_val[k] = v
else:
new_val[k] = v
else:
new_val[k] = v
return new_val
except ImportError:
return val
def set_hook_mode(self, hook_mode: Any) -> None:
self._call_rpc("set_hook_mode", hook_mode)
def register_all_hook_patches(
self,
hooks: Any,
target_dict: Any,
model_options: Any = None,
registered: Any = None,
) -> None:
self._call_rpc(
"register_all_hook_patches", hooks, target_dict, model_options, registered
)
def is_clone(self, other: Any) -> bool:
if isinstance(other, ModelPatcherProxy):
return self._call_rpc("is_clone_by_id", other._instance_id)
return False
def clone(self) -> ModelPatcherProxy:
new_id = self._call_rpc("clone")
return ModelPatcherProxy(
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
def clone_has_same_weights(self, clone: Any) -> bool:
if isinstance(clone, ModelPatcherProxy):
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
if not IS_CHILD_PROCESS:
return self._call_rpc("is_clone", clone)
return False
def get_model_object(self, name: str) -> Any:
return self._call_rpc("get_model_object", name)
@property
def model_options(self) -> dict:
data = self._call_rpc("get_model_options")
import json
def _decode_keys(obj):
if isinstance(obj, dict):
new_d = {}
for k, v in obj.items():
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
try:
json_str = k[17:]
val = json.loads(json_str)
if isinstance(val, list):
val = tuple(val)
new_d[val] = _decode_keys(v)
except:
new_d[k] = _decode_keys(v)
else:
new_d[k] = _decode_keys(v)
return new_d
if isinstance(obj, list):
return [_decode_keys(x) for x in obj]
return obj
return _decode_keys(data)
@model_options.setter
def model_options(self, value: dict) -> None:
self._call_rpc("set_model_options", value)
def apply_hooks(self, hooks: Any) -> Any:
return self._call_rpc("apply_hooks", hooks)
def prepare_state(self, timestep: Any) -> Any:
return self._call_rpc("prepare_state", timestep)
def restore_hook_patches(self) -> None:
self._call_rpc("restore_hook_patches")
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
self._call_rpc("unpatch_hooks", whitelist_keys_set)
def model_patches_to(self, device: Any) -> Any:
return self._call_rpc("model_patches_to", device)
def partially_load(
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
) -> Any:
return self._call_rpc(
"partially_load", device, extra_memory, force_patch_weights
)
def partially_unload(
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
) -> int:
return self._call_rpc(
"partially_unload", device_to, memory_to_free, force_patch_weights
)
def load(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._call_rpc(
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
)
def patch_model(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
load_weights: bool = True,
force_patch_weights: bool = False,
) -> Any:
self._call_rpc(
"patch_model",
device_to,
lowvram_model_memory,
load_weights,
force_patch_weights,
)
return self
def unpatch_model(
self, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._call_rpc("unpatch_model", device_to, unpatch_weights)
def detach(self, unpatch_all: bool = True) -> Any:
self._call_rpc("detach", unpatch_all)
return self.model
def _cpu_tensor_bytes(self, obj: Any) -> int:
import torch
if isinstance(obj, torch.Tensor):
if obj.device.type == "cpu":
return obj.nbytes
return 0
if isinstance(obj, dict):
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
if isinstance(obj, (list, tuple)):
return sum(self._cpu_tensor_bytes(v) for v in obj)
return 0
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
if required_bytes <= 0:
return True
import torch
import comfy.model_management as model_management
target_raw = self.load_device
try:
if isinstance(target_raw, torch.device):
target = target_raw
elif isinstance(target_raw, str):
target = torch.device(target_raw)
elif isinstance(target_raw, int):
target = torch.device(f"cuda:{target_raw}")
else:
target = torch.device(target_raw)
except Exception:
return True
if target.type != "cuda":
return True
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
if model_management.get_free_memory(target) >= required:
return True
model_management.cleanup_models_gc()
model_management.cleanup_models()
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.free_memory(required, target, for_dynamic=True)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
model_management.free_memory(required, target, for_dynamic=False)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.load_models_gpu(
[self],
minimum_memory_required=required,
)
return model_management.get_free_memory(target) >= required
def apply_model(self, *args, **kwargs) -> Any:
import torch
def _preferred_device() -> Any:
for value in args:
if isinstance(value, torch.Tensor):
return value.device
for value in kwargs.values():
if isinstance(value, torch.Tensor):
return value.device
return None
def _move_result_to_device(obj: Any, device: Any) -> Any:
if device is None:
return obj
if isinstance(obj, torch.Tensor):
return obj.to(device) if obj.device != device else obj
if isinstance(obj, dict):
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
if isinstance(obj, list):
return [_move_result_to_device(v, device) for v in obj]
if isinstance(obj, tuple):
return tuple(_move_result_to_device(v, device) for v in obj)
return obj
# DynamicVRAM models must keep load/offload decisions in host process.
# Child-side CUDA staging here can deadlock before first inference RPC.
if self.is_dynamic():
out = self._call_rpc("inner_model_apply_model", args, kwargs)
return _move_result_to_device(out, _preferred_device())
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes)
def _to_cuda(obj: Any) -> Any:
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
return obj.to("cuda")
if isinstance(obj, dict):
return {k: _to_cuda(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cuda(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cuda(v) for v in obj)
return obj
try:
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
except torch.OutOfMemoryError:
self._ensure_apply_model_headroom(required_bytes)
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
return _move_result_to_device(out, _preferred_device())
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix)
return dict.fromkeys(keys, None)
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
res = self._call_rpc("add_patches", *args, **kwargs)
if isinstance(res, list):
return [tuple(x) if isinstance(x, list) else x for x in res]
return res
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
return self._call_rpc("get_key_patches", filter_prefix)
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
def pin_weight_to_device(self, key):
self._call_rpc("pin_weight_to_device", key)
def unpin_weight(self, key):
self._call_rpc("unpin_weight", key)
def unpin_all_weights(self):
self._call_rpc("unpin_all_weights")
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
return self._call_rpc(
"calculate_weight", patches, weight, key, intermediate_dtype
)
def inject_model(self) -> None:
self._call_rpc("inject_model")
def eject_model(self) -> None:
self._call_rpc("eject_model")
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
return AutoPatcherEjector(
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
)
@property
def is_injected(self) -> bool:
return self._call_rpc("get_is_injected")
@property
def skip_injection(self) -> bool:
return self._call_rpc("get_skip_injection")
@skip_injection.setter
def skip_injection(self, value: bool) -> None:
self._call_rpc("set_skip_injection", value)
def clean_hooks(self) -> None:
self._call_rpc("clean_hooks")
def pre_run(self) -> None:
self._call_rpc("pre_run")
def cleanup(self) -> None:
try:
self._call_rpc("cleanup")
except Exception:
logger.debug(
"ModelPatcherProxy cleanup RPC failed for %s",
self._instance_id,
exc_info=True,
)
finally:
super().cleanup()
@property
def model(self) -> _InnerModelProxy:
return _InnerModelProxy(self)
def __getattr__(self, name: str) -> Any:
_whitelisted_attrs = {
"hook_patches_backup",
"hook_backup",
"cached_hook_patches",
"current_hooks",
"forced_hooks",
"is_clip",
"patches_uuid",
"pinned",
"attachments",
"additional_models",
"injections",
"hook_patches",
"model_lowvram",
"model_loaded_weight_memory",
"backup",
"object_patches_backup",
"weight_wrapper_patches",
"weight_inplace_update",
"force_cast_weights",
}
if name in _whitelisted_attrs:
return self._call_rpc("get_patcher_attr", name)
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
def load_lora(
self,
lora_path: str,
strength_model: float,
clip: Optional[Any] = None,
strength_clip: float = 1.0,
) -> tuple:
clip_id = None
if clip is not None:
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
result = self._call_rpc(
"load_lora", lora_path, strength_model, clip_id, strength_clip
)
new_model = None
if result.get("model_id"):
new_model = ModelPatcherProxy(
result["model_id"],
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
new_clip = None
if result.get("clip_id"):
from comfy.isolation.clip_proxy import CLIPProxy
new_clip = CLIPProxy(result["clip_id"])
return (new_model, new_clip)
@property
def load_device(self) -> Any:
return self._call_rpc("get_load_device")
@property
def offload_device(self) -> Any:
return self._call_rpc("get_offload_device")
@property
def device(self) -> Any:
return self.load_device
def current_loaded_device(self) -> Any:
return self._call_rpc("current_loaded_device")
@property
def size(self) -> int:
return self._call_rpc("get_size")
def model_size(self) -> Any:
return self._call_rpc("model_size")
def loaded_size(self) -> Any:
return self._call_rpc("loaded_size")
def get_ram_usage(self) -> int:
return self._call_rpc("get_ram_usage")
def lowvram_patch_counter(self) -> int:
return self._call_rpc("lowvram_patch_counter")
def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape)
def get_operation_state(self) -> Dict[str, Any]:
state = self._call_rpc("get_operation_state")
return state if isinstance(state, dict) else {}
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
return bool(self._call_rpc("wait_for_idle", timeout_ms))
def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic"))
def get_free_memory(self, device: Any) -> Any:
return self._call_rpc("get_free_memory", device)
def partially_unload_ram(self, ram_to_unload: int) -> Any:
return self._call_rpc("partially_unload_ram", ram_to_unload)
def model_dtype(self) -> Any:
res = self._call_rpc("model_dtype")
if isinstance(res, str) and res.startswith("torch."):
try:
import torch
attr = res.split(".")[-1]
if hasattr(torch, attr):
return getattr(torch, attr)
except ImportError:
pass
return res
@property
def hook_mode(self) -> Any:
return self._call_rpc("get_hook_mode")
@hook_mode.setter
def hook_mode(self, value: Any) -> None:
self._call_rpc("set_hook_mode", value)
def set_model_sampler_cfg_function(
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_cfg_function",
sampler_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_post_cfg_function(
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_post_cfg_function",
post_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_pre_cfg_function(
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_pre_cfg_function",
pre_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
def set_model_patch(self, patch: Any, name: str) -> None:
self._call_rpc("set_model_patch", patch, name)
def set_model_patch_replace(
self,
patch: Any,
name: str,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self._call_rpc(
"set_model_patch_replace",
patch,
name,
block_name,
number,
transformer_index,
)
def set_model_attn1_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn1", block_name, number, transformer_index
)
def set_model_attn2_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn2", block_name, number, transformer_index
)
def set_model_attn1_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "output_block_patch")
def set_model_emb_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "emb_patch")
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "forward_timestep_embed_patch")
def set_model_double_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "double_block")
def set_model_post_input_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "post_input")
def set_model_rope_options(
self,
scale_x=1.0,
shift_x=0.0,
scale_y=1.0,
shift_y=0.0,
scale_t=1.0,
shift_t=0.0,
**kwargs: Any,
) -> None:
options = {
"scale_x": scale_x,
"shift_x": shift_x,
"scale_y": scale_y,
"shift_y": shift_y,
"scale_t": scale_t,
"shift_t": shift_t,
}
options.update(kwargs)
self._call_rpc("set_model_rope_options", options)
def set_model_compute_dtype(self, dtype: Any) -> None:
self._call_rpc("set_model_compute_dtype", dtype)
def add_object_patch(self, name: str, obj: Any) -> None:
self._call_rpc("add_object_patch", name, obj)
def add_weight_wrapper(self, name: str, function: Any) -> None:
self._call_rpc("add_weight_wrapper", name, function)
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
self.add_wrapper_with_key(wrapper_type, None, wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
@property
def wrappers(self) -> Any:
return self._call_rpc("get_wrappers")
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
self._call_rpc("add_callback_with_key", call_type, key, callback)
def add_callback(self, call_type: str, callback: Any) -> None:
self.add_callback_with_key(call_type, None, callback)
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
self._call_rpc("remove_callbacks_with_key", call_type, key)
@property
def callbacks(self) -> Any:
return self._call_rpc("get_callbacks")
def set_attachments(self, key: str, attachment: Any) -> None:
self._call_rpc("set_attachments", key, attachment)
def get_attachment(self, key: str) -> Any:
return self._call_rpc("get_attachment", key)
def remove_attachments(self, key: str) -> None:
self._call_rpc("remove_attachments", key)
def set_injections(self, key: str, injections: Any) -> None:
self._call_rpc("set_injections", key, injections)
def get_injections(self, key: str) -> Any:
return self._call_rpc("get_injections", key)
def remove_injections(self, key: str) -> None:
self._call_rpc("remove_injections", key)
def set_additional_models(self, key: str, models: Any) -> None:
ids = [m._instance_id for m in models]
self._call_rpc("set_additional_models", key, ids)
def remove_additional_models(self, key: str) -> None:
self._call_rpc("remove_additional_models", key)
def get_nested_additional_models(self) -> Any:
return self._call_rpc("get_nested_additional_models")
def get_additional_models(self) -> List[ModelPatcherProxy]:
ids = self._call_rpc("get_additional_models")
return [
ModelPatcherProxy(
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
for mid in ids
]
def model_patches_models(self) -> Any:
return self._call_rpc("model_patches_models")
@property
def parent(self) -> Any:
return self._call_rpc("get_parent")
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
self._parent = parent
self._model_sampling = None
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(name)
if name in (
"model_config",
"latent_format",
"model_type",
"current_weight_patches_uuid",
):
return self._parent._call_rpc("get_inner_model_attr", name)
if name == "load_device":
return self._parent._call_rpc("get_inner_model_attr", "load_device")
if name == "device":
return self._parent._call_rpc("get_inner_model_attr", "device")
if name == "current_patcher":
return ModelPatcherProxy(
self._parent._instance_id,
self._parent._registry,
manage_lifecycle=False,
)
if name == "model_sampling":
if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(
"get_model_object", "model_sampling"
)
return self._model_sampling
if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k
)
if name == "extra_conds":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds", a, k
)
if name == "memory_required":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_memory_required", a, k
)
if name == "apply_model":
# Delegate to parent's method to get the CPU->CUDA optimization
return self._parent.apply_model
if name == "process_latent_in":
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
if name == "process_latent_out":
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
if name == "scale_latent_inpaint":
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
if name == "diffusion_model":
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
raise AttributeError(f"'{name}' not supported on isolated InnerModel")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,156 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
# Isolation utilities and serializers for ModelPatcherProxy
from __future__ import annotations
import logging
import os
from typing import Any
from comfy.cli_args import args
logger = logging.getLogger(__name__)
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
isolation_active = args.use_process_isolation or is_child
if not isolation_active:
return model_patcher
if is_child:
return model_patcher
if isinstance(model_patcher, ModelPatcherProxy):
return model_patcher
registry = ModelPatcherRegistry()
model_id = registry.register(model_patcher)
logger.debug(f"Isolated ModelPatcher: {model_id}")
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
def register_hooks_serializers(registry=None):
from pyisolate._internal.serialization_registry import SerializerRegistry
import comfy.hooks
if registry is None:
registry = SerializerRegistry.get_instance()
def serialize_enum(obj):
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
def deserialize_enum(data):
cls_name, val_name = data["__enum__"].split(".")
cls = getattr(comfy.hooks, cls_name)
return cls[val_name]
registry.register("EnumHookType", serialize_enum, deserialize_enum)
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
def serialize_hook_group(obj):
return {"__type__": "HookGroup", "hooks": obj.hooks}
def deserialize_hook_group(data):
hg = comfy.hooks.HookGroup()
for h in data["hooks"]:
hg.add(h)
return hg
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
def serialize_dict_state(obj):
d = obj.__dict__.copy()
d["__type__"] = type(obj).__name__
if "custom_should_register" in d:
del d["custom_should_register"]
return d
def deserialize_dict_state_generic(cls):
def _deserialize(data):
h = cls()
h.__dict__.update(data)
return h
return _deserialize
def deserialize_hook_keyframe(data):
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
h.__dict__.update(data)
return h
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
def deserialize_hook_keyframe_group(data):
h = comfy.hooks.HookKeyframeGroup()
h.__dict__.update(data)
return h
registry.register(
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
)
def deserialize_hook(data):
h = comfy.hooks.Hook()
h.__dict__.update(data)
return h
registry.register("Hook", serialize_dict_state, deserialize_hook)
def deserialize_weight_hook(data):
h = comfy.hooks.WeightHook()
h.__dict__.update(data)
return h
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
def serialize_set(obj):
return {"__set__": list(obj)}
def deserialize_set(data):
return set(data["__set__"])
registry.register("set", serialize_set, deserialize_set)
try:
from comfy.weight_adapter.lora import LoRAAdapter
def serialize_lora(obj):
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
def deserialize_lora(data):
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
except Exception:
pass
try:
from comfy.hooks import _HookRef
import uuid
def serialize_hook_ref(obj):
return {
"__hook_ref__": True,
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
}
def deserialize_hook_ref(data):
h = _HookRef()
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
return h
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
except ImportError:
pass
except Exception as e:
logger.warning(f"Failed to register _HookRef: {e}")
try:
register_hooks_serializers()
except Exception as e:
logger.error(f"Failed to initialize hook serializers: {e}")

View File

@@ -0,0 +1,360 @@
# pylint: disable=import-outside-toplevel
from __future__ import annotations
import asyncio
import logging
import os
import threading
import time
from typing import Any
from comfy.isolation.proxies.base import (
BaseProxy,
BaseRegistry,
detach_if_grad,
get_thread_loop,
run_coro_in_new_loop,
)
logger = logging.getLogger(__name__)
def _describe_value(obj: Any) -> str:
try:
import torch
except Exception:
torch = None
try:
if torch is not None and isinstance(obj, torch.Tensor):
return (
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
)
except Exception:
pass
return "%s(id=%s)" % (type(obj).__name__, id(obj))
def _prefer_device(*tensors: Any) -> Any:
try:
import torch
except Exception:
return None
for t in tensors:
if isinstance(t, torch.Tensor) and t.is_cuda:
return t.device
for t in tensors:
if isinstance(t, torch.Tensor):
return t.device
return None
def _to_device(obj: Any, device: Any) -> Any:
try:
import torch
except Exception:
return obj
if device is None:
return obj
if isinstance(obj, torch.Tensor):
if obj.device != device:
return obj.to(device)
return obj
if isinstance(obj, (list, tuple)):
converted = [_to_device(x, device) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_device(v, device) for k, v in obj.items()}
return obj
def _to_cpu_for_rpc(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
t = obj.detach() if obj.requires_grad else obj
if t.is_cuda:
return t.to("cpu")
return t
if isinstance(obj, (list, tuple)):
converted = [_to_cpu_for_rpc(x) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling"
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.calculate_input(sigma, noise))
async def calculate_denoised(
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(
sampling.calculate_denoised(sigma, model_output, model_input)
)
async def noise_scaling(
self,
instance_id: str,
sigma: Any,
noise: Any,
latent_image: Any,
max_denoise: bool = False,
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
)
async def inverse_noise_scaling(
self, instance_id: str, sigma: Any, latent: Any
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
async def timestep(self, instance_id: str, sigma: Any) -> Any:
sampling = self._get_instance(instance_id)
return sampling.timestep(sigma)
async def sigma(self, instance_id: str, timestep: Any) -> Any:
sampling = self._get_instance(instance_id)
return sampling.sigma(timestep)
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
sampling = self._get_instance(instance_id)
return sampling.percent_to_sigma(percent)
async def get_sigma_min(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_min)
async def get_sigma_max(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_max)
async def get_sigma_data(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_data)
async def get_sigmas(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigmas)
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
sampling = self._get_instance(instance_id)
sampling.set_sigmas(sigmas)
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
_registry_class = ModelSamplingRegistry
__module__ = "comfy.isolation.model_sampling_proxy"
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is not None:
self._rpc_caller = rpc.create_caller(
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
)
else:
registry = ModelSamplingRegistry()
class _LocalCaller:
def calculate_input(
self, instance_id: str, sigma: Any, noise: Any
) -> Any:
return registry.calculate_input(instance_id, sigma, noise)
def calculate_denoised(
self,
instance_id: str,
sigma: Any,
model_output: Any,
model_input: Any,
) -> Any:
return registry.calculate_denoised(
instance_id, sigma, model_output, model_input
)
def noise_scaling(
self,
instance_id: str,
sigma: Any,
noise: Any,
latent_image: Any,
max_denoise: bool = False,
) -> Any:
return registry.noise_scaling(
instance_id, sigma, noise, latent_image, max_denoise
)
def inverse_noise_scaling(
self, instance_id: str, sigma: Any, latent: Any
) -> Any:
return registry.inverse_noise_scaling(
instance_id, sigma, latent
)
def timestep(self, instance_id: str, sigma: Any) -> Any:
return registry.timestep(instance_id, sigma)
def sigma(self, instance_id: str, timestep: Any) -> Any:
return registry.sigma(instance_id, timestep)
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
return registry.percent_to_sigma(instance_id, percent)
def get_sigma_min(self, instance_id: str) -> Any:
return registry.get_sigma_min(instance_id)
def get_sigma_max(self, instance_id: str) -> Any:
return registry.get_sigma_max(instance_id)
def get_sigma_data(self, instance_id: str) -> Any:
return registry.get_sigma_data(instance_id)
def get_sigmas(self, instance_id: str) -> Any:
return registry.get_sigmas(instance_id)
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
return registry.set_sigmas(instance_id, sigmas)
self._rpc_caller = _LocalCaller()
return self._rpc_caller
def _call(self, method_name: str, *args: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
result = method(self._instance_id, *args)
timeout_ms = self._rpc_timeout_ms()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
call_id = "%s:%s:%s:%.6f" % (
self._instance_id,
method_name,
thread_id,
start_perf,
)
logger.debug(
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
method_name,
self._instance_id,
call_id,
start_epoch,
thread_id,
timeout_ms,
)
if asyncio.iscoroutine(result):
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
try:
asyncio.get_running_loop()
out = run_coro_in_new_loop(result)
except RuntimeError:
loop = get_thread_loop()
out = loop.run_until_complete(result)
else:
out = result
logger.debug(
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
method_name,
self._instance_id,
call_id,
_describe_value(out),
)
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
method_name,
self._instance_id,
call_id,
elapsed_ms,
thread_id,
)
logger.debug(
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
method_name,
self._instance_id,
call_id,
)
return out
@staticmethod
def _rpc_timeout_ms() -> int:
raw = os.environ.get(
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
)
try:
timeout_ms = int(raw)
except ValueError:
timeout_ms = 30000
return max(1, timeout_ms)
@property
def sigma_min(self) -> Any:
return self._call("get_sigma_min")
@property
def sigma_max(self) -> Any:
return self._call("get_sigma_max")
@property
def sigma_data(self) -> Any:
return self._call("get_sigma_data")
@property
def sigmas(self) -> Any:
return self._call("get_sigmas")
def calculate_input(self, sigma: Any, noise: Any) -> Any:
return self._call("calculate_input", sigma, noise)
def calculate_denoised(
self, sigma: Any, model_output: Any, model_input: Any
) -> Any:
return self._call("calculate_denoised", sigma, model_output, model_input)
def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any:
preferred_device = _prefer_device(noise, latent_image)
out = self._call(
"noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(noise),
_to_cpu_for_rpc(latent_image),
max_denoise,
)
return _to_device(out, preferred_device)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
preferred_device = _prefer_device(latent)
out = self._call(
"inverse_noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(latent),
)
return _to_device(out, preferred_device)
def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma)
def sigma(self, timestep: Any) -> Any:
return self._call("sigma", timestep)
def percent_to_sigma(self, percent: float) -> Any:
return self._call("percent_to_sigma", percent)
def set_sigmas(self, sigmas: Any) -> None:
return self._call("set_sigmas", sigmas)

View File

@@ -0,0 +1,17 @@
from .base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
get_thread_loop,
run_coro_in_new_loop,
)
__all__ = [
"IS_CHILD_PROCESS",
"BaseRegistry",
"BaseProxy",
"get_thread_loop",
"run_coro_in_new_loop",
"detach_if_grad",
]

View File

@@ -0,0 +1,283 @@
# pylint: disable=global-statement,import-outside-toplevel,protected-access
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import os
import threading
import time
import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
try:
from pyisolate import ProxiedSingleton
except ImportError:
class ProxiedSingleton: # type: ignore[no-redef]
pass
logger = logging.getLogger(__name__)
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
_thread_local = threading.local()
T = TypeVar("T")
def get_thread_loop() -> asyncio.AbstractEventLoop:
loop = getattr(_thread_local, "loop", None)
if loop is None or loop.is_closed():
loop = asyncio.new_event_loop()
_thread_local.loop = loop
return loop
def run_coro_in_new_loop(coro: Any) -> Any:
result_box: Dict[str, Any] = {}
exc_box: Dict[str, BaseException] = {}
def runner() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result_box["value"] = loop.run_until_complete(coro)
except Exception as exc: # noqa: BLE001
exc_box["exc"] = exc
finally:
loop.close()
t = threading.Thread(target=runner, daemon=True)
t.start()
t.join()
if "exc" in exc_box:
raise exc_box["exc"]
return result_box.get("value")
def detach_if_grad(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
return obj.detach() if obj.requires_grad else obj
if isinstance(obj, (list, tuple)):
return type(obj)(detach_if_grad(x) for x in obj)
if isinstance(obj, dict):
return {k: detach_if_grad(v) for k, v in obj.items()}
return obj
class BaseRegistry(ProxiedSingleton, Generic[T]):
_type_prefix: str = "base"
def __init__(self) -> None:
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
super().__init__()
self._registry: Dict[str, T] = {}
self._id_map: Dict[int, str] = {}
self._counter = 0
self._lock = threading.Lock()
def register(self, instance: T) -> str:
with self._lock:
obj_id = id(instance)
if obj_id in self._id_map:
return self._id_map[obj_id]
instance_id = f"{self._type_prefix}_{self._counter}"
self._counter += 1
self._registry[instance_id] = instance
self._id_map[obj_id] = instance_id
return instance_id
def unregister_sync(self, instance_id: str) -> None:
with self._lock:
instance = self._registry.pop(instance_id, None)
if instance:
self._id_map.pop(id(instance), None)
def _get_instance(self, instance_id: str) -> T:
if IS_CHILD_PROCESS:
raise RuntimeError(
f"[{self.__class__.__name__}] _get_instance called in child"
)
with self._lock:
instance = self._registry.get(instance_id)
if instance is None:
raise ValueError(f"{instance_id} not found")
return instance
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
global _GLOBAL_LOOP
_GLOBAL_LOOP = loop
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
_TIMEOUT_RPC_METHODS = frozenset(
{
"partially_load",
"partially_unload",
"load",
"patch_model",
"unpatch_model",
"inner_model_apply_model",
"memory_required",
"model_dtype",
"inner_model_memory_required",
"inner_model_extra_conds_shapes",
"inner_model_extra_conds",
"process_latent_in",
"process_latent_out",
"scale_latent_inpaint",
}
)
def __init__(
self,
instance_id: str,
registry: Optional[Any] = None,
manage_lifecycle: bool = False,
) -> None:
self._instance_id = instance_id
self._rpc_caller: Optional[Any] = None
self._registry = registry if registry is not None else self._registry_class()
self._manage_lifecycle = manage_lifecycle
self._cleaned_up = False
if manage_lifecycle and not IS_CHILD_PROCESS:
self._finalizer = weakref.finalize(
self, self._registry.unregister_sync, instance_id
)
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is None:
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
return self._rpc_caller
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
if method_name not in self._TIMEOUT_RPC_METHODS:
return None
try:
timeout_ms = int(
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
)
except ValueError:
timeout_ms = 120000
return max(1, timeout_ms)
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
coro = method(self._instance_id, *args, **kwargs)
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
try:
running_loop = asyncio.get_running_loop()
loop_id: Optional[int] = id(running_loop)
except RuntimeError:
loop_id = None
logger.debug(
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
"thread=%s loop=%s timeout_ms=%s",
self.__class__.__name__,
method_name,
self._instance_id,
start_epoch,
thread_id,
loop_id,
timeout_ms,
)
try:
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
finally:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
"elapsed_ms=%.3f thread=%s loop=%s",
self.__class__.__name__,
method_name,
self._instance_id,
end_epoch,
elapsed_ms,
thread_id,
loop_id,
)
def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id}
def __setstate__(self, state: Dict[str, Any]) -> None:
self._instance_id = state["_instance_id"]
self._rpc_caller = None
self._registry = self._registry_class()
self._manage_lifecycle = False
self._cleaned_up = False
def cleanup(self) -> None:
if self._cleaned_up or IS_CHILD_PROCESS:
return
self._cleaned_up = True
finalizer = getattr(self, "_finalizer", None)
if finalizer is not None:
finalizer.detach()
self._registry.unregister_sync(self._instance_id)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._instance_id}>"
def create_rpc_method(method_name: str) -> Callable[..., Any]:
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
return self._call_rpc(method_name, *args, **kwargs)
method.__name__ = method_name
return method

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from typing import Dict
import folder_paths
from pyisolate import ProxiedSingleton
class FolderPathsProxy(ProxiedSingleton):
"""
Dynamic proxy for folder_paths.
Uses __getattr__ for most lookups, with explicit handling for
mutable collections to ensure efficient by-value transfer.
"""
def __getattr__(self, name):
return getattr(folder_paths, name)
# Return dict snapshots (avoid RPC chatter)
@property
def folder_names_and_paths(self) -> Dict:
return dict(folder_paths.folder_names_and_paths)
@property
def extension_mimetypes_cache(self) -> Dict:
return dict(folder_paths.extension_mimetypes_cache)
@property
def filename_list_cache(self) -> Dict:
return dict(folder_paths.filename_list_cache)

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Any, Dict, Optional
class AnyTypeProxy(str):
"""Replacement for custom AnyType objects used by some nodes."""
def __new__(cls, value: str = "*"):
return super().__new__(cls, value)
def __ne__(self, other): # type: ignore[override]
return False
class FlexibleOptionalInputProxy(dict):
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
super().__init__()
self.type = flex_type
if data:
self.update(data)
def __getitem__(self, key): # type: ignore[override]
return (self.type,)
def __contains__(self, key): # type: ignore[override]
return True
class ByPassTypeTupleProxy(tuple):
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
def __new__(cls, values):
return super().__new__(cls, values)
def __getitem__(self, index): # type: ignore[override]
if index >= len(self):
return AnyTypeProxy("*")
return super().__getitem__(index)
def _restore_special_value(value: Any) -> Any:
if isinstance(value, dict):
if value.get("__pyisolate_any_type__"):
return AnyTypeProxy(value.get("value", "*"))
if value.get("__pyisolate_flexible_optional__"):
flex_type = _restore_special_value(value.get("type"))
data_raw = value.get("data")
data = (
{k: _restore_special_value(v) for k, v in data_raw.items()}
if isinstance(data_raw, dict)
else {}
)
return FlexibleOptionalInputProxy(flex_type, data)
if value.get("__pyisolate_tuple__") is not None:
return tuple(
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
)
if value.get("__pyisolate_bypass_tuple__") is not None:
return ByPassTypeTupleProxy(
tuple(
_restore_special_value(v)
for v in value["__pyisolate_bypass_tuple__"]
)
)
return {k: _restore_special_value(v) for k, v in value.items()}
if isinstance(value, list):
return [_restore_special_value(v) for v in value]
return value
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
if not isinstance(raw, dict):
return raw # type: ignore[return-value]
restored: Dict[str, object] = {}
for section, entries in raw.items():
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
restored[section] = _restore_special_value(entries)
elif isinstance(entries, dict):
restored[section] = {
k: _restore_special_value(v) for k, v in entries.items()
}
else:
restored[section] = _restore_special_value(entries)
return restored
__all__ = [
"AnyTypeProxy",
"FlexibleOptionalInputProxy",
"ByPassTypeTupleProxy",
"restore_input_types",
]

View File

@@ -0,0 +1,27 @@
import comfy.model_management as mm
from pyisolate import ProxiedSingleton
class ModelManagementProxy(ProxiedSingleton):
"""
Dynamic proxy for comfy.model_management.
Uses __getattr__ to forward all calls to the underlying module,
reducing maintenance burden.
"""
# Explicitly expose Enums/Classes as properties
@property
def VRAMState(self):
return mm.VRAMState
@property
def CPUState(self):
return mm.CPUState
@property
def OOM_EXCEPTION(self):
return mm.OOM_EXCEPTION
def __getattr__(self, name):
"""Forward all other attribute access to the module."""
return getattr(mm, name)

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import logging
from typing import Any, Optional
try:
from pyisolate import ProxiedSingleton
except ImportError:
class ProxiedSingleton:
pass
from comfy_execution.progress import get_progress_state
logger = logging.getLogger(__name__)
class ProgressProxy(ProxiedSingleton):
def set_progress(
self,
value: float,
max_value: float,
node_id: Optional[str] = None,
image: Any = None,
) -> None:
get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=image,
)
__all__ = ["ProgressProxy"]

View File

@@ -0,0 +1,265 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
"""Stateless RPC Implementation for PromptServer.
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
- Host: PromptServerService (RPC Handler)
- Child: PromptServerStub (Interface Implementation)
"""
from __future__ import annotations
import asyncio
import os
from typing import Any, Dict, Optional, Callable
import logging
from aiohttp import web
# IMPORTS
from pyisolate import ProxiedSingleton
logger = logging.getLogger(__name__)
LOG_PREFIX = "[Isolation:C<->H]"
# ...
# =============================================================================
# CHILD SIDE: PromptServerStub
# =============================================================================
class PromptServerStub:
"""Stateless Stub for PromptServer."""
# Masquerade as the real server module
__module__ = "server"
_instance: Optional["PromptServerStub"] = None
_rpc: Optional[Any] = None # This will be the Caller object
_source_file: Optional[str] = None
def __init__(self):
self.routes = RouteStub(self)
@classmethod
def set_rpc(cls, rpc: Any) -> None:
"""Inject RPC client (called by adapter.py or manually)."""
# Create caller for HOST Service
# Assuming Host Service is registered as "PromptServerService" (class name)
# We target the Host Service Class
target_id = "PromptServerService"
# We need to pass a class to create_caller? Usually yes.
# But we don't have the Service class imported here necessarily (if running on child).
# pyisolate check verify_service type?
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
# We need a dummy class with right name?
# Or just rely on string ID if create_caller supports it?
# Standard: rpc.create_caller(PromptServerStub, target_id)
# But wait, PromptServerStub is the *Local* class.
# We want to call *Remote* class.
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
# The first arg is 'service_cls'.
cls._rpc = rpc.create_caller(
PromptServerService, target_id
) # We import Service below?
# We need PromptServerService available for the create_caller call?
# Or just use the Stub class if ID matches?
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
@property
def instance(self) -> "PromptServerStub":
return self
# ... Compatibility ...
@classmethod
def _get_source_file(cls) -> str:
if cls._source_file is None:
import folder_paths
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
return cls._source_file
@property
def __file__(self) -> str:
return self._get_source_file()
# --- Properties ---
@property
def client_id(self) -> Optional[str]:
return "isolated_client"
def supports(self, feature: str) -> bool:
return True
@property
def app(self):
raise RuntimeError(
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
)
@property
def prompt_queue(self):
raise RuntimeError(
"PromptServer.prompt_queue is not accessible in isolated nodes."
)
# --- UI Communication (RPC Delegates) ---
async def send_sync(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
) -> None:
if self._rpc:
await self._rpc.ui_send_sync(event, data, sid)
async def send(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
) -> None:
if self._rpc:
await self._rpc.ui_send(event, data, sid)
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
if self._rpc:
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
# We must schedule it?
# Or use fire_remote equivalent?
# Caller object usually proxies calls. If host method is async, it returns coro.
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
# But UtilsProxy hook wrapper creates task.
# Does send_progress_text need to be sync? Yes, node code calls it sync.
import asyncio
try:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
except RuntimeError:
pass # Sync context without loop?
# --- Route Registration Logic ---
def register_route(self, method: str, path: str, handler: Callable):
"""Register a route handler via RPC."""
if not self._rpc:
logger.error("RPC not initialized in PromptServerStub")
return
# Fire registration async
try:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
except RuntimeError:
pass
class RouteStub:
"""Simulates aiohttp.web.RouteTableDef."""
def __init__(self, stub: PromptServerStub):
self._stub = stub
def get(self, path: str):
def decorator(handler):
self._stub.register_route("GET", path, handler)
return handler
return decorator
def post(self, path: str):
def decorator(handler):
self._stub.register_route("POST", path, handler)
return handler
return decorator
def patch(self, path: str):
def decorator(handler):
self._stub.register_route("PATCH", path, handler)
return handler
return decorator
def put(self, path: str):
def decorator(handler):
self._stub.register_route("PUT", path, handler)
return handler
return decorator
def delete(self, path: str):
def decorator(handler):
self._stub.register_route("DELETE", path, handler)
return handler
return decorator
# =============================================================================
# HOST SIDE: PromptServerService
# =============================================================================
class PromptServerService(ProxiedSingleton):
"""Host-side RPC Service for PromptServer."""
def __init__(self):
# We will bind to the real server instance lazily or via global import
pass
@property
def server(self):
from server import PromptServer
return PromptServer.instance
async def ui_send_sync(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
):
await self.server.send_sync(event, data, sid)
async def ui_send(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
):
await self.server.send(event, data, sid)
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
# Made async to be awaitable by RPC layer
self.server.send_progress_text(text, node_id, sid)
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
"""RPC Target: Register a route that forwards to the Child."""
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
async def route_wrapper(request: web.Request) -> web.Response:
# 1. Capture request data
req_data = {
"method": request.method,
"path": request.path,
"query": dict(request.query),
}
if request.can_read_body:
req_data["text"] = await request.text()
try:
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
result = await child_handler_proxy(req_data)
# 3. Serialize Response
return self._serialize_response(result)
except Exception as e:
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
return web.Response(status=500, text=str(e))
# Register loop
self.server.app.router.add_route(method, path, route_wrapper)
def _serialize_response(self, result: Any) -> web.Response:
"""Helper to convert Child result -> web.Response"""
if isinstance(result, web.Response):
return result
# Handle dict (json)
if isinstance(result, dict):
return web.json_response(result)
# Handle string
if isinstance(result, str):
return web.Response(text=result)
# Fallback
return web.Response(text=str(result))

View File

@@ -0,0 +1,64 @@
# pylint: disable=cyclic-import,import-outside-toplevel
from __future__ import annotations
from typing import Optional, Any
import comfy.utils
from pyisolate import ProxiedSingleton
import os
class UtilsProxy(ProxiedSingleton):
"""
Proxy for comfy.utils.
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
from isolated nodes reach the host.
"""
# _instance and __new__ removed to rely on SingletonMetaclass
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
# Create caller using class name as ID (standard for Singletons)
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
async def progress_bar_hook(
self,
value: int,
total: int,
preview: Optional[bytes] = None,
node_id: Optional[str] = None,
) -> Any:
"""
Host-side implementation: forwards the call to the real global hook.
Child-side: this method call is intercepted by RPC and sent to host.
"""
if os.environ.get("PYISOLATE_CHILD") == "1":
# Manual RPC dispatch for Child process
# Use class-level RPC storage (Static Injection)
if UtilsProxy._rpc:
return await UtilsProxy._rpc.progress_bar_hook(
value, total, preview, node_id
)
# Fallback channel: global child rpc
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
get_child_rpc_instance()
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
# but we need a caller. For now, just pass to avoid crashing.
pass
except (ImportError, LookupError):
pass
return None
# Host Execution
if comfy.utils.PROGRESS_BAR_HOOK is not None:
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
def set_progress_bar_global_hook(self, hook: Any) -> None:
"""Forward hook registration (though usually not needed from child)."""
comfy.utils.set_progress_bar_global_hook(hook)

View File

@@ -0,0 +1,49 @@
import asyncio
import logging
import threading
logger = logging.getLogger(__name__)
class RpcBridge:
"""Minimal helper to run coroutines synchronously inside isolated processes.
If an event loop is already running, the coroutine is executed on a fresh
thread with its own loop to avoid nested run_until_complete errors.
"""
def run_sync(self, maybe_coro):
if not asyncio.iscoroutine(maybe_coro):
return maybe_coro
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
result_container = {}
exc_container = {}
def _runner():
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
result_container["value"] = new_loop.run_until_complete(maybe_coro)
except Exception as exc: # pragma: no cover
exc_container["error"] = exc
finally:
try:
new_loop.close()
except Exception:
pass
t = threading.Thread(target=_runner, daemon=True)
t.start()
t.join()
if "error" in exc_container:
raise exc_container["error"]
return result_container.get("value")
return asyncio.run(maybe_coro)

View File

@@ -0,0 +1,363 @@
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
from __future__ import annotations
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Set, TYPE_CHECKING
from .proxies.helper_proxies import restore_input_types
from comfy_api.internal import _ComfyNodeInternal
from comfy_api.latest import _io as latest_io
from .shm_forensics import scan_shm_forensics
if TYPE_CHECKING:
from .extension_wrapper import ComfyNodeExtension
LOG_PREFIX = "]["
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
def _resource_snapshot() -> Dict[str, int]:
fd_count = -1
shm_sender_files = 0
try:
fd_count = len(os.listdir("/proc/self/fd"))
except Exception:
pass
try:
shm_root = Path("/dev/shm")
if shm_root.exists():
prefix = f"torch_{os.getpid()}_"
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
except Exception:
pass
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
summary: Dict[str, int] = {
"tensor_count": 0,
"cpu_tensors": 0,
"cuda_tensors": 0,
"shared_cpu_tensors": 0,
"tensor_bytes": 0,
}
try:
import torch
except Exception:
return summary
def visit(node: Any) -> None:
if isinstance(node, torch.Tensor):
summary["tensor_count"] += 1
summary["tensor_bytes"] += int(node.numel() * node.element_size())
if node.device.type == "cpu":
summary["cpu_tensors"] += 1
if node.is_shared():
summary["shared_cpu_tensors"] += 1
elif node.device.type == "cuda":
summary["cuda_tensors"] += 1
return
if isinstance(node, dict):
for v in node.values():
visit(v)
return
if isinstance(node, (list, tuple)):
for v in node:
visit(v)
visit(value)
return summary
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
for key, value in inputs.items():
key_text = str(key)
if "unique_id" in key_text:
return str(value)
return None
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
return
if not callable(flush_tensor_keeper):
return
flushed = flush_tensor_keeper()
if flushed > 0:
logger.debug(
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
)
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
import comfy.model_management as model_management
model_management.cleanup_models_gc()
model_management.cleanup_models()
device = model_management.get_torch_device()
if not hasattr(device, "type") or device.type == "cpu":
return
required = max(
model_management.minimum_inference_memory(),
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=True)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=False)
model_management.cleanup_models()
model_management.soft_empty_cache()
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
def _detach_shared_cpu_tensors(value: Any) -> Any:
try:
import torch
except Exception:
return value
if isinstance(value, torch.Tensor):
if value.device.type == "cpu" and value.is_shared():
clone = value.clone()
if value.requires_grad:
clone.requires_grad_(True)
return clone
return value
if isinstance(value, list):
return [_detach_shared_cpu_tensors(v) for v in value]
if isinstance(value, tuple):
return tuple(_detach_shared_cpu_tensors(v) for v in value)
if isinstance(value, dict):
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
return value
def build_stub_class(
node_name: str,
info: Dict[str, object],
extension: "ComfyNodeExtension",
running_extensions: Dict[str, "ComfyNodeExtension"],
logger: logging.Logger,
) -> type:
is_v3 = bool(info.get("is_v3", False))
function_name = "_pyisolate_execute"
restored_input_types = restore_input_types(info.get("input_types", {}))
async def _execute(self, **inputs):
from comfy.isolation import _RUNNING_EXTENSIONS
# Update BOTH the local dict AND the module-level dict
running_extensions[extension.name] = extension
_RUNNING_EXTENSIONS[extension.name] = extension
prev_child = None
node_unique_id = _extract_hidden_unique_id(inputs)
summary = _tensor_transport_summary(inputs)
resources = _resource_snapshot()
logger.debug(
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
summary["tensor_count"],
summary["cpu_tensors"],
summary["cuda_tensors"],
summary["shared_cpu_tensors"],
summary["tensor_bytes"],
resources["fd_count"],
resources["shm_sender_files"],
)
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
try:
if os.environ.get("PYISOLATE_CHILD") != "1":
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
from pyisolate._internal.model_serialization import (
serialize_for_isolation,
deserialize_from_isolation,
)
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
logger.debug(
"%s ISO:serialize_start ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
serialized = serialize_for_isolation(inputs)
logger.debug(
"%s ISO:serialize_done ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
logger.debug(
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
result = await extension.execute_node(node_name, **serialized)
logger.debug(
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
# Reconstruct NodeOutput if the child serialized one
if isinstance(result, dict) and result.get("__node_output__"):
from comfy_api.latest import io as latest_io
args_raw = result.get("args", ())
deserialized_args = await deserialize_from_isolation(args_raw, extension)
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
return latest_io.NodeOutput(
*deserialized_args,
ui=result.get("ui"),
expand=result.get("expand"),
block_execution=result.get("block_execution"),
)
deserialized = await deserialize_from_isolation(result, extension)
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
return _detach_shared_cpu_tensors(deserialized)
except ImportError:
return await extension.execute_node(node_name, **inputs)
except Exception:
logger.exception(
"%s ISO:execute_error ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
raise
finally:
if prev_child is not None:
os.environ["PYISOLATE_CHILD"] = prev_child
logger.debug(
"%s ISO:execute_end ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
def _input_types(
cls,
include_hidden: bool = True,
return_schema: bool = False,
live_inputs: Any = None,
):
if not is_v3:
return restored_input_types
inputs_copy = copy.deepcopy(restored_input_types)
if not include_hidden:
inputs_copy.pop("hidden", None)
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
dynamic = inputs_copy.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
if return_schema:
hidden_vals = info.get("hidden", []) or []
hidden_enums = []
for h in hidden_vals:
try:
hidden_enums.append(latest_io.Hidden(h))
except Exception:
hidden_enums.append(h)
class SchemaProxy:
hidden = hidden_enums
return inputs_copy, SchemaProxy, v3_data
return inputs_copy
def _validate_class(cls):
return True
def _get_node_info_v1(cls):
node_info = copy.deepcopy(info.get("schema_v1", {}))
relative_python_module = node_info.get("python_module")
if not isinstance(relative_python_module, str) or not relative_python_module:
relative_python_module = f"custom_nodes.{extension.name}"
node_info["python_module"] = relative_python_module
return node_info
def _get_base_class(cls):
return latest_io.ComfyNode
attributes: Dict[str, object] = {
"FUNCTION": function_name,
"CATEGORY": info.get("category", ""),
"OUTPUT_NODE": info.get("output_node", False),
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
"RETURN_NAMES": info.get("return_names"),
function_name: _execute,
"_pyisolate_extension": extension,
"_pyisolate_node_name": node_name,
"INPUT_TYPES": classmethod(_input_types),
}
output_is_list = info.get("output_is_list")
if output_is_list is not None:
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
if is_v3:
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
attributes["DESCRIPTION"] = info.get("description", "")
attributes["EXPERIMENTAL"] = info.get("experimental", False)
attributes["DEPRECATED"] = info.get("deprecated", False)
attributes["API_NODE"] = info.get("api_node", False)
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
bases = (_ComfyNodeInternal,) if is_v3 else ()
stub_cls = type(class_name, bases, attributes)
if is_v3:
try:
stub_cls.VALIDATE_CLASS()
except Exception as e:
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
return stub_cls
def get_class_types_for_extension(
extension_name: str,
running_extensions: Dict[str, "ComfyNodeExtension"],
specs: List[Any],
) -> Set[str]:
extension = running_extensions.get(extension_name)
if not extension:
return set()
ext_path = Path(extension.module_path)
class_types = set()
for spec in specs:
if spec.module_path.resolve() == ext_path.resolve():
class_types.add(spec.node_name)
return class_types
__all__ = ["build_stub_class", "get_class_types_for_extension"]

View File

@@ -0,0 +1,217 @@
# pylint: disable=consider-using-from-import,import-outside-toplevel
from __future__ import annotations
import atexit
import hashlib
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Set
LOG_PREFIX = "]["
logger = logging.getLogger(__name__)
def _shm_debug_enabled() -> bool:
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
class _SHMForensicsTracker:
def __init__(self) -> None:
self._started = False
self._tracked_files: Set[str] = set()
self._current_model_context: Dict[str, str] = {
"id": "unknown",
"name": "unknown",
"hash": "????",
}
@staticmethod
def _snapshot_shm() -> Set[str]:
shm_path = Path("/dev/shm")
if not shm_path.exists():
return set()
return {f.name for f in shm_path.glob("torch_*")}
def start(self) -> None:
if self._started or not _shm_debug_enabled():
return
self._tracked_files = self._snapshot_shm()
self._started = True
logger.debug(
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
)
def stop(self) -> None:
if not self._started:
return
self.scan("shutdown", refresh_model_context=True)
self._started = False
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
def _compute_model_hash(self, model_patcher: Any) -> str:
try:
model_instance_id = getattr(model_patcher, "_instance_id", None)
if model_instance_id is not None:
model_id_text = str(model_instance_id)
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
import torch
real_model = (
model_patcher.model
if hasattr(model_patcher, "model")
else model_patcher
)
tensor = None
if hasattr(real_model, "parameters"):
for p in real_model.parameters():
if torch.is_tensor(p) and p.numel() > 0:
tensor = p
break
if tensor is None:
return "0000"
flat = tensor.flatten()
values = []
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
for i in indices:
if i < flat.shape[0]:
values.append(flat[i].item())
size = 0
if hasattr(model_patcher, "model_size"):
size = model_patcher.model_size()
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
except Exception:
return "err!"
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
try:
import comfy.model_management as model_management
except Exception:
return []
snapshot: List[Dict[str, Any]] = []
try:
for loaded_model in model_management.current_loaded_models:
model = loaded_model.model
if model is None:
continue
if str(getattr(loaded_model, "device", "")) != "cuda:0":
continue
name = (
model.model.__class__.__name__
if hasattr(model, "model")
else type(model).__name__
)
model_hash = self._compute_model_hash(model)
model_instance_id = getattr(model, "_instance_id", None)
if model_instance_id is None:
model_instance_id = model_hash
snapshot.append(
{
"name": str(name),
"id": str(model_instance_id),
"hash": str(model_hash or "????"),
"used": bool(getattr(loaded_model, "currently_used", False)),
}
)
except Exception:
return []
return snapshot
def _update_model_context(self) -> None:
snapshot = self._get_models_snapshot()
selected = None
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
if used_models:
selected = used_models[-1]
else:
live_models = [m for m in snapshot if m.get("id")]
if live_models:
selected = live_models[-1]
if selected is None:
self._current_model_context = {
"id": "unknown",
"name": "unknown",
"hash": "????",
}
return
self._current_model_context = {
"id": str(selected.get("id", "unknown")),
"name": str(selected.get("name", "unknown")),
"hash": str(selected.get("hash", "????") or "????"),
}
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
if not self._started or not _shm_debug_enabled():
return
if refresh_model_context:
self._update_model_context()
current = self._snapshot_shm()
added = current - self._tracked_files
removed = self._tracked_files - current
self._tracked_files = current
if not added and not removed:
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
return
for filename in sorted(added):
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
model_id = self._current_model_context["id"]
if model_id == "unknown":
logger.error(
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
LOG_PREFIX,
filename,
)
else:
logger.info(
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
LOG_PREFIX,
model_id,
filename,
self._current_model_context["name"],
self._current_model_context["hash"],
)
for filename in sorted(removed):
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
logger.debug(
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
LOG_PREFIX,
marker,
len(added),
len(removed),
len(self._tracked_files),
)
_TRACKER = _SHMForensicsTracker()
def start_shm_forensics() -> None:
_TRACKER.start()
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
def stop_shm_forensics() -> None:
_TRACKER.stop()
atexit.register(stop_shm_forensics)

View File

@@ -0,0 +1,214 @@
# pylint: disable=attribute-defined-outside-init
import logging
from typing import Any
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
)
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry
logger = logging.getLogger(__name__)
class FirstStageModelRegistry(BaseRegistry[Any]):
_type_prefix = "first_stage_model"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
async def has_property(self, instance_id: str, name: str) -> bool:
obj = self._get_instance(instance_id)
return hasattr(obj, name)
class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]):
_registry_class = FirstStageModelRegistry
__module__ = "comfy.ldm.models.autoencoder"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<FirstStageModelProxy {self._instance_id}>"
class VAERegistry(BaseRegistry[Any]):
_type_prefix = "vae"
async def get_patcher_id(self, instance_id: str) -> str:
vae = self._get_instance(instance_id)
return ModelPatcherRegistry().register(vae.patcher)
async def get_first_stage_model_id(self, instance_id: str) -> str:
vae = self._get_instance(instance_id)
return FirstStageModelRegistry().register(vae.first_stage_model)
async def encode(self, instance_id: str, pixels: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).encode(pixels))
async def encode_tiled(
self,
instance_id: str,
pixels: Any,
tile_x: int = 512,
tile_y: int = 512,
overlap: int = 64,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).encode_tiled(
pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap
)
)
async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs))
async def decode_tiled(
self,
instance_id: str,
samples: Any,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 16,
**kwargs: Any,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).decode_tiled(
samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs
)
)
async def get_property(self, instance_id: str, name: str) -> Any:
return getattr(self._get_instance(instance_id), name)
async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int:
return self._get_instance(instance_id).memory_used_encode(shape, dtype)
async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int:
return self._get_instance(instance_id).memory_used_decode(shape, dtype)
async def process_input(self, instance_id: str, image: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).process_input(image))
async def process_output(self, instance_id: str, image: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).process_output(image))
class VAEProxy(BaseProxy[VAERegistry]):
_registry_class = VAERegistry
__module__ = "comfy.sd"
@property
def patcher(self) -> ModelPatcherProxy:
if not hasattr(self, "_patcher_proxy"):
patcher_id = self._call_rpc("get_patcher_id")
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
return self._patcher_proxy
@property
def first_stage_model(self) -> FirstStageModelProxy:
if not hasattr(self, "_first_stage_model_proxy"):
fsm_id = self._call_rpc("get_first_stage_model_id")
self._first_stage_model_proxy = FirstStageModelProxy(
fsm_id, manage_lifecycle=False
)
return self._first_stage_model_proxy
@property
def vae_dtype(self) -> Any:
return self._get_property("vae_dtype")
def encode(self, pixels: Any) -> Any:
return self._call_rpc("encode", pixels)
def encode_tiled(
self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64
) -> Any:
return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap)
def decode(self, samples: Any, **kwargs: Any) -> Any:
return self._call_rpc("decode", samples, **kwargs)
def decode_tiled(
self,
samples: Any,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 16,
**kwargs: Any,
) -> Any:
return self._call_rpc(
"decode_tiled", samples, tile_x, tile_y, overlap, **kwargs
)
def get_sd(self) -> Any:
return self._call_rpc("get_sd")
def _get_property(self, name: str) -> Any:
return self._call_rpc("get_property", name)
@property
def latent_dim(self) -> int:
return self._get_property("latent_dim")
@property
def latent_channels(self) -> int:
return self._get_property("latent_channels")
@property
def downscale_ratio(self) -> Any:
return self._get_property("downscale_ratio")
@property
def upscale_ratio(self) -> Any:
return self._get_property("upscale_ratio")
@property
def output_channels(self) -> int:
return self._get_property("output_channels")
@property
def check_not_vide(self) -> bool:
return self._get_property("not_video")
@property
def device(self) -> Any:
return self._get_property("device")
@property
def working_dtypes(self) -> Any:
return self._get_property("working_dtypes")
@property
def disable_offload(self) -> bool:
return self._get_property("disable_offload")
@property
def size(self) -> Any:
return self._get_property("size")
def memory_used_encode(self, shape: Any, dtype: Any) -> int:
return self._call_rpc("memory_used_encode", shape, dtype)
def memory_used_decode(self, shape: Any, dtype: Any) -> int:
return self._call_rpc("memory_used_decode", shape, dtype)
def process_input(self, image: Any) -> Any:
return self._call_rpc("process_input", image)
def process_output(self, image: Any) -> Any:
return self._call_rpc("process_output", image)
if not IS_CHILD_PROCESS:
_VAE_REGISTRY_SINGLETON = VAERegistry()
_FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()

View File

@@ -1,4 +1,5 @@
import math
import os
from functools import partial
from scipy import integrate
@@ -12,8 +13,8 @@ from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
from comfy.cli_args import args
from comfy.utils import model_trange as trange
def append_zero(x):
@@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if isolation_active:
target_device = sigmas.device
if x.device != target_device:
x = x.to(target_device)
s_in = s_in.to(target_device)
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.

View File

@@ -136,7 +136,16 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

View File

@@ -44,22 +44,6 @@ class FluxParams:
txt_norm: bool = False
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -154,7 +138,6 @@ class Flux(nn.Module):
y: Tensor,
guidance: Tensor = None,
control = None,
timestep_zero_index=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
@@ -181,6 +164,10 @@ class Flux(nn.Module):
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
@@ -195,24 +182,6 @@ class Flux(nn.Module):
else:
pe = None
vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
@@ -226,8 +195,7 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
**extra_kwargs)
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
@@ -245,8 +213,7 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options,
**extra_kwargs)
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -263,12 +230,6 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@@ -281,8 +242,7 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
**extra_kwargs)
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
@@ -293,7 +253,7 @@ class Flux(nn.Module):
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -304,11 +264,7 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@@ -356,16 +312,13 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents:
if ref_latents_method in ("index", "index_timestep_zero"):
if ref_latents_method == "index":
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
@@ -389,13 +342,6 @@ class Flux(nn.Module):
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@@ -403,6 +349,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@@ -343,7 +343,6 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
low_precision_attention=False,
)
out = self.out_proj(x)
@@ -413,7 +412,6 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
low_precision_attention=False,
)
x = self.out_proj(x)

View File

@@ -23,11 +23,6 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
@@ -63,25 +58,16 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end
if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

View File

@@ -11,7 +11,6 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
@@ -233,7 +232,10 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
@@ -244,14 +246,10 @@ class Encoder(nn.Module):
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
if sample is None or sample.shape[2] == 0:
return None
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if sample is None or sample.shape[2] == 0:
return None
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
@@ -283,35 +281,9 @@ class Encoder(nn.Module):
return sample
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
outputs = []
samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1:
chunk_t = max(2, max_chunk_size // frame_size)
if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self)
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
output = self._forward_chunk(chunk)
if output is not None:
outputs.append(output)
return torch_cat_if_needed(outputs, dim=2)
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
@@ -324,23 +296,7 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module):
r"""
@@ -500,17 +456,6 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
ts, hs, ws, to = 1, 1, 1, 0
for block in self.up_blocks:
if isinstance(block, DepthToSpaceUpsample):
ts *= block.stride[0]
hs *= block.stride[1]
ws *= block.stride[2]
if block.stride[0] > 1:
to = to * block.stride[0] + 1
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
@@ -532,62 +477,11 @@ class Decoder(nn.Module):
)
def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return
up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
@@ -602,7 +496,6 @@ class Decoder(nn.Module):
)
timestep_shift_scale = None
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
@@ -630,18 +523,48 @@ class Decoder(nn.Module):
)
timestep_shift_scale = ada_values.unbind(dim=1)
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
output = []
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
return output_buffer
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def forward(self, *args, **kwargs):
try:
@@ -765,25 +688,12 @@ class SpaceToDepthDownsample(nn.Module):
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True):
tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None
if self.stride[0] == 2 and pad_first:
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
pad_first = False
if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None
# skip connection
x_in = rearrange(
@@ -798,26 +708,15 @@ class SpaceToDepthDownsample(nn.Module):
# conv
x = self.conv(x, causal=causal)
if self.stride[0] == 2 and x.shape[2] == 1:
if cached_x is not None:
x = torch_cat_if_needed([cached_x, x], dim=2)
cached_x = None
else:
cached_x = x
x = None
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
cached = add_exchange_cache(x, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
x = x + x_in
return x
@@ -1150,8 +1049,6 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
comfy_has_chunked_io = True
def __init__(self, version=0, config=None):
super().__init__()
@@ -1294,15 +1191,14 @@ class VideoVAE(nn.Module):
}
return config
def encode(self, x, device=None):
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)
def decode(self, x, output_buffer=None):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)

View File

@@ -99,7 +99,7 @@ class Resample(nn.Module):
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
@@ -109,7 +109,22 @@ class Resample(nn.Module):
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
@@ -130,24 +145,19 @@ class Resample(nn.Module):
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :]
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
deferred_x = feat_cache[idx + 1]
if deferred_x is not None:
x = torch.cat([deferred_x, x], 2)
feat_cache[idx + 1] = None
if x.shape[2] == 1 and not final:
feat_cache[idx + 1] = x
x = None
feat_idx[0] += 2
feat_idx[0] += 1
return x
@@ -167,12 +177,19 @@ class ResidualBlock(nn.Module):
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x, feat_cache=None, feat_idx=[0]):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -196,7 +213,7 @@ class AttentionBlock(nn.Module):
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
@@ -266,10 +283,17 @@ class Encoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -279,16 +303,14 @@ class Encoder3d(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, final=final)
if x is None:
return None
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, final=final)
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -296,7 +318,14 @@ class Encoder3d(nn.Module):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -360,48 +389,18 @@ class Decoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
for frame_idx in range(0, x.shape[2], 2):
self.run_up(
layer_idx + 1,
[x[:, :, frame_idx:frame_idx + 2, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
)
del x
return
next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -410,21 +409,42 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
out_chunks = []
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
return out_chunks
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_cache_layers(model):
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
if isinstance(m, CausalConv3d):
count += 1
return count
@@ -462,12 +482,11 @@ class WanVAE(nn.Module):
conv_idx = [0]
## cache
t = x.shape[2]
t = 1 + ((t - 1) // 4) * 4
iter_ = 1 + (t - 1) // 2
iter_ = 1 + (t - 1) // 4
feat_map = None
if iter_ > 1:
feat_map = [None] * count_cache_layers(self.encoder)
## 对encode输入的x按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
feat_map = [None] * count_conv3d(self.encoder)
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]
if i == 0:
@@ -477,23 +496,20 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx,
final=(i == (iter_ - 1)))
if out_ is None:
continue
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
return mu
def decode(self, z):
conv_idx = [0]
# z: [b,c,t,h,w]
iter_ = 1 + z.shape[2] // 2
iter_ = z.shape[2]
feat_map = None
if iter_ > 1:
feat_map = [None] * count_cache_layers(self.decoder)
feat_map = [None] * count_conv3d(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
@@ -504,8 +520,8 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
out += out_
return torch.cat(out, 2)
out = torch.cat([out, out_], 2)
return out

View File

@@ -1,71 +1,9 @@
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int
def read_tensor_file_slice_into(tensor, destination):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
return False
if info.size == 0:
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype

View File

@@ -20,8 +20,8 @@ import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import os
import comfy.ldm.lightricks.av_model
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
@@ -113,8 +113,20 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
from comfy.cli_args import args
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class ModelSampling(s, c):
pass
if isolation_runtime_enabled:
def __reduce__(self):
"""Ensure pickling yields a proxy instead of failing on local class."""
try:
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
registry = ModelSamplingRegistry()
ms_id = registry.register(self)
return (ModelSamplingProxy, (ms_id,))
except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
return ModelSampling(model_config)
@@ -286,12 +298,6 @@ class BaseModel(torch.nn.Module):
return data
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
Use comfy.context_windows.slice_cond() for common cases."""
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
@@ -1382,11 +1388,6 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@@ -1439,11 +1440,6 @@ class WAN21_HuMo(WAN21):
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@@ -1461,13 +1457,6 @@ class WAN22_Animate(WAN21):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@@ -1504,11 +1493,6 @@ class WAN22_S2V(WAN21):
out['reference_motion'] = reference_motion.shape
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@@ -372,7 +372,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
@@ -400,7 +400,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@@ -497,6 +497,9 @@ except:
current_loaded_models = []
def _isolation_mode_enabled():
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
def module_size(module):
module_mem = 0
sd = module.state_dict()
@@ -505,28 +508,6 @@ def module_size(module):
module_mem += t.nbytes
return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
class LoadedModel:
def __init__(self, model):
self._set_model(model)
@@ -541,7 +522,6 @@ class LoadedModel:
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
self._patcher_finalizer.atexit = False
def _switch_parent(self):
model = self._parent_model()
@@ -555,9 +535,6 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self):
return self.model.loaded_size()
@@ -588,7 +565,6 @@ class LoadedModel:
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
self.model_finalizer.atexit = False
return real_model
def should_reload_model(self, force_patch_weights=False):
@@ -603,8 +579,9 @@ class LoadedModel:
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
if self.model_finalizer is not None:
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
@@ -618,8 +595,15 @@ class LoadedModel:
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()
def dead_state(self):
model_ref_gone = self.model is None
real_model_ref = self.real_model
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
return model_ref_gone, real_model_ref_gone
def is_dead(self):
return self.real_model() is not None and self.model is None
model_ref_gone, real_model_ref_gone = self.dead_state()
return model_ref_gone or real_model_ref_gone
def use_more_memory(extra_memory, loaded_models, device):
@@ -660,11 +644,12 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
cleanup_models_gc()
unloaded_model = []
can_unload = []
unloaded_models = []
isolation_active = _isolation_mode_enabled()
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
@@ -673,14 +658,24 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
if can_unload and isolation_active:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
flush_tensor_keeper = None
if callable(flush_tensor_keeper):
flushed = flush_tensor_keeper()
if flushed > 0:
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
gc.collect()
for x in sorted(can_unload):
i = x[-1]
memory_to_free = 1e32
pins_to_free = 1e32
ram_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram()
ram_to_free = ram_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
@@ -689,21 +684,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
if ram_to_free > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
unloaded = current_loaded_models.pop(i)
model_obj = unloaded.model
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
unloaded_models.append(unloaded)
if len(unloaded_model) > 0:
soft_empty_cache()
@@ -762,31 +754,23 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for i in to_unload:
model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
if model_to_unload.model_finalizer is not None:
model_to_unload.model_finalizer.detach()
model_to_unload.model_finalizer = None
total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
#want to do.
#FIXME: This should subtract off the to_load current pin consumption.
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
for device in total_memory_required:
if device != torch.device("cpu"):
@@ -835,25 +819,62 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
if not _isolation_mode_enabled():
dead_found = False
for i in range(len(current_loaded_models)):
if current_loaded_models[i].is_dead():
dead_found = True
break
if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
return
dead_found = False
has_real_model_leak = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
if model_ref_gone or real_model_ref_gone:
dead_found = True
if real_model_ref_gone and not model_ref_gone:
has_real_model_leak = True
if do_gc:
if dead_found:
if has_real_model_leak:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
else:
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models)):
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
model_ref_gone, real_model_ref_gone = cur.dead_state()
if model_ref_gone or real_model_ref_gone:
if real_model_ref_gone and not model_ref_gone:
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
else:
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
def archive_model_dtypes(model):
@@ -867,11 +888,20 @@ def archive_model_dtypes(model):
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None:
real_model_ref = current_loaded_models[i].real_model
if real_model_ref is None:
to_delete = [i] + to_delete
continue
if callable(real_model_ref) and real_model_ref() is None:
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)
model_obj = getattr(x, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
del x
def dtype_size(dtype):
@@ -1052,12 +1082,6 @@ def intermediate_device():
else:
return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
@@ -1278,11 +1302,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0)
if tensor is None:
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking)
@@ -1720,19 +1739,6 @@ def supports_nvfp4_compute(device=None):
return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):

View File

@@ -297,9 +297,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
@@ -1066,10 +1063,6 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
def pinned_memory_size(self):
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload):
pass
@@ -1660,16 +1653,6 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:

View File

@@ -306,40 +306,10 @@ class CastWeightBiasOp:
bias_function = []
class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
super().__init__(in_features, out_features, bias, device, dtype)
return
@@ -360,21 +330,32 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.in_features, self.out_features),
bias_shape=(self.out_features,),
)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k,v in state_dict.items():
if k[prefix_len:] == "weight":
if not assign_to_params_buffers:
v = v.clone()
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self):
@@ -566,53 +547,6 @@ class disable_weight_init:
return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return
torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Keep shape/dtype visible for module introspection without reserving storage.
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
self.weight = torch.nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
requires_grad=False,
)
self.bias = None
self.weight_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)
def reset_parameters(self):
self.bias = None
return None
@@ -776,71 +710,6 @@ from .quant_ops import (
)
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
@@ -932,22 +801,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@@ -1035,37 +888,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
_use_quantized = (
getattr(self, 'layout_type', None) is not None and
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
len(self.weight_function) == 0 and len(self.bias_function) == 0):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@@ -1113,10 +939,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items():
if param is None:
continue
p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
@@ -1127,15 +950,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")

View File

@@ -1,7 +1,6 @@
import torch
import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
from comfy.cli_args import args
@@ -13,31 +12,18 @@ def pin_memory(module):
return
#FIXME: This is a RAM cache trigger event
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):
module._pin = pin
else:
module.pin_failed = True
return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
except RuntimeError:
module.pin_failed = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
comfy.model_management.unpin_memory(module._pin)
del module._pin
del module._pin_hostbuf
return size

View File

@@ -43,18 +43,6 @@ except ImportError as e:
def get_layout_class(name):
return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float
# ==============================================================================
@@ -96,31 +84,6 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@@ -174,8 +137,6 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@@ -196,14 +157,6 @@ QUANT_ALGOS = {
},
}
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ==============================================================================
# Re-exports for backward compatibility

View File

@@ -8,12 +8,12 @@ import comfy.nested_tensor
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
samples = samples.to(comfy.model_management.intermediate_device())
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
samples = samples.to(comfy.model_management.intermediate_device())
return samples

View File

@@ -11,12 +11,14 @@ from functools import partial
import collections
import math
import logging
import os
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.utils
from comfy.cli_args import args
import scipy.stats
import numpy
@@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
result = executor.execute(model, conds, x_in, timestep, model_options)
return result
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
if memory_required * 1.5 < free_memory:
to_batch = batch_amount
break
@@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
if isolation_active:
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
input_x = torch.cat(input_x).to(target_device)
else:
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if isolation_active:
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
else:
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
@@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
out_t = output[o]
mult_t = mult[o]
if isolation_active:
target_dev = out_conds[cond_index].device
if hasattr(out_t, "device") and out_t.device != target_dev:
out_t = out_t.to(target_dev)
if hasattr(mult_t, "device") and mult_t.device != target_dev:
mult_t = mult_t.to(target_dev)
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
out_conds[cond_index] += out_t * mult_t
out_counts[cond_index] += mult_t
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
@@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
out_c += out_t * mult_t
out_cts += mult_t
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
@@ -392,14 +413,31 @@ class KSamplerX0Inpaint:
self.inner_model = model
self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if denoise_mask is not None:
if isolation_active and denoise_mask.device != x.device:
denoise_mask = denoise_mask.to(x.device)
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
if isolation_active:
latent_image = self.latent_image
if hasattr(latent_image, "device") and latent_image.device != x.device:
latent_image = latent_image.to(x.device)
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
if hasattr(scaled, "device") and scaled.device != x.device:
scaled = scaled.to(x.device)
else:
scaled = self.inner_model.inner_model.scale_latent_inpaint(
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
)
x = x * denoise_mask + scaled * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
latent_image = self.latent_image
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
latent_image = latent_image.to(out.device)
out = out * denoise_mask + latent_image * latent_mask
return out
def simple_scheduler(model_sampling, steps):
@@ -741,7 +779,11 @@ class KSAMPLER(Sampler):
else:
model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
max_denoise = self.max_denoise(model_wrap, sigmas)
model_sampling = model_wrap.inner_model.model_sampling
noise = model_sampling.noise_scaling(
sigmas[0], noise, latent_image, max_denoise
)
k_callback = None
total_steps = len(sigmas) - 1
@@ -985,8 +1027,8 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
@@ -1028,7 +1070,6 @@ class CFGGuider:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
denoise_mask = denoise_mask.float()
self.conds = {}
for k in self.original_conds:

View File

@@ -455,7 +455,7 @@ class VAE:
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
@@ -871,16 +871,13 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@@ -890,16 +887,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@@ -908,7 +905,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@@ -917,7 +914,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
@@ -926,7 +923,7 @@ class VAE:
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
@@ -935,7 +932,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}):
@@ -951,23 +948,12 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
@@ -978,7 +964,6 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -1039,15 +1024,10 @@ class VAE:
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except Exception as e:
@@ -1060,7 +1040,6 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4

View File

@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2]
if pooled is not None:
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)

View File

@@ -20,8 +20,6 @@
import torch
import math
import struct
import ctypes
import os
import comfy.memory_management
import safetensors.torch
import numpy as np
@@ -34,7 +32,7 @@ from einops import rearrange
from comfy.cli_args import args
import json
import time
import threading
import mmap
import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
@@ -83,17 +81,14 @@ _TYPES = {
}
def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb")
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
f = open(ckpt, "rb", buffering=0)
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
header_size = struct.unpack("<Q", mv[:8])[0]
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
mv = mv[(data_base_offset := 8 + header_size):]
mv = mv[8 + header_size:]
sd = {}
for name, info in header.items():
@@ -107,14 +102,7 @@ def load_safetensors(ckpt):
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
return sd, header.get("__metadata__", {}),
@@ -897,10 +885,6 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):
@@ -1135,8 +1119,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pbar.update(1)
continue
out = output[b:b+1].zero_()
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
@@ -1151,7 +1135,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
mask = torch.ones_like(ps)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
@@ -1174,7 +1158,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if pbar is not None:
pbar.update(1)
out.div_(out_div)
output[b:b+1] = out/out_div
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):

View File

@@ -25,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -85,36 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs
across ComfyUI instances.
Example::
from comfy_api.latest import Caching
class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
... # check external storage
async def on_store(self, context, value):
... # store to external storage
Caching.register_provider(MyCacheProvider())
"""
from ._caching import CacheProvider, CacheContext, CacheValue
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Register an external cache provider. Providers are called in registration order."""
from comfy_execution.cache_provider import register_cache_provider
register_cache_provider(provider)
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Unregister a previously registered cache provider."""
from comfy_execution.cache_provider import unregister_cache_provider
unregister_cache_provider(provider)
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -147,9 +116,6 @@ class Types:
VOXEL = VOXEL
File3D = File3D
Caching = ComfyAPI_latest.Caching
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@@ -169,7 +135,6 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"Caching",
"ComfyExtension",
"io",
"IO",

View File

@@ -1,42 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
@@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if self.__duration and frame.time > start_time + self.__duration:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:

View File

@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, SVG as _SVG, File3D
from ._util import MESH, VOXEL, SVG as _SVG, File3D, PLY as _PLY, NPZ as _NPZ
class FolderType(str, Enum):
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -678,6 +678,16 @@ class Mesh(ComfyTypeIO):
Type = MESH
@comfytype(io_type="PLY")
class Ply(ComfyTypeIO):
Type = _PLY
@comfytype(io_type="NPZ")
class Npz(ComfyTypeIO):
Type = _NPZ
@comfytype(io_type="FILE_3D")
class File3DAny(ComfyTypeIO):
"""General 3D file type - accepts any supported 3D format."""
@@ -2197,6 +2207,8 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
"Ply",
"Npz",
"File3DAny",
"File3DGLB",
"File3DGLTF",

View File

@@ -1,6 +1,8 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH, File3D
from .image_types import SVG
from .ply_types import PLY
from .npz_types import NPZ
__all__ = [
# Utility Types
@@ -11,4 +13,6 @@ __all__ = [
"MESH",
"File3D",
"SVG",
"PLY",
"NPZ",
]

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
import os
class NPZ:
"""Ordered collection of NPZ file payloads.
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
``save_to`` writes numbered files into a directory.
"""
def __init__(self, frames: list[bytes]) -> None:
self.frames = frames
@property
def num_frames(self) -> int:
return len(self.frames)
def save_to(self, directory: str, prefix: str = "frame") -> str:
os.makedirs(directory, exist_ok=True)
for i, frame_bytes in enumerate(self.frames):
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
with open(path, "wb") as f:
f.write(frame_bytes)
return directory

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
import numpy as np
class PLY:
"""Point cloud payload for PLY file output.
Supports two schemas:
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
When ``raw_data`` is provided, the object acts as an opaque binary PLY
carrier and ``save_to`` writes the bytes directly.
"""
def __init__(
self,
points: np.ndarray | None = None,
colors: np.ndarray | None = None,
confidence: np.ndarray | None = None,
view_id: np.ndarray | None = None,
raw_data: bytes | None = None,
) -> None:
self.raw_data = raw_data
if raw_data is not None:
self.points = None
self.colors = None
self.confidence = None
self.view_id = None
return
if points is None:
raise ValueError("Either points or raw_data must be provided")
if points.ndim != 2 or points.shape[1] != 3:
raise ValueError(f"points must be (N, 3), got {points.shape}")
self.points = np.ascontiguousarray(points, dtype=np.float32)
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
@property
def is_gaussian(self) -> bool:
return self.raw_data is not None
@property
def num_points(self) -> int:
if self.points is not None:
return self.points.shape[0]
return 0
@staticmethod
def _to_numpy(arr, dtype):
if arr is None:
return None
if hasattr(arr, "numpy"):
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
return np.ascontiguousarray(arr, dtype=dtype)
def save_to(self, path: str) -> str:
if self.raw_data is not None:
with open(path, "wb") as f:
f.write(self.raw_data)
return path
self.points = self._to_numpy(self.points, np.float32)
self.colors = self._to_numpy(self.colors, np.float32)
self.confidence = self._to_numpy(self.confidence, np.float32)
self.view_id = self._to_numpy(self.view_id, np.int32)
N = self.num_points
header_lines = [
"ply",
"format ascii 1.0",
f"element vertex {N}",
"property float x",
"property float y",
"property float z",
]
if self.colors is not None:
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
if self.confidence is not None:
header_lines.append("property float confidence")
if self.view_id is not None:
header_lines.append("property int view_id")
header_lines.append("end_header")
with open(path, "w") as f:
f.write("\n".join(header_lines) + "\n")
for i in range(N):
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
if self.colors is not None:
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
parts.append(f"{r} {g} {b}")
if self.confidence is not None:
parts.append(f"{self.confidence[i]}")
if self.view_id is not None:
parts.append(f"{int(self.view_id[i])}")
f.write(" ".join(parts) + "\n")
return path

View File

@@ -67,7 +67,6 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
thought: bool | None = Field(None)
class GeminiTextPart(BaseModel):

View File

@@ -29,21 +29,13 @@ class ImageEditRequest(BaseModel):
class VideoGenerationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
image: InputUrlObject | None = Field(None)
reference_images: list[InputUrlObject] | None = Field(None)
image: InputUrlObject | None = Field(...)
duration: int = Field(...)
aspect_ratio: str | None = Field(...)
resolution: str = Field(...)
seed: int = Field(...)
class VideoExtensionRequest(BaseModel):
prompt: str = Field(...)
video: InputUrlObject = Field(...)
duration: int = Field(default=6)
model: str | None = Field(default=None)
class VideoEditRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)

View File

@@ -1,43 +0,0 @@
from pydantic import BaseModel, Field
class QuiverImageObject(BaseModel):
url: str = Field(...)
class QuiverTextToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
prompt: str = Field(...)
instructions: str | None = Field(default=None)
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverImageToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
image: QuiverImageObject = Field(...)
auto_crop: bool | None = Field(default=None)
target_size: int | None = Field(default=None, ge=128, le=4096)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverSVGResponseItem(BaseModel):
svg: str = Field(...)
mime_type: str | None = Field(default="image/svg+xml")
class QuiverSVGUsage(BaseModel):
total_tokens: int | None = Field(default=None)
input_tokens: int | None = Field(default=None)
output_tokens: int | None = Field(default=None)
class QuiverSVGResponse(BaseModel):
id: str | None = Field(default=None)
created: int | None = Field(default=None)
data: list[QuiverSVGResponseItem] = Field(...)
usage: QuiverSVGUsage | None = Field(default=None)

View File

@@ -47,10 +47,6 @@ SEEDREAM_MODELS = {
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
logger = logging.getLogger(__name__)
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
@@ -139,7 +135,6 @@ class ByteDanceImageNode(IO.ComfyNode):
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
is_deprecated=True,
)
@classmethod
@@ -947,7 +942,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
]
return await process_video_task(
cls,
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
payload=Image2VideoTaskCreationRequest(model=model, content=x),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -957,12 +952,6 @@ async def process_video_task(
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
if payload.model in DEPRECATED_MODELS:
logger.warning(
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
payload.model,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),

View File

@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
$m := widgets.model;
$r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2");
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
@@ -188,12 +188,10 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/*")
for part in parts:
if (part.thought is True) != thought:
continue
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
@@ -933,11 +931,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
outputs=[
IO.Image.Output(),
IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -999,11 +992,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):

View File

@@ -8,7 +8,6 @@ from comfy_api_nodes.apis.grok import (
ImageGenerationResponse,
InputUrlObject,
VideoEditRequest,
VideoExtensionRequest,
VideoGenerationRequest,
VideoGenerationResponse,
VideoStatusResponse,
@@ -22,7 +21,6 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
tensor_to_base64_string,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
@@ -35,13 +33,6 @@ def _extract_grok_price(response) -> float | None:
return None
def _extract_grok_video_price(response) -> float | None:
price = _extract_grok_price(response)
if price is not None:
return price * 1.43
return None
class GrokImageNode(IO.ComfyNode):
@classmethod
@@ -363,8 +354,6 @@ class GrokVideoNode(IO.ComfyNode):
seed: int,
image: Input.Image | None = None,
) -> IO.NodeOutput:
if model == "grok-imagine-video-beta":
model = "grok-imagine-video"
image_url = None
if image is not None:
if get_number_of_images(image) != 1:
@@ -473,244 +462,6 @@ class GrokVideoEditNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokVideoReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video",
category="api node/video/Grok",
description="Generate video guided by reference images as style and content references.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of the desired video.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="reference_",
min=1,
max=7,
),
tooltip="Up to 7 reference images to guide the video generation.",
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="The resolution of the output video.",
),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
tooltip="The aspect ratio of the output video.",
),
IO.Int.Input(
"duration",
default=6,
min=2,
max=10,
step=1,
tooltip="The duration of the output video in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model.duration", "model.resolution"],
input_groups=["model.reference_images"],
),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$refs := inputGroups["model.reference_images"];
$rate := $res = "720p" ? 0.07 : 0.05;
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
{"type":"usd","usd": $price}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
ref_image_urls = await upload_images_to_comfyapi(
cls,
list(model["reference_images"].values()),
mime_type="image/png",
wait_label="Uploading base images",
max_images=7,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest(
model=model["model"],
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
prompt=prompt,
resolution=model["resolution"],
duration=model["duration"],
aspect_ratio=model["aspect_ratio"],
seed=seed,
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokVideoExtendNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoExtendNode",
display_name="Grok Video Extend",
category="api node/video/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of what should happen next in the video.",
),
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Int.Input(
"duration",
default=8,
min=2,
max=10,
step=1,
tooltip="Length of the extension in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video extension.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
expr="""
(
$dur := $lookup(widgets, "model.duration");
{
"type": "range_usd",
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
"max_usd": (0.15 + 0.05 * $dur) * 1.43
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
video: Input.Video,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
validate_video_duration(video, min_duration=2, max_duration=15)
video_size = get_fs_object_size(video.get_stream_source())
if video_size > 50 * 1024 * 1024:
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
data=VideoExtensionRequest(
prompt=prompt,
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
duration=model["duration"],
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -718,9 +469,7 @@ class GrokExtension(ComfyExtension):
GrokImageNode,
GrokImageEditNode,
GrokVideoNode,
GrokVideoReferenceNode,
GrokVideoEditNode,
GrokVideoExtendNode,
]

View File

@@ -1,7 +1,3 @@
import zipfile
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, Types
@@ -21,10 +17,7 @@ from comfy_api_nodes.apis.hunyuan3d import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
download_url_to_bytesio,
download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
@@ -43,68 +36,6 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
)
class ObjZipResult:
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
def __init__(
self,
obj: Types.File3D,
texture: Input.Image | None = None,
metallic: Input.Image | None = None,
normal: Input.Image | None = None,
roughness: Input.Image | None = None,
):
self.obj = obj
self.texture = texture
self.metallic = metallic
self.normal = normal
self.roughness = roughness
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
identified by their filename suffixes.
"""
data = BytesIO()
await download_url_to_bytesio(url, data)
data.seek(0)
if not zipfile.is_zipfile(data):
data.seek(0)
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
data.seek(0)
obj_bytes = None
textures: dict[str, Input.Image] = {}
with zipfile.ZipFile(data) as zf:
for name in zf.namelist():
lower = name.lower()
if lower.endswith(".obj"):
obj_bytes = zf.read(name)
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
stem = lower.rsplit(".", 1)[0]
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
matched_key = "texture"
for suffix, key in {
"_metallic": "metallic",
"_normal": "normal",
"_roughness": "roughness",
}.items():
if stem.endswith(suffix):
matched_key = key
break
textures[matched_key] = tensor
if obj_bytes is None:
raise ValueError("ZIP archive does not contain an OBJ file.")
return ObjZipResult(
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
texture=textures.get("texture"),
metallic=textures.get("metallic"),
normal=textures.get("normal"),
roughness=textures.get("roughness"),
)
def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None:
@@ -162,7 +93,6 @@ class TencentTextToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -221,14 +151,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
@@ -281,10 +211,6 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
IO.Image.Output(display_name="optional_metallic"),
IO.Image.Output(display_name="optional_normal"),
IO.Image.Output(display_name="optional_roughness"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -378,17 +304,14 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
@@ -508,8 +431,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
],
outputs=[
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -558,8 +480,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
@@ -733,7 +654,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
TencentTextToModelNode,
TencentImageToModelNode,
TencentModelTo3DUVNode,
Tencent3DTextureEditNode,
# Tencent3DTextureEditNode,
Tencent3DPartNode,
TencentSmartTopologyNode,
]

View File

@@ -1459,7 +1459,6 @@ class OmniProEditVideoNode(IO.ComfyNode):
node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),

View File

@@ -1,291 +0,0 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.quiver import (
QuiverImageObject,
QuiverImageToSVGRequest,
QuiverSVGResponse,
QuiverTextToSVGRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
sync_op,
upload_image_to_comfyapi,
validate_string,
)
from comfy_extras.nodes_images import SVG
class QuiverTextToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG",
category="api node/image/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired SVG output.",
),
IO.String.Input(
"instructions",
multiline=True,
default="",
tooltip="Additional style or formatting guidance.",
optional=True,
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="ref_",
min=0,
max=4,
),
tooltip="Up to 4 reference images to guide the generation.",
optional=True,
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
instructions: str = None,
reference_images: IO.Autogrow.Type = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
references = None
if reference_images:
references = []
for key in reference_images:
url = await upload_image_to_comfyapi(cls, reference_images[key])
references.append(QuiverImageObject(url=url))
if len(references) > 4:
raise ValueError("Maximum 4 reference images are allowed.")
instructions_val = instructions.strip() if instructions else None
if instructions_val == "":
instructions_val = None
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverTextToSVGRequest(
model=model["model"],
prompt=prompt,
instructions=instructions_val,
references=references,
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverImageToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG",
category="api node/image/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[
IO.Image.Input(
"image",
tooltip="Input image to vectorize.",
),
IO.Boolean.Input(
"auto_crop",
default=False,
tooltip="Automatically crop to the dominant subject.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Int.Input(
"target_size",
default=1024,
min=128,
max=4096,
tooltip="Square resize target in pixels.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG vectorization.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
image,
auto_crop: bool,
model: dict,
seed: int,
) -> IO.NodeOutput:
image_url = await upload_image_to_comfyapi(cls, image)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverImageToSVGRequest(
model=model["model"],
image=QuiverImageObject(url=image_url),
auto_crop=auto_crop if auto_crop else None,
target_size=model.get("target_size"),
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
QuiverTextToSVGNode,
QuiverImageToSVGNode,
]
async def comfy_entrypoint() -> QuiverExtension:
return QuiverExtension()

View File

@@ -833,7 +833,6 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image",
category="api node/image/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.",
inputs=[
IO.Image.Input("image"),

View File

@@ -38,7 +38,6 @@ from comfy_api_nodes.util import (
UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}

View File

@@ -1,138 +0,0 @@
from typing import Any, Optional, Tuple, List
import hashlib
import json
import logging
import threading
# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
_logger = logging.getLogger(__name__)
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Tuple[CacheProvider, ...] = ()
def register_cache_provider(provider: CacheProvider) -> None:
"""Register an external cache provider. Providers are called in registration order."""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot
def _has_cache_providers() -> bool:
return bool(_providers_snapshot)
def _clear_cache_providers() -> None:
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = ()
def _canonicalize(obj: Any) -> Any:
# Convert to canonical JSON-serializable form with deterministic ordering.
# Frozensets have non-deterministic iteration order between Python sessions.
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
# _serialize_cache_key returns None and external caching is skipped.
if isinstance(obj, frozenset):
return ("__frozenset__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, set):
return ("__set__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, tuple):
return ("__tuple__", [_canonicalize(item) for item in obj])
elif isinstance(obj, list):
return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict):
return {"__dict__": sorted(
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
key=lambda x: json.dumps(x, sort_keys=True)
)}
elif isinstance(obj, (int, float, str, bool, type(None))):
return (type(obj).__name__, obj)
elif isinstance(obj, bytes):
return ("__bytes__", obj.hex())
else:
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
# Returns deterministic SHA256 hex digest, or None on failure.
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e:
_logger.warning(f"Failed to serialize cache key: {e}")
return None
def _contains_self_unequal(obj: Any) -> bool:
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
# never hit locally, but serialized form would match externally. Skip these.
try:
if not (obj == obj):
return True
except Exception:
return True
if isinstance(obj, (frozenset, tuple, list, set)):
return any(_contains_self_unequal(item) for item in obj)
if isinstance(obj, dict):
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
if hasattr(obj, 'value'):
return _contains_self_unequal(obj.value)
return False
def _estimate_value_size(value: CacheValue) -> int:
try:
import torch
except ImportError:
return 0
total = 0
def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)
for output in value.outputs:
estimate(output)
return total

View File

@@ -1,4 +1,3 @@
import asyncio
import bisect
import gc
import itertools
@@ -148,15 +147,13 @@ class CacheKeySetInputSignature(CacheKeySet):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
def __init__(self, key_class, enable_providers=False):
def __init__(self, key_class):
self.key_class = key_class
self.initialized = False
self.enable_providers = enable_providers
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@@ -199,138 +196,18 @@ class BasicCache:
def poll(self, **kwargs):
pass
def get_local(self, node_id):
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
return None
def set_local(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
async def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
await self._notify_providers_store(node_id, cache_key, value)
async def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result
return external_result
return None
async def _notify_providers_store(self, node_id, cache_key, value):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return
if not _has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if _contains_self_unequal(cache_key):
return
context = self._build_context(node_id, cache_key)
if context is None:
return
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in _get_cache_providers():
try:
if provider.should_cache(context, cache_value):
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return None
if not _has_cache_providers():
return None
if _contains_self_unequal(cache_key):
return None
context = self._build_context(node_id, cache_key)
if context is None:
return None
for provider in _get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = await provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
from execution import CacheEntry
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
def _is_external_cacheable_value(self, value):
return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id):
if not self.initialized or not self.dynprompt:
return ''
try:
return self.dynprompt.get_node(node_id).get('class_type', '')
except Exception:
return ''
def _build_context(self, node_id, cache_key):
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
try:
cache_key_hash = _serialize_cache_key(cache_key)
if cache_key_hash is None:
return None
return CacheContext(
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=cache_key_hash,
)
except Exception as e:
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
else:
return None
async def _ensure_subcache(self, node_id, children_ids):
@@ -359,8 +236,8 @@ class BasicCache:
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, enable_providers=enable_providers)
def __init__(self, key_class):
super().__init__(key_class)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
@@ -380,27 +257,16 @@ class HierarchicalCache(BasicCache):
return None
return cache
async def get(self, node_id):
def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return await cache._get_immediate(node_id)
return cache._get_immediate(node_id)
def get_local(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return BasicCache.get_local(cache, node_id)
async def set(self, node_id, value):
def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
await cache._set_immediate(node_id, value)
def set_local(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
BasicCache.set_local(cache, node_id, value)
cache._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -421,24 +287,18 @@ class NullCache:
def poll(self, **kwargs):
pass
async def get(self, node_id):
def get(self, node_id):
return None
def get_local(self, node_id):
return None
async def set(self, node_id, value):
pass
def set_local(self, node_id, value):
def set(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
return self
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100, enable_providers=False):
super().__init__(key_class, enable_providers=enable_providers)
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
@@ -462,18 +322,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
async def get(self, node_id):
def get(self, node_id):
self._mark_used(node_id)
return await self._get_immediate(node_id)
return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
async def set(self, node_id, value):
def set(self, node_id, value):
self._mark_used(node_id)
return await self._set_immediate(node_id, value)
return self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -506,20 +366,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, 0, enable_providers=enable_providers)
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
async def set(self, node_id, value):
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
await super().set(node_id, value)
super().set(node_id, value)
async def get(self, node_id):
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id)
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():

View File

@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {}
def is_cached(self, node_id):
return self.output_cache.get_local(node_id) is not None
return self.output_cache.get(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set_local(from_node_id, value)
self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value):

View File

@@ -19,7 +19,6 @@ class EmptyLatentAudio(IO.ComfyNode):
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="latent/audio",
essentials_category="Audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
@@ -186,7 +185,6 @@ class SaveAudioMP3(IO.ComfyNode):
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
category="audio",
essentials_category="Audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),

View File

@@ -3,7 +3,6 @@ from typing_extensions import override
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
import torch
class Canny(io.ComfyNode):
@@ -30,8 +29,8 @@ class Canny(io.ComfyNode):
@classmethod
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return io.NodeOutput(img_out)

View File

@@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),

View File

@@ -6,7 +6,6 @@ import comfy.model_management
import torch
import math
import nodes
import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
@@ -232,68 +231,6 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
class KV_Attn_Input:
def __init__(self):
self.cache = {}
def __call__(self, q, k, v, extra_options, **kwargs):
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
if len(reference_image_num_tokens) == 0:
return {}
ref_toks = sum(reference_image_num_tokens)
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
if cache_key in self.cache:
kk, vv = self.cache[cache_key]
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
self.set_cache = True
return {"q": q, "k": k, "v": v}
def cleanup(self):
self.cache = {}
class FluxKVCache(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FluxKVCache",
display_name="Flux KV Cache",
description="Enables KV Cache optimization for reference images on Flux family models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to use KV Cache on."),
],
outputs=[
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
m = model.clone()
input_patch_obj = KV_Attn_Input()
def model_input_patch(inputs):
if len(input_patch_obj.cache) > 0:
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
if ref_image_tokens > 0:
img = inputs["img"]
inputs["img"] = img[:, :-ref_image_tokens]
return inputs
m.set_model_attn1_patch(input_patch_obj)
m.set_model_post_input_patch(model_input_patch)
if hasattr(model.model.diffusion_model, "params"):
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
else:
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
return io.NodeOutput(m)
class FluxExtension(ComfyExtension):
@override
@@ -306,7 +243,6 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
FluxKVCache,
]

View File

@@ -14,7 +14,6 @@ class ImageCompare(IO.ComfyNode):
display_name="Image Compare",
description="Compares two images side by side with a slider.",
category="image",
essentials_category="Image Tools",
is_experimental=True,
is_output_node=True,
inputs=[

View File

@@ -58,7 +58,6 @@ class ImageCropV2(IO.ComfyNode):
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
essentials_category="Image Tools",
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@@ -1,127 +0,0 @@
from __future__ import annotations
import hashlib
import os
import numpy as np
import torch
from PIL import Image
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io, UI
from typing_extensions import override
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0.0, 0.0, 0.0)
r = int(hex_color[0:2], 16) / 255.0
g = int(hex_color[2:4], 16) / 255.0
b = int(hex_color[4:6], 16) / 255.0
return (r, g, b)
class PainterNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Painter",
display_name="Painter",
category="image",
inputs=[
io.Image.Input(
"image",
optional=True,
tooltip="Optional base image to paint over",
),
io.String.Input(
"mask",
default="",
socketless=True,
extra_dict={"widgetType": "PAINTER", "image_upload": True},
),
io.Int.Input(
"width",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Int.Input(
"height",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Color.Input("bg_color", default="#000000"),
],
outputs=[
io.Image.Output("IMAGE"),
io.Mask.Output("MASK"),
],
)
@classmethod
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
if image is not None:
base_image = image[:1]
h, w = base_image.shape[1], base_image.shape[2]
else:
h, w = height, width
r, g, b = hex_to_rgb(bg_color)
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
base_image[0, :, :, 0] = r
base_image[0, :, :, 1] = g
base_image[0, :, :, 2] = b
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
painter_img = node_helpers.pillow(Image.open, mask_path)
painter_img = painter_img.convert("RGBA")
if painter_img.size != (w, h):
painter_img = painter_img.resize((w, h), Image.LANCZOS)
painter_np = np.array(painter_img).astype(np.float32) / 255.0
painter_rgb = painter_np[:, :, :3]
painter_alpha = painter_np[:, :, 3:4]
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
base_np = base_image[0].cpu().numpy()
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
out_image = torch.from_numpy(composited).unsqueeze(0)
else:
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
out_image = base_image
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
@classmethod
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
if os.path.exists(mask_path):
m = hashlib.sha256()
with open(mask_path, "rb") as f:
m.update(f.read())
return m.digest().hex()
return ""
class PainterExtension(ComfyExtension):
@override
async def get_node_list(self):
return [PainterNode]
async def comfy_entrypoint():
return PainterExtension()

Some files were not shown because too many files have changed in this diff Show More