mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-27 07:57:31 +00:00
Compare commits
15 Commits
release/v0
...
pyisolate-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c02372936d | ||
|
|
6aa0b838a0 | ||
|
|
54461f9ecc | ||
|
|
b602cc4533 | ||
|
|
08b92a48c3 | ||
|
|
c5e7b9cdaf | ||
|
|
623a9d21e9 | ||
|
|
9250191c65 | ||
|
|
a0f8784e9f | ||
|
|
7962db477a | ||
|
|
3c8ba051b6 | ||
|
|
a1c3124821 | ||
|
|
9ca799362d | ||
|
|
22f5e43c12 | ||
|
|
3cfd5e3311 |
103
.github/scripts/check-ai-co-authors.sh
vendored
103
.github/scripts/check-ai-co-authors.sh
vendored
@@ -1,103 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Checks pull request commits for AI agent Co-authored-by trailers.
|
||||
# Exits non-zero when any are found and prints fix instructions.
|
||||
set -euo pipefail
|
||||
|
||||
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
|
||||
# Known AI coding-agent trailer patterns (case-insensitive).
|
||||
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
|
||||
AGENT_PATTERNS=(
|
||||
# Anthropic — Claude Code / Amp
|
||||
'noreply@anthropic\.com'
|
||||
# Cursor
|
||||
'cursoragent@cursor\.com'
|
||||
# GitHub Copilot
|
||||
'copilot-swe-agent\[bot\]'
|
||||
'copilot@github\.com'
|
||||
# OpenAI Codex
|
||||
'noreply@openai\.com'
|
||||
'codex@openai\.com'
|
||||
# Aider
|
||||
'aider@aider\.chat'
|
||||
# Google — Gemini / Jules
|
||||
'gemini@google\.com'
|
||||
'jules@google\.com'
|
||||
# Windsurf / Codeium
|
||||
'@codeium\.com'
|
||||
# Devin
|
||||
'devin-ai-integration\[bot\]'
|
||||
'devin@cognition\.ai'
|
||||
'devin@cognition-labs\.com'
|
||||
# Amazon Q Developer
|
||||
'amazon-q-developer'
|
||||
'@amazon\.com.*[Qq].[Dd]eveloper'
|
||||
# Cline
|
||||
'cline-bot'
|
||||
'cline@cline\.ai'
|
||||
# Continue
|
||||
'continue-agent'
|
||||
'continue@continue\.dev'
|
||||
# Sourcegraph
|
||||
'noreply@sourcegraph\.com'
|
||||
# Generic catch-alls for common agent name patterns
|
||||
'Co-authored-by:.*\b[Cc]laude\b'
|
||||
'Co-authored-by:.*\b[Cc]opilot\b'
|
||||
'Co-authored-by:.*\b[Cc]ursor\b'
|
||||
'Co-authored-by:.*\b[Cc]odex\b'
|
||||
'Co-authored-by:.*\b[Gg]emini\b'
|
||||
'Co-authored-by:.*\b[Aa]ider\b'
|
||||
'Co-authored-by:.*\b[Dd]evin\b'
|
||||
'Co-authored-by:.*\b[Ww]indsurf\b'
|
||||
'Co-authored-by:.*\b[Cc]line\b'
|
||||
'Co-authored-by:.*\b[Aa]mazon Q\b'
|
||||
'Co-authored-by:.*\b[Jj]ules\b'
|
||||
'Co-authored-by:.*\bOpenCode\b'
|
||||
)
|
||||
|
||||
# Build a single alternation regex from all patterns.
|
||||
regex=""
|
||||
for pattern in "${AGENT_PATTERNS[@]}"; do
|
||||
if [[ -n "$regex" ]]; then
|
||||
regex="${regex}|${pattern}"
|
||||
else
|
||||
regex="$pattern"
|
||||
fi
|
||||
done
|
||||
|
||||
# Collect Co-authored-by lines from every commit in the PR range.
|
||||
violations=""
|
||||
while IFS= read -r sha; do
|
||||
message="$(git log -1 --format='%B' "$sha")"
|
||||
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
|
||||
if [[ -z "$matched_lines" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
while IFS= read -r line; do
|
||||
if echo "$line" | grep -iqE "$regex"; then
|
||||
short="$(git log -1 --format='%h' "$sha")"
|
||||
violations="${violations} ${short}: ${line}"$'\n'
|
||||
fi
|
||||
done <<< "$matched_lines"
|
||||
done < <(git rev-list "${base_sha}..${head_sha}")
|
||||
|
||||
if [[ -n "$violations" ]]; then
|
||||
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
|
||||
echo ""
|
||||
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
|
||||
echo ""
|
||||
echo "$violations"
|
||||
echo "These trailers should be removed before merging."
|
||||
echo ""
|
||||
echo "To fix, rewrite the commit messages with:"
|
||||
echo " git rebase -i ${base_sha}"
|
||||
echo ""
|
||||
echo "and remove the Co-authored-by lines, then force-push your branch."
|
||||
echo ""
|
||||
echo "If you believe this is a false positive, please open an issue."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No AI agent Co-authored-by trailers found."
|
||||
19
.github/workflows/check-ai-co-authors.yml
vendored
19
.github/workflows/check-ai-co-authors.yml
vendored
@@ -1,19 +0,0 @@
|
||||
name: Check AI Co-Authors
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*']
|
||||
|
||||
jobs:
|
||||
check-ai-co-authors:
|
||||
name: Check for AI agent co-author trailers
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check commits for AI co-author trailers
|
||||
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -24,3 +24,4 @@ web_custom_versions/
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
.pyisolate_venvs/
|
||||
|
||||
11
README.md
11
README.md
@@ -38,8 +38,6 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
|
||||
## Get Started
|
||||
|
||||
### Local
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
@@ -51,13 +49,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
#### [Manual Install](#manual-install-windows-linux)
|
||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||
|
||||
### Cloud
|
||||
|
||||
#### [Comfy Cloud](https://www.comfy.org/cloud)
|
||||
- Our official paid cloud version for those who can't afford local hardware.
|
||||
|
||||
## Examples
|
||||
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
|
||||
@@ -8,7 +8,7 @@ from alembic import context
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base, NAMING_CONVENTION
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
@@ -51,10 +51,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
Add system_metadata and job_id columns to asset_references.
|
||||
Change preview_id FK from assets.id to asset_references.id.
|
||||
|
||||
Revision ID: 0003_add_metadata_job_id
|
||||
Revises: 0002_merge_to_asset_references
|
||||
Create Date: 2026-03-09
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.database.models import NAMING_CONVENTION
|
||||
|
||||
revision = "0003_add_metadata_job_id"
|
||||
down_revision = "0002_merge_to_asset_references"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("system_metadata", sa.JSON(), nullable=True)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column("job_id", sa.String(length=36), nullable=True)
|
||||
)
|
||||
|
||||
# Change preview_id FK from assets.id to asset_references.id (self-ref).
|
||||
# Existing values are asset-content IDs that won't match reference IDs,
|
||||
# so null them out first.
|
||||
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_assets", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_asset_references",
|
||||
"asset_references",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
batch_op.create_index(
|
||||
"ix_asset_references_preview_id", ["preview_id"]
|
||||
)
|
||||
|
||||
# Purge any all-null meta rows before adding the constraint
|
||||
op.execute(
|
||||
"DELETE FROM asset_reference_meta"
|
||||
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
|
||||
)
|
||||
with op.batch_alter_table("asset_reference_meta") as batch_op:
|
||||
batch_op.create_check_constraint(
|
||||
"ck_asset_reference_meta_has_value",
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# SQLite doesn't reflect CHECK constraints, so we must declare it
|
||||
# explicitly via table_args for the batch recreate to find it.
|
||||
# Use the fully-rendered constraint name to avoid the naming convention
|
||||
# doubling the prefix.
|
||||
with op.batch_alter_table(
|
||||
"asset_reference_meta",
|
||||
table_args=[
|
||||
sa.CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="ck_asset_reference_meta_has_value",
|
||||
),
|
||||
],
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"ck_asset_reference_meta_has_value", type_="check"
|
||||
)
|
||||
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_index("ix_asset_references_preview_id")
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_assets",
|
||||
"assets",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.drop_column("job_id")
|
||||
batch_op.drop_column("system_metadata")
|
||||
@@ -13,7 +13,6 @@ from pydantic import ValidationError
|
||||
import folder_paths
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
from app.assets.services import schemas
|
||||
from app.assets.api.schemas_in import (
|
||||
AssetValidationError,
|
||||
UploadError,
|
||||
@@ -39,7 +38,6 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@@ -124,61 +122,6 @@ def _validate_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
|
||||
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
|
||||
if not user_metadata:
|
||||
return None
|
||||
filename = user_metadata.get("filename")
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
if "input" in tags:
|
||||
view_type = "input"
|
||||
elif "output" in tags:
|
||||
view_type = "output"
|
||||
else:
|
||||
return None
|
||||
|
||||
subfolder = ""
|
||||
if "/" in filename:
|
||||
subfolder, filename = filename.rsplit("/", 1)
|
||||
|
||||
encoded_filename = urllib.parse.quote(filename, safe="")
|
||||
url = f"/api/view?type={view_type}&filename={encoded_filename}"
|
||||
if subfolder:
|
||||
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
|
||||
return url
|
||||
|
||||
|
||||
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
|
||||
"""Build an Asset response from a service result."""
|
||||
if result.ref.preview_id:
|
||||
preview_detail = get_asset_detail(result.ref.preview_id)
|
||||
if preview_detail:
|
||||
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
|
||||
else:
|
||||
preview_url = None
|
||||
else:
|
||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
preview_url=preview_url,
|
||||
preview_id=result.ref.preview_id,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
@_require_assets_feature_enabled
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
@@ -221,7 +164,20 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
summaries = [
|
||||
schemas_out.AssetSummary(
|
||||
id=item.ref.id,
|
||||
name=item.ref.name,
|
||||
asset_hash=item.asset.hash if item.asset else None,
|
||||
size=int(item.asset.size_bytes) if item.asset else None,
|
||||
mime_type=item.asset.mime_type if item.asset else None,
|
||||
tags=item.tags,
|
||||
created_at=item.ref.created_at,
|
||||
updated_at=item.ref.updated_at,
|
||||
last_access_time=item.ref.last_access_time,
|
||||
)
|
||||
for item in result.items
|
||||
]
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
@@ -251,7 +207,18 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
||||
{"id": reference_id},
|
||||
)
|
||||
|
||||
payload = _build_asset_response(result)
|
||||
payload = schemas_out.AssetDetail(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
)
|
||||
except ValueError as e:
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
||||
@@ -263,7 +230,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
@@ -345,31 +312,32 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||
)
|
||||
|
||||
# Derive name from hash if not provided
|
||||
name = body.name
|
||||
if name is None:
|
||||
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
|
||||
|
||||
result = create_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=name,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
mime_type=body.mime_type,
|
||||
preview_id=body.preview_id,
|
||||
)
|
||||
if result is None:
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
||||
)
|
||||
|
||||
asset = _build_asset_response(result)
|
||||
payload_out = schemas_out.AssetCreated(
|
||||
**asset.model_dump(),
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
created_new=result.created_new,
|
||||
)
|
||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
|
||||
return web.json_response(payload_out.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
@@ -390,8 +358,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
"name": parsed.provided_name,
|
||||
"user_metadata": parsed.user_metadata_raw,
|
||||
"hash": parsed.provided_hash,
|
||||
"mime_type": parsed.provided_mime_type,
|
||||
"preview_id": parsed.provided_preview_id,
|
||||
}
|
||||
)
|
||||
except ValidationError as ve:
|
||||
@@ -420,8 +386,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
mime_type=spec.mime_type,
|
||||
preview_id=spec.preview_id,
|
||||
)
|
||||
if result is None:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
@@ -446,8 +410,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
client_filename=parsed.file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_hash=spec.hash,
|
||||
mime_type=spec.mime_type,
|
||||
preview_id=spec.preview_id,
|
||||
)
|
||||
except AssetValidationError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
@@ -466,13 +428,21 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
asset = _build_asset_response(result)
|
||||
payload_out = schemas_out.AssetCreated(
|
||||
**asset.model_dump(),
|
||||
payload = schemas_out.AssetCreated(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
created_new=result.created_new,
|
||||
)
|
||||
status = 201 if result.created_new else 200
|
||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=status)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -494,9 +464,15 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
preview_id=body.preview_id,
|
||||
)
|
||||
payload = _build_asset_response(result)
|
||||
payload = schemas_out.AssetUpdated(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
updated_at=result.ref.updated_at,
|
||||
)
|
||||
except PermissionError as pe:
|
||||
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
||||
except ValueError as ve:
|
||||
@@ -510,7 +486,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -579,7 +555,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
payload = schemas_out.TagsList(
|
||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@@ -627,7 +603,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@@ -674,29 +650,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/tags/refine")
|
||||
@_require_assets_feature_enabled
|
||||
async def get_tags_refine(request: web.Request) -> web.Response:
|
||||
"""GET request to get tag histogram for filtered assets."""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _build_validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
tag_counts = list_tag_histogram(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
)
|
||||
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
|
||||
@@ -45,8 +45,6 @@ class ParsedUpload:
|
||||
user_metadata_raw: str | None
|
||||
provided_hash: str | None
|
||||
provided_hash_exists: bool | None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
@@ -100,17 +98,11 @@ class ListAssetsQuery(BaseModel):
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_at_least_one_field(self):
|
||||
if all(
|
||||
v is None
|
||||
for v in (self.name, self.user_metadata, self.preview_id)
|
||||
):
|
||||
raise ValueError(
|
||||
"Provide at least one of: name, user_metadata, preview_id."
|
||||
)
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
@@ -118,11 +110,9 @@ class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str | None = None
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
mime_type: str | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
@@ -148,44 +138,6 @@ class CreateFromHashBody(BaseModel):
|
||||
return []
|
||||
|
||||
|
||||
class TagsRefineQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
limit: conint(ge=1, le=1000) = 100
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@@ -234,25 +186,21 @@ class TagsRemove(TagsAdd):
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
|
||||
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||
- mime_type: optional MIME type override
|
||||
- preview_id: optional asset_reference ID for preview
|
||||
|
||||
Files are stored using the content hash as filename stem.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
mime_type: str | None = Field(default=None)
|
||||
preview_id: str | None = Field(default=None) # references an asset_reference id
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
@@ -331,7 +279,7 @@ class UploadAssetSpec(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("at least one tag is required for uploads")
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
|
||||
@@ -4,10 +4,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
|
||||
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
@@ -15,14 +12,8 @@ class Asset(BaseModel):
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
is_immutable: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
prompt_id: str | None = None # deprecated: use job_id
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -32,16 +23,50 @@ class Asset(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(Asset):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[Asset]
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@@ -66,7 +91,3 @@ class TagsRemove(BaseModel):
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagHistogram(BaseModel):
|
||||
tag_counts: dict[str, int]
|
||||
|
||||
@@ -52,8 +52,6 @@ async def parse_multipart_upload(
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
@@ -130,16 +128,6 @@ async def parse_multipart_upload(
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
elif fname == "id":
|
||||
raise UploadError(
|
||||
400,
|
||||
"UNSUPPORTED_FIELD",
|
||||
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
|
||||
)
|
||||
elif fname == "mime_type":
|
||||
provided_mime_type = ((await field.text()) or "").strip() or None
|
||||
elif fname == "preview_id":
|
||||
provided_preview_id = ((await field.text()) or "").strip() or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
@@ -164,8 +152,6 @@ async def parse_multipart_upload(
|
||||
user_metadata_raw=user_metadata_raw,
|
||||
provided_hash=provided_hash,
|
||||
provided_hash_exists=provided_hash_exists,
|
||||
provided_mime_type=provided_mime_type,
|
||||
provided_preview_id=provided_preview_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -45,7 +45,13 @@ class Asset(Base):
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# preview_id on AssetReference is a self-referential FK to asset_references.id
|
||||
preview_of: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
||||
foreign_keys=lambda: [AssetReference.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
@@ -85,15 +91,11 @@ class AssetReference(Base):
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
|
||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
||||
)
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True)
|
||||
)
|
||||
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True), nullable=True, default=None
|
||||
)
|
||||
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
@@ -113,10 +115,10 @@ class AssetReference(Base):
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_ref: Mapped[AssetReference | None] = relationship(
|
||||
"AssetReference",
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
remote_side=lambda: [AssetReference.id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||
@@ -150,7 +152,6 @@ class AssetReference(Base):
|
||||
Index("ix_asset_references_created_at", "created_at"),
|
||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||
Index("ix_asset_references_preview_id", "preview_id"),
|
||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||
CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
@@ -191,10 +192,6 @@ class AssetReferenceMeta(Base):
|
||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||
CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="has_value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,21 +31,16 @@ from app.assets.database.queries.asset_reference import (
|
||||
get_unenriched_references,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
insert_reference,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
list_references_page,
|
||||
mark_references_missing_outside_prefixes,
|
||||
rebuild_metadata_projection,
|
||||
reference_exists,
|
||||
reference_exists_for_asset_id,
|
||||
restore_references_by_paths,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_system_metadata,
|
||||
soft_delete_reference_by_id,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_is_missing_by_asset_id,
|
||||
update_reference_timestamps,
|
||||
update_reference_updated_at,
|
||||
upsert_reference,
|
||||
@@ -59,7 +54,6 @@ from app.assets.database.queries.tags import (
|
||||
bulk_insert_tags_and_meta,
|
||||
ensure_tags_exist,
|
||||
get_reference_tags,
|
||||
list_tag_counts_for_filtered_assets,
|
||||
list_tags_with_usage,
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_reference,
|
||||
@@ -103,26 +97,20 @@ __all__ = [
|
||||
"get_unenriched_references",
|
||||
"get_unreferenced_unhashed_asset_ids",
|
||||
"insert_reference",
|
||||
"list_all_file_paths_by_asset_id",
|
||||
"list_references_by_asset_id",
|
||||
"list_references_page",
|
||||
"list_tag_counts_for_filtered_assets",
|
||||
"list_tags_with_usage",
|
||||
"mark_references_missing_outside_prefixes",
|
||||
"reassign_asset_references",
|
||||
"rebuild_metadata_projection",
|
||||
"reference_exists",
|
||||
"reference_exists_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_tags_from_reference",
|
||||
"restore_references_by_paths",
|
||||
"set_reference_metadata",
|
||||
"set_reference_preview",
|
||||
"set_reference_system_metadata",
|
||||
"soft_delete_reference_by_id",
|
||||
"set_reference_tags",
|
||||
"update_asset_hash_and_mime",
|
||||
"update_is_missing_by_asset_id",
|
||||
"update_reference_access_time",
|
||||
"update_reference_name",
|
||||
"update_reference_timestamps",
|
||||
|
||||
@@ -69,7 +69,7 @@ def upsert_asset(
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and not asset.mime_type:
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
@@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
|
||||
return False
|
||||
if asset_hash is not None:
|
||||
asset.hash = asset_hash
|
||||
if mime_type is not None and not asset.mime_type:
|
||||
if mime_type is not None:
|
||||
asset.mime_type = mime_type
|
||||
return True
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from decimal import Decimal
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, exists, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, noload
|
||||
@@ -24,14 +24,12 @@ from app.assets.database.models import (
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
MAX_BIND_PARAMS,
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
build_prefix_like_conditions,
|
||||
build_visible_owner_clause,
|
||||
calculate_rows_per_statement,
|
||||
iter_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
def _check_is_scalar(v):
|
||||
@@ -46,6 +44,15 @@ def _check_is_scalar(v):
|
||||
|
||||
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||
"""Convert a scalar value to a typed projection row."""
|
||||
if value is None:
|
||||
return {
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
if isinstance(value, bool):
|
||||
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
@@ -59,19 +66,96 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
||||
"""Turn a metadata key/value into typed projection rows."""
|
||||
if value is None:
|
||||
return []
|
||||
return [_scalar_to_row(key, 0, None)]
|
||||
|
||||
if _check_is_scalar(value):
|
||||
return [_scalar_to_row(key, 0, value)]
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(_check_is_scalar(x) for x in value):
|
||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
|
||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
|
||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
||||
|
||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_reference_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetReferenceMeta.val_json.is_(None),
|
||||
AssetReferenceMeta.val_str.is_(None),
|
||||
AssetReferenceMeta.val_num.is_(None),
|
||||
AssetReferenceMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def get_reference_by_id(
|
||||
@@ -128,21 +212,6 @@ def reference_exists_for_asset_id(
|
||||
return session.execute(q).first() is not None
|
||||
|
||||
|
||||
def reference_exists(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
) -> bool:
|
||||
"""Return True if a reference with the given ID exists (not soft-deleted)."""
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetReference)
|
||||
.where(AssetReference.id == reference_id)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.limit(1)
|
||||
)
|
||||
return session.execute(q).first() is not None
|
||||
|
||||
|
||||
def insert_reference(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
@@ -267,8 +336,8 @@ def list_references_page(
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
@@ -297,8 +366,8 @@ def list_references_page(
|
||||
count_stmt = count_stmt.where(
|
||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||
)
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int(session.execute(count_stmt).scalar_one() or 0)
|
||||
refs = session.execute(base).unique().scalars().all()
|
||||
@@ -310,7 +379,7 @@ def list_references_page(
|
||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
.order_by(AssetReferenceTag.added_at)
|
||||
)
|
||||
for ref_id, tag_name in rows.all():
|
||||
tag_map[ref_id].append(tag_name)
|
||||
@@ -423,42 +492,6 @@ def update_reference_updated_at(
|
||||
)
|
||||
|
||||
|
||||
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
|
||||
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
|
||||
|
||||
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
|
||||
override system keys of the same name.
|
||||
"""
|
||||
session.execute(
|
||||
delete(AssetReferenceMeta).where(
|
||||
AssetReferenceMeta.asset_reference_id == ref.id
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
|
||||
if not merged:
|
||||
return
|
||||
|
||||
rows: list[AssetReferenceMeta] = []
|
||||
for k, v in merged.items():
|
||||
for r in convert_metadata_to_rows(k, v):
|
||||
rows.append(
|
||||
AssetReferenceMeta(
|
||||
asset_reference_id=ref.id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def set_reference_metadata(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
@@ -472,24 +505,33 @@ def set_reference_metadata(
|
||||
ref.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
|
||||
rebuild_metadata_projection(session, ref)
|
||||
|
||||
|
||||
def set_reference_system_metadata(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
system_metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Set system_metadata on a reference and rebuild the merged projection."""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
ref.system_metadata = system_metadata or {}
|
||||
ref.updated_at = get_utc_now()
|
||||
session.execute(
|
||||
delete(AssetReferenceMeta).where(
|
||||
AssetReferenceMeta.asset_reference_id == reference_id
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
rebuild_metadata_projection(session, ref)
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetReferenceMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in convert_metadata_to_rows(k, v):
|
||||
rows.append(
|
||||
AssetReferenceMeta(
|
||||
asset_reference_id=reference_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def delete_reference_by_id(
|
||||
@@ -529,19 +571,19 @@ def soft_delete_reference_by_id(
|
||||
def set_reference_preview(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
preview_reference_id: str | None = None,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
if preview_reference_id is None:
|
||||
if preview_asset_id is None:
|
||||
ref.preview_id = None
|
||||
else:
|
||||
if not session.get(AssetReference, preview_reference_id):
|
||||
raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
|
||||
ref.preview_id = preview_reference_id
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
ref.preview_id = preview_asset_id
|
||||
|
||||
ref.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
@@ -567,8 +609,6 @@ def list_references_by_asset_id(
|
||||
session.execute(
|
||||
select(AssetReference)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.is_missing == False) # noqa: E712
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.order_by(AssetReference.id.asc())
|
||||
)
|
||||
.scalars()
|
||||
@@ -576,25 +616,6 @@ def list_references_by_asset_id(
|
||||
)
|
||||
|
||||
|
||||
def list_all_file_paths_by_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> list[str]:
|
||||
"""Return every file_path for an asset, including soft-deleted/missing refs.
|
||||
|
||||
Used for orphan cleanup where all on-disk files must be removed.
|
||||
"""
|
||||
return list(
|
||||
session.execute(
|
||||
select(AssetReference.file_path)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.file_path.isnot(None))
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def upsert_reference(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
@@ -834,22 +855,6 @@ def bulk_update_is_missing(
|
||||
return total
|
||||
|
||||
|
||||
def update_is_missing_by_asset_id(
|
||||
session: Session, asset_id: str, value: bool
|
||||
) -> int:
|
||||
"""Set is_missing flag for ALL references belonging to an asset.
|
||||
|
||||
Returns: Number of rows updated
|
||||
"""
|
||||
result = session.execute(
|
||||
sa.update(AssetReference)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.values(is_missing=value)
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
||||
"""Delete references by their IDs.
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Shared utilities for database query modules."""
|
||||
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from typing import Iterable, Sequence
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
||||
from app.assets.helpers import escape_sql_like_string, normalize_tags
|
||||
from app.assets.database.models import AssetReference
|
||||
from app.assets.helpers import escape_sql_like_string
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
@@ -54,74 +52,3 @@ def build_prefix_like_conditions(
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||
return conds
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_reference_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
return sa.not_(
|
||||
sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
@@ -8,15 +8,12 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import (
|
||||
Asset,
|
||||
AssetReference,
|
||||
AssetReferenceMeta,
|
||||
AssetReferenceTag,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
@@ -75,9 +72,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
select(AssetReferenceTag.tag_name).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
@@ -120,7 +117,7 @@ def set_reference_tags(
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
|
||||
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
||||
|
||||
|
||||
def add_tags_to_reference(
|
||||
@@ -275,12 +272,6 @@ def list_tags_with_usage(
|
||||
.select_from(AssetReferenceTag)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.subquery()
|
||||
@@ -317,12 +308,6 @@ def list_tags_with_usage(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
)
|
||||
@@ -335,53 +320,6 @@ def list_tags_with_usage(
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def list_tag_counts_for_filtered_assets(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
"""Return tag counts for assets matching the given filters.
|
||||
|
||||
Uses the same filtering logic as list_references_page but returns
|
||||
{tag_name: count} instead of paginated references.
|
||||
"""
|
||||
# Build a subquery of matching reference IDs
|
||||
ref_sq = (
|
||||
select(AssetReference.id)
|
||||
.join(Asset, Asset.id == AssetReference.asset_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.is_missing == False) # noqa: E712
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
|
||||
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
|
||||
ref_sq = ref_sq.subquery()
|
||||
|
||||
# Count tags across those references
|
||||
q = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name,
|
||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||
)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
rows = session.execute(q).all()
|
||||
return {tag_name: int(cnt) for tag_name, cnt in rows}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
tag_rows: list[dict],
|
||||
|
||||
@@ -18,7 +18,7 @@ from app.assets.database.queries import (
|
||||
mark_references_missing_outside_prefixes,
|
||||
reassign_asset_references,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_system_metadata,
|
||||
set_reference_metadata,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
from app.assets.services.bulk_ingest import (
|
||||
@@ -490,8 +490,8 @@ def enrich_asset(
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
user_metadata = metadata.to_user_metadata()
|
||||
set_reference_metadata(session, reference_id, user_metadata)
|
||||
|
||||
if full_hash:
|
||||
existing = get_asset_by_hash(session, full_hash)
|
||||
|
||||
@@ -16,12 +16,10 @@ from app.assets.database.queries import (
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
list_references_page,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_updated_at,
|
||||
@@ -69,8 +67,6 @@ def update_asset_metadata(
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
@@ -107,21 +103,6 @@ def update_asset_metadata(
|
||||
)
|
||||
touched = True
|
||||
|
||||
if mime_type is not None:
|
||||
updated = update_asset_hash_and_mime(
|
||||
session, asset_id=ref.asset_id, mime_type=mime_type
|
||||
)
|
||||
if updated:
|
||||
touched = True
|
||||
|
||||
if preview_id is not None:
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_reference_id=preview_id,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
update_reference_updated_at(session, reference_id=reference_id)
|
||||
|
||||
@@ -178,9 +159,11 @@ def delete_asset_reference(
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# Orphaned asset - gather ALL file paths (including
|
||||
# soft-deleted / missing refs) so their on-disk files get cleaned up.
|
||||
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
|
||||
# Orphaned asset - delete it and its files
|
||||
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [
|
||||
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
||||
]
|
||||
# Also include the just-deleted file path
|
||||
if file_path:
|
||||
file_paths.append(file_path)
|
||||
@@ -202,7 +185,7 @@ def delete_asset_reference(
|
||||
|
||||
def set_asset_preview(
|
||||
reference_id: str,
|
||||
preview_reference_id: str | None = None,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
@@ -211,7 +194,7 @@ def set_asset_preview(
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_reference_id=preview_reference_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
@@ -280,47 +263,6 @@ def list_assets_page(
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
asset_hash: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult | None:
|
||||
"""Resolve a blake3 hash to an on-disk file path.
|
||||
|
||||
Only references visible to *owner_id* are considered (owner-less
|
||||
references are always visible).
|
||||
|
||||
Returns a DownloadResolutionResult with abs_path, content_type, and
|
||||
download_name, or None if no asset or live path is found.
|
||||
"""
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash)
|
||||
if not asset:
|
||||
return None
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
visible = [
|
||||
r for r in refs
|
||||
if r.owner_id == "" or r.owner_id == owner_id
|
||||
]
|
||||
abs_path = select_best_live_path(visible)
|
||||
if not abs_path:
|
||||
return None
|
||||
display_name = os.path.basename(abs_path)
|
||||
for ref in visible:
|
||||
if ref.file_path == abs_path and ref.name:
|
||||
display_name = ref.name
|
||||
break
|
||||
ctype = (
|
||||
asset.mime_type
|
||||
or mimetypes.guess_type(display_name)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=display_name,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
|
||||
@@ -11,14 +11,13 @@ from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
@@ -27,7 +26,6 @@ from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
@@ -67,7 +65,7 @@ def _ingest_file_from_path(
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if not reference_exists(session, preview_id):
|
||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
@@ -137,8 +135,6 @@ def _register_existing_asset(
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
@@ -147,25 +143,14 @@ def _register_existing_asset(
|
||||
if not asset:
|
||||
raise ValueError(f"No asset with hash {asset_hash}")
|
||||
|
||||
if mime_type and not asset.mime_type:
|
||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
|
||||
|
||||
if preview_id:
|
||||
if not reference_exists(session, preview_id):
|
||||
preview_id = None
|
||||
|
||||
ref, ref_created = get_or_create_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
|
||||
if not ref_created:
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
@@ -257,8 +242,6 @@ def upload_from_temp_path(
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_hash: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult:
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||
@@ -287,8 +270,6 @@ def upload_from_temp_path(
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
@@ -310,7 +291,7 @@ def upload_from_temp_path(
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
validate_path_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = mime_type or (
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
@@ -334,7 +315,7 @@ def upload_from_temp_path(
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=preview_id,
|
||||
preview_id=None,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
@@ -361,99 +342,30 @@ def upload_from_temp_path(
|
||||
)
|
||||
|
||||
|
||||
def register_file_in_place(
|
||||
abs_path: str,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
) -> UploadResult:
|
||||
"""Register an already-saved file in the asset database without moving it.
|
||||
|
||||
Tags are derived from the filesystem path (root category + subfolder names),
|
||||
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||
If the path is not under a known root, only the caller-provided tags are used.
|
||||
"""
|
||||
try:
|
||||
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
except ValueError:
|
||||
path_tags = []
|
||||
merged_tags = normalize_tags([*path_tags, *tags])
|
||||
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
content_type = mime_type or (
|
||||
mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
abs_path=abs_path,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
tags=merged_tags,
|
||||
tag_origin="upload",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
reference_id = ingest_result.reference_id
|
||||
if not reference_id:
|
||||
raise RuntimeError("failed to create asset reference")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
ref, asset = pair
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
|
||||
return UploadResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def create_from_hash(
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
|
||||
try:
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
except ValueError:
|
||||
logging.warning("create_from_hash: no asset found for hash %s", canonical)
|
||||
return None
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
|
||||
@@ -25,9 +25,7 @@ class ReferenceData:
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
system_metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
last_access_time: datetime | None = None
|
||||
last_access_time: datetime | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -95,8 +93,6 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
||||
file_path=ref.file_path,
|
||||
user_metadata=ref.user_metadata,
|
||||
preview_id=ref.preview_id,
|
||||
system_metadata=ref.system_metadata,
|
||||
job_id=ref.job_id,
|
||||
created_at=ref.created_at,
|
||||
updated_at=ref.updated_at,
|
||||
last_access_time=ref.last_access_time,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
@@ -8,7 +6,6 @@ from app.assets.database.queries import (
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_reference,
|
||||
)
|
||||
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
|
||||
from app.assets.services.schemas import TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
@@ -76,23 +73,3 @@ def list_tags(
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
|
||||
|
||||
def list_tag_histogram(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
with create_session() as session:
|
||||
return list_tag_counts_for_filtered_assets(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,9 @@
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
NAMING_CONVENTION = {
|
||||
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||
pass
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
|
||||
@@ -6,7 +6,6 @@ import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@@ -378,15 +377,8 @@ class UserManager():
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
dir_name = os.path.dirname(path)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(body)
|
||||
os.replace(tmp_path, path)
|
||||
except:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
|
||||
@@ -83,8 +83,6 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||
|
||||
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
@@ -149,7 +147,6 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -182,6 +179,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
@@ -263,6 +262,4 @@ else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
if args.enable_dynamic_vram:
|
||||
return True
|
||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||
|
||||
@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[dict]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -93,50 +93,6 @@ class IndexListCallbacks:
|
||||
return {}
|
||||
|
||||
|
||||
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
|
||||
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
|
||||
return None
|
||||
cond_tensor = cond_value.cond
|
||||
if temporal_dim >= cond_tensor.ndim:
|
||||
return None
|
||||
|
||||
cond_size = cond_tensor.size(temporal_dim)
|
||||
|
||||
if temporal_scale == 1:
|
||||
expected_size = x_in.size(window.dim) - temporal_offset
|
||||
if cond_size != expected_size:
|
||||
return None
|
||||
|
||||
if temporal_offset == 0 and temporal_scale == 1:
|
||||
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
if temporal_scale > 1:
|
||||
scaled = []
|
||||
for i in indices:
|
||||
for k in range(temporal_scale):
|
||||
si = i * temporal_scale + k
|
||||
if si < cond_size:
|
||||
scaled.append(si)
|
||||
indices = scaled
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
idx = tuple([slice(None)] * temporal_dim + [indices])
|
||||
sliced = cond_tensor[idx].to(device)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@@ -221,17 +177,10 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if not handled and self._model is not None:
|
||||
result = self._model.resize_cond_for_context_window(
|
||||
cond_key, cond_value, window, x_in, device,
|
||||
retain_index_list=self.cond_retain_index_list)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
@@ -275,7 +224,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
@@ -209,39 +209,3 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||
def roundup(x_val, multiple):
|
||||
return ((x_val + multiple - 1) // multiple) * multiple
|
||||
|
||||
if pad_32x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 32)
|
||||
padded_cols = roundup(cols, 32)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
F8_E4M3_MAX = 448.0
|
||||
E8M0_BIAS = 127
|
||||
BLOCK_SIZE = 32
|
||||
|
||||
rows, cols = x.shape
|
||||
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||
|
||||
# E8M0 block scales (power-of-2 exponents)
|
||||
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||
|
||||
zero_mask = (max_abs == 0)
|
||||
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||
|
||||
# Scale per-block then stochastic round
|
||||
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||
|
||||
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||
|
||||
@@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.cli_args import args
|
||||
import uuid
|
||||
import os
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
@@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
def __init__(self):
|
||||
if _ISOLATION_HOOKREF_MODE:
|
||||
self._pyisolate_id = str(uuid.uuid4())
|
||||
|
||||
def _ensure_pyisolate_id(self):
|
||||
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||
if pyisolate_id is None:
|
||||
pyisolate_id = str(uuid.uuid4())
|
||||
self._pyisolate_id = pyisolate_id
|
||||
return pyisolate_id
|
||||
|
||||
def __eq__(self, other):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return self is other
|
||||
if not isinstance(other, _HookRef):
|
||||
return False
|
||||
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||
|
||||
def __hash__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return id(self)
|
||||
return hash(self._ensure_pyisolate_id())
|
||||
|
||||
def __str__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return super().__str__()
|
||||
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
@@ -168,6 +200,8 @@ class WeightHook(Hook):
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if self.weights is None:
|
||||
self.weights = {}
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
|
||||
394
comfy/isolation/__init__.py
Normal file
394
comfy/isolation/__init__.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||
import folder_paths
|
||||
from .extension_loader import load_isolated_node
|
||||
from .manifest_loader import find_manifest_directories
|
||||
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
||||
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyisolate import ExtensionManager
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
isolated_node_timings: List[tuple[float, Path, int]] = []
|
||||
|
||||
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
||||
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
from .child_hooks import is_child_process
|
||||
|
||||
is_child = is_child_process()
|
||||
|
||||
if is_child:
|
||||
from .child_hooks import initialize_child_process
|
||||
|
||||
initialize_child_process()
|
||||
else:
|
||||
from .host_hooks import initialize_host_process
|
||||
|
||||
initialize_host_process()
|
||||
start_shm_forensics()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IsolatedNodeSpec:
|
||||
node_name: str
|
||||
display_name: str
|
||||
stub_class: type
|
||||
module_path: Path
|
||||
|
||||
|
||||
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
|
||||
_CLAIMED_PATHS: Set[Path] = set()
|
||||
_ISOLATION_SCAN_ATTEMPTED = False
|
||||
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
|
||||
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
|
||||
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
|
||||
_EARLY_START_TIME: Optional[float] = None
|
||||
|
||||
|
||||
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
return
|
||||
_EARLY_START_TIME = time.perf_counter()
|
||||
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
|
||||
|
||||
|
||||
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
specs = await _ISOLATION_BACKGROUND_TASK
|
||||
return specs
|
||||
return await initialize_isolation_nodes()
|
||||
|
||||
|
||||
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
|
||||
|
||||
if _ISOLATED_NODE_SPECS:
|
||||
return _ISOLATED_NODE_SPECS
|
||||
|
||||
if _ISOLATION_SCAN_ATTEMPTED:
|
||||
return []
|
||||
|
||||
_ISOLATION_SCAN_ATTEMPTED = True
|
||||
manifest_entries = find_manifest_directories()
|
||||
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
||||
|
||||
if not manifest_entries:
|
||||
return []
|
||||
|
||||
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
|
||||
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def load_with_semaphore(
|
||||
node_dir: Path, manifest: Path
|
||||
) -> List[IsolatedNodeSpec]:
|
||||
async with semaphore:
|
||||
load_start = time.perf_counter()
|
||||
spec_list = await load_isolated_node(
|
||||
node_dir,
|
||||
manifest,
|
||||
logger,
|
||||
lambda name, info, extension: build_stub_class(
|
||||
name,
|
||||
info,
|
||||
extension,
|
||||
_RUNNING_EXTENSIONS,
|
||||
logger,
|
||||
),
|
||||
PYISOLATE_VENV_ROOT,
|
||||
_EXTENSION_MANAGERS,
|
||||
)
|
||||
spec_list = [
|
||||
IsolatedNodeSpec(
|
||||
node_name=node_name,
|
||||
display_name=display_name,
|
||||
stub_class=stub_cls,
|
||||
module_path=node_dir,
|
||||
)
|
||||
for node_name, display_name, stub_cls in spec_list
|
||||
]
|
||||
isolated_node_timings.append(
|
||||
(time.perf_counter() - load_start, node_dir, len(spec_list))
|
||||
)
|
||||
return spec_list
|
||||
|
||||
tasks = [
|
||||
load_with_semaphore(node_dir, manifest)
|
||||
for node_dir, manifest in manifest_entries
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
specs: List[IsolatedNodeSpec] = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"%s Isolated node failed during startup; continuing: %s",
|
||||
LOG_PREFIX,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
specs.extend(result)
|
||||
|
||||
_ISOLATED_NODE_SPECS = specs
|
||||
return list(_ISOLATED_NODE_SPECS)
|
||||
|
||||
|
||||
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
"""Get all node class types (node names) belonging to an extension."""
|
||||
extension = _RUNNING_EXTENSIONS.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in _ISOLATED_NODE_SPECS:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
|
||||
return class_types
|
||||
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
"""Evict running extensions not needed for current execution."""
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:notify_graph_wait_idle",
|
||||
)
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
) -> None:
|
||||
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
||||
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
||||
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||
|
||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||
isolated_class_types_in_graph = needed_class_types.intersection(
|
||||
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
||||
)
|
||||
graph_uses_isolation = bool(isolated_class_types_in_graph)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
len(needed_class_types),
|
||||
)
|
||||
if graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
|
||||
# If NONE of this extension's nodes are in the execution graph -> evict.
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
)
|
||||
|
||||
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if getattr(device, "type", None) == "cuda":
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
free_before = model_management.get_free_memory(device)
|
||||
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.unload_all_models()
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
finally:
|
||||
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
||||
)
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:flush_transport_wait_idle",
|
||||
)
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
if not callable(flush_fn):
|
||||
continue
|
||||
try:
|
||||
flushed = await flush_fn()
|
||||
if isinstance(flushed, int):
|
||||
total_flushed += flushed
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush released=%d",
|
||||
LOG_PREFIX,
|
||||
ext_name,
|
||||
flushed,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
|
||||
)
|
||||
scan_shm_forensics(
|
||||
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
|
||||
)
|
||||
return total_flushed
|
||||
|
||||
|
||||
async def wait_for_model_patcher_quiescence(
|
||||
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
marker: str = "ISO:wait_model_patcher_idle",
|
||||
) -> bool:
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
start = time.perf_counter()
|
||||
idle = await registry.wait_all_idle(timeout_ms)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
if idle:
|
||||
logger.debug(
|
||||
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
)
|
||||
return True
|
||||
|
||||
states = await registry.get_all_operation_states()
|
||||
logger.error(
|
||||
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
states,
|
||||
)
|
||||
if fail_loud:
|
||||
raise TimeoutError(
|
||||
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
|
||||
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
|
||||
"""Update all active RPC instances with the current event loop.
|
||||
|
||||
This MUST be called at the start of each workflow execution to ensure
|
||||
RPC calls are scheduled on the correct event loop. This handles the case
|
||||
where asyncio.run() creates a new event loop for each workflow.
|
||||
|
||||
Args:
|
||||
loop: The event loop to use. If None, uses asyncio.get_running_loop().
|
||||
"""
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
update_count = 0
|
||||
|
||||
# Update RPCs from ExtensionManagers
|
||||
for manager in _EXTENSION_MANAGERS:
|
||||
if not hasattr(manager, "extensions"):
|
||||
continue
|
||||
for name, extension in manager.extensions.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
|
||||
|
||||
# Also update RPCs from running extensions (they may have direct RPC refs)
|
||||
for name, extension in _RUNNING_EXTENSIONS.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
|
||||
|
||||
if update_count > 0:
|
||||
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
|
||||
else:
|
||||
logger.debug(
|
||||
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"initialize_proxies",
|
||||
"initialize_isolation_nodes",
|
||||
"start_isolation_loading_early",
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"wait_for_model_patcher_quiescence",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
"get_class_types_for_extension",
|
||||
]
|
||||
641
comfy/isolation/adapter.py
Normal file
641
comfy/isolation/adapter.py
Normal file
@@ -0,0 +1,641 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
||||
from comfy.isolation.model_patcher_proxy import (
|
||||
ModelPatcherProxy,
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
from comfy.isolation.model_sampling_proxy import (
|
||||
ModelSamplingProxy,
|
||||
ModelSamplingRegistry,
|
||||
)
|
||||
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
except ImportError as exc: # Fail loud if Comfy environment is incomplete
|
||||
raise ImportError(f"ComfyUI environment incomplete: {exc}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force /dev/shm for shared memory (bwrap makes /tmp private)
|
||||
import tempfile
|
||||
|
||||
if os.path.exists("/dev/shm"):
|
||||
# Only override if not already set or if default is not /dev/shm
|
||||
current_tmp = tempfile.gettempdir()
|
||||
if not current_tmp.startswith("/dev/shm"):
|
||||
logger.debug(
|
||||
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
|
||||
)
|
||||
os.environ["TMPDIR"] = "/dev/shm"
|
||||
tempfile.tempdir = None # Clear cache to force re-evaluation
|
||||
|
||||
|
||||
class ComfyUIAdapter(IsolationAdapter):
|
||||
# ComfyUI-specific IsolationAdapter implementation
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return "comfyui"
|
||||
|
||||
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
|
||||
if "ComfyUI" in module_path and "custom_nodes" in module_path:
|
||||
parts = module_path.split("ComfyUI")
|
||||
if len(parts) > 1:
|
||||
comfy_root = parts[0] + "ComfyUI"
|
||||
return {
|
||||
"preferred_root": comfy_root,
|
||||
"additional_paths": [
|
||||
os.path.join(comfy_root, "custom_nodes"),
|
||||
os.path.join(comfy_root, "comfy"),
|
||||
],
|
||||
}
|
||||
return None
|
||||
|
||||
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
||||
comfy_root = snapshot.get("preferred_root")
|
||||
if not comfy_root:
|
||||
return
|
||||
|
||||
requirements_path = Path(comfy_root) / "requirements.txt"
|
||||
if requirements_path.exists():
|
||||
import re
|
||||
|
||||
for line in requirements_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
|
||||
if pkg_name:
|
||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||
|
||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||
import torch
|
||||
|
||||
def serialize_device(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "device", "device_str": str(obj)}
|
||||
|
||||
def deserialize_device(data: Dict[str, Any]) -> Any:
|
||||
return torch.device(data["device_str"])
|
||||
|
||||
registry.register("device", serialize_device, deserialize_device)
|
||||
|
||||
_VALID_DTYPES = {
|
||||
"float16", "float32", "float64", "bfloat16",
|
||||
"int8", "int16", "int32", "int64",
|
||||
"uint8", "bool",
|
||||
}
|
||||
|
||||
def serialize_dtype(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "dtype", "dtype_str": str(obj)}
|
||||
|
||||
def deserialize_dtype(data: Dict[str, Any]) -> Any:
|
||||
dtype_name = data["dtype_str"].replace("torch.", "")
|
||||
if dtype_name not in _VALID_DTYPES:
|
||||
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
|
||||
return getattr(torch, dtype_name)
|
||||
|
||||
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
||||
|
||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelPatcher in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side: register with registry
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
model_id = ModelPatcherRegistry().register(obj)
|
||||
return {"__type__": "ModelPatcherRef", "model_id": model_id}
|
||||
|
||||
def deserialize_model_patcher(data: Any) -> Any:
|
||||
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
return data
|
||||
|
||||
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelPatcherRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
else:
|
||||
return ModelPatcherRegistry()._get_instance(data["model_id"])
|
||||
|
||||
# Register ModelPatcher type for serialization
|
||||
registry.register(
|
||||
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherProxy type (already a proxy, just return ref)
|
||||
registry.register(
|
||||
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
|
||||
|
||||
def serialize_clip(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
|
||||
clip_id = CLIPRegistry().register(obj)
|
||||
return {"__type__": "CLIPRef", "clip_id": clip_id}
|
||||
|
||||
def deserialize_clip(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
return data
|
||||
|
||||
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware CLIPRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
else:
|
||||
return CLIPRegistry()._get_instance(data["clip_id"])
|
||||
|
||||
# Register CLIP type for serialization
|
||||
registry.register("CLIP", serialize_clip, deserialize_clip)
|
||||
# Register CLIPProxy type (already a proxy, just return ref)
|
||||
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
|
||||
# Register CLIPRef for deserialization (context-aware: host or child)
|
||||
registry.register("CLIPRef", None, deserialize_clip_ref)
|
||||
|
||||
def serialize_vae(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "VAERef", "vae_id": obj._instance_id}
|
||||
vae_id = VAERegistry().register(obj)
|
||||
return {"__type__": "VAERef", "vae_id": vae_id}
|
||||
|
||||
def deserialize_vae(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return VAEProxy(data["vae_id"])
|
||||
return data
|
||||
|
||||
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware VAERef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
# Child: create a proxy
|
||||
return VAEProxy(data["vae_id"])
|
||||
else:
|
||||
# Host: lookup real VAE from registry
|
||||
return VAERegistry()._get_instance(data["vae_id"])
|
||||
|
||||
# Register VAE type for serialization
|
||||
registry.register("VAE", serialize_vae, deserialize_vae)
|
||||
# Register VAEProxy type (already a proxy, just return ref)
|
||||
registry.register("VAEProxy", serialize_vae, deserialize_vae)
|
||||
# Register VAERef for deserialization (context-aware: host or child)
|
||||
registry.register("VAERef", None, deserialize_vae_ref)
|
||||
|
||||
# ModelSampling serialization - handles ModelSampling* types
|
||||
# copyreg removed - no pickle fallback allowed
|
||||
|
||||
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelSampling in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side pass-through for proxies: do not re-register a proxy as a
|
||||
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||
ms_id = ModelSamplingRegistry().register(obj)
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||
|
||||
def deserialize_model_sampling(data: Any) -> Any:
|
||||
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
return data
|
||||
|
||||
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
else:
|
||||
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||
|
||||
# Register all ModelSampling* and StableCascadeSampling classes dynamically
|
||||
import comfy.model_sampling
|
||||
|
||||
for ms_cls in vars(comfy.model_sampling).values():
|
||||
if not isinstance(ms_cls, type):
|
||||
continue
|
||||
if not issubclass(ms_cls, torch.nn.Module):
|
||||
continue
|
||||
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
|
||||
continue
|
||||
registry.register(
|
||||
ms_cls.__name__,
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||
)
|
||||
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
||||
|
||||
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"cond": obj.cond,
|
||||
}
|
||||
|
||||
def deserialize_cond(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(data["cond"])
|
||||
|
||||
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
|
||||
state: Dict[str, Any] = {}
|
||||
for key, value in obj.__dict__.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if callable(value):
|
||||
continue
|
||||
state[key] = value
|
||||
return state
|
||||
|
||||
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"state": _serialize_public_state(obj),
|
||||
}
|
||||
|
||||
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
obj = cls()
|
||||
for key, value in data.get("state", {}).items():
|
||||
prop = getattr(type(obj), key, None)
|
||||
if isinstance(prop, property) and prop.fset is None:
|
||||
continue
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
import comfy.conds
|
||||
|
||||
for cond_cls in vars(comfy.conds).values():
|
||||
if not isinstance(cond_cls, type):
|
||||
continue
|
||||
if not issubclass(cond_cls, comfy.conds.CONDRegular):
|
||||
continue
|
||||
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
|
||||
registry.register(type_key, serialize_cond, deserialize_cond)
|
||||
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
|
||||
|
||||
import comfy.latent_formats
|
||||
|
||||
for latent_cls in vars(comfy.latent_formats).values():
|
||||
if not isinstance(latent_cls, type):
|
||||
continue
|
||||
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
|
||||
continue
|
||||
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
|
||||
registry.register(
|
||||
type_key, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
registry.register(
|
||||
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
|
||||
# V3 API: unwrap NodeOutput.args
|
||||
def deserialize_node_output(data: Any) -> Any:
|
||||
return getattr(data, "args", data)
|
||||
|
||||
registry.register("NodeOutput", None, deserialize_node_output)
|
||||
|
||||
# KSAMPLER serializer: stores sampler name instead of function object
|
||||
# sampler_function is a callable which gets filtered out by JSONSocketTransport
|
||||
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
|
||||
func_name = obj.sampler_function.__name__
|
||||
# Map function name back to sampler name
|
||||
if func_name == "sample_unipc":
|
||||
sampler_name = "uni_pc"
|
||||
elif func_name == "sample_unipc_bh2":
|
||||
sampler_name = "uni_pc_bh2"
|
||||
elif func_name == "dpm_fast_function":
|
||||
sampler_name = "dpm_fast"
|
||||
elif func_name == "dpm_adaptive_function":
|
||||
sampler_name = "dpm_adaptive"
|
||||
elif func_name.startswith("sample_"):
|
||||
sampler_name = func_name[7:] # Remove "sample_" prefix
|
||||
else:
|
||||
sampler_name = func_name
|
||||
return {
|
||||
"__type__": "KSAMPLER",
|
||||
"sampler_name": sampler_name,
|
||||
"extra_options": obj.extra_options,
|
||||
"inpaint_options": obj.inpaint_options,
|
||||
}
|
||||
|
||||
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
|
||||
import comfy.samplers
|
||||
|
||||
return comfy.samplers.ksampler(
|
||||
data["sampler_name"],
|
||||
data.get("extra_options", {}),
|
||||
data.get("inpaint_options", {}),
|
||||
)
|
||||
|
||||
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
|
||||
|
||||
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
|
||||
|
||||
register_hooks_serializers(registry)
|
||||
|
||||
# Generic Numpy Serializer
|
||||
def serialize_numpy(obj: Any) -> Any:
|
||||
import torch
|
||||
|
||||
try:
|
||||
# Attempt zero-copy conversion to Tensor
|
||||
return torch.from_numpy(obj)
|
||||
except Exception:
|
||||
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||
return obj.tolist()
|
||||
|
||||
registry.register("ndarray", serialize_numpy, None)
|
||||
|
||||
def serialize_ply(obj: Any) -> Dict[str, Any]:
|
||||
import base64
|
||||
import torch
|
||||
if obj.raw_data is not None:
|
||||
return {
|
||||
"__type__": "PLY",
|
||||
"raw_data": base64.b64encode(obj.raw_data).decode("ascii"),
|
||||
}
|
||||
result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)}
|
||||
if obj.colors is not None:
|
||||
result["colors"] = torch.from_numpy(obj.colors)
|
||||
if obj.confidence is not None:
|
||||
result["confidence"] = torch.from_numpy(obj.confidence)
|
||||
if obj.view_id is not None:
|
||||
result["view_id"] = torch.from_numpy(obj.view_id)
|
||||
return result
|
||||
|
||||
def deserialize_ply(data: Any) -> Any:
|
||||
import base64
|
||||
from comfy_api.latest._util.ply_types import PLY
|
||||
if "raw_data" in data:
|
||||
return PLY(raw_data=base64.b64decode(data["raw_data"]))
|
||||
return PLY(
|
||||
points=data["points"],
|
||||
colors=data.get("colors"),
|
||||
confidence=data.get("confidence"),
|
||||
view_id=data.get("view_id"),
|
||||
)
|
||||
|
||||
registry.register("PLY", serialize_ply, deserialize_ply, data_type=True)
|
||||
|
||||
def serialize_npz(obj: Any) -> Dict[str, Any]:
|
||||
import base64
|
||||
return {
|
||||
"__type__": "NPZ",
|
||||
"frames": [base64.b64encode(f).decode("ascii") for f in obj.frames],
|
||||
}
|
||||
|
||||
def deserialize_npz(data: Any) -> Any:
|
||||
import base64
|
||||
from comfy_api.latest._util.npz_types import NPZ
|
||||
return NPZ(frames=[base64.b64decode(f) for f in data["frames"]])
|
||||
|
||||
registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True)
|
||||
|
||||
def serialize_file3d(obj: Any) -> Dict[str, Any]:
|
||||
import base64
|
||||
return {
|
||||
"__type__": "File3D",
|
||||
"format": obj.format,
|
||||
"data": base64.b64encode(obj.get_bytes()).decode("ascii"),
|
||||
}
|
||||
|
||||
def deserialize_file3d(data: Any) -> Any:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from comfy_api.latest._util.geometry_types import File3D
|
||||
return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"])
|
||||
|
||||
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
|
||||
|
||||
def serialize_video(obj: Any) -> Dict[str, Any]:
|
||||
components = obj.get_components()
|
||||
images = components.images.detach() if components.images.requires_grad else components.images
|
||||
result: Dict[str, Any] = {
|
||||
"__type__": "VIDEO",
|
||||
"images": images,
|
||||
"frame_rate_num": components.frame_rate.numerator,
|
||||
"frame_rate_den": components.frame_rate.denominator,
|
||||
}
|
||||
if components.audio is not None:
|
||||
waveform = components.audio["waveform"]
|
||||
if waveform.requires_grad:
|
||||
waveform = waveform.detach()
|
||||
result["audio_waveform"] = waveform
|
||||
result["audio_sample_rate"] = components.audio["sample_rate"]
|
||||
if components.metadata is not None:
|
||||
result["metadata"] = components.metadata
|
||||
return result
|
||||
|
||||
def deserialize_video(data: Any) -> Any:
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest._input_impl.video_types import VideoFromComponents
|
||||
from comfy_api.latest._util.video_types import VideoComponents
|
||||
audio = None
|
||||
if "audio_waveform" in data:
|
||||
audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]}
|
||||
components = VideoComponents(
|
||||
images=data["images"],
|
||||
frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]),
|
||||
audio=audio,
|
||||
metadata=data.get("metadata"),
|
||||
)
|
||||
return VideoFromComponents(components)
|
||||
|
||||
registry.register("VIDEO", serialize_video, deserialize_video, data_type=True)
|
||||
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
|
||||
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
|
||||
|
||||
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
||||
return [
|
||||
PromptServerService,
|
||||
FolderPathsProxy,
|
||||
ModelManagementProxy,
|
||||
UtilsProxy,
|
||||
ProgressProxy,
|
||||
VAERegistry,
|
||||
CLIPRegistry,
|
||||
ModelPatcherRegistry,
|
||||
ModelSamplingRegistry,
|
||||
FirstStageModelRegistry,
|
||||
]
|
||||
|
||||
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
||||
# Resolve the real name whether it's an instance or the Singleton class itself
|
||||
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
|
||||
|
||||
if api_name == "FolderPathsProxy":
|
||||
import folder_paths
|
||||
|
||||
# Replace module-level functions with proxy methods
|
||||
# This is aggressive but necessary for transparent proxying
|
||||
# Handle both instance and class cases
|
||||
instance = api() if isinstance(api, type) else api
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(folder_paths, name, getattr(instance, name))
|
||||
|
||||
# Fence: isolated children get writable temp inside sandbox
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
_child_temp = os.path.join("/tmp", "comfyui_temp")
|
||||
os.makedirs(_child_temp, exist_ok=True)
|
||||
folder_paths.temp_directory = _child_temp
|
||||
|
||||
return
|
||||
|
||||
if api_name == "ModelManagementProxy":
|
||||
import comfy.model_management
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
# Replace module-level functions with proxy methods
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(comfy.model_management, name, getattr(instance, name))
|
||||
return
|
||||
|
||||
if api_name == "UtilsProxy":
|
||||
import comfy.utils
|
||||
|
||||
# Static Injection of RPC mechanism to ensure Child can access it
|
||||
# independent of instance lifecycle.
|
||||
api.set_rpc(rpc)
|
||||
|
||||
# Don't overwrite host hook (infinite recursion)
|
||||
return
|
||||
|
||||
if api_name == "PromptServerProxy":
|
||||
# Defer heavy import to child context
|
||||
import server
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
proxy = (
|
||||
instance.instance
|
||||
) # PromptServerProxy instance has .instance property returning self
|
||||
|
||||
original_register_route = proxy.register_route
|
||||
|
||||
def register_route_wrapper(
|
||||
method: str, path: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
callback_id = rpc.register_callback(handler)
|
||||
loop = getattr(rpc, "loop", None)
|
||||
if loop and loop.is_running():
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
return None
|
||||
|
||||
proxy.register_route = register_route_wrapper
|
||||
|
||||
class RouteTableDefProxy:
|
||||
def __init__(self, proxy_instance: Any):
|
||||
self.proxy = proxy_instance
|
||||
|
||||
def get(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
proxy.routes = RouteTableDefProxy(proxy)
|
||||
|
||||
if (
|
||||
hasattr(server, "PromptServer")
|
||||
and getattr(server.PromptServer, "instance", None) != proxy
|
||||
):
|
||||
server.PromptServer.instance = proxy
|
||||
141
comfy/isolation/child_hooks.py
Normal file
141
comfy/isolation/child_hooks.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
|
||||
# Child process initialization for PyIsolate
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def initialize_child_process() -> None:
|
||||
# Manual RPC injection
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc:
|
||||
_setup_prompt_server_stub(rpc)
|
||||
_setup_utils_proxy(rpc)
|
||||
else:
|
||||
logger.warning("Could not get child RPC instance for manual injection")
|
||||
_setup_prompt_server_stub()
|
||||
_setup_utils_proxy()
|
||||
except Exception as e:
|
||||
logger.error(f"Manual RPC Injection failed: {e}")
|
||||
_setup_prompt_server_stub()
|
||||
_setup_utils_proxy()
|
||||
|
||||
_setup_logging()
|
||||
|
||||
|
||||
def _setup_prompt_server_stub(rpc=None) -> None:
|
||||
try:
|
||||
from .proxies.prompt_server_impl import PromptServerStub
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Mock server module
|
||||
if "server" not in sys.modules:
|
||||
mock_server = types.ModuleType("server")
|
||||
sys.modules["server"] = mock_server
|
||||
|
||||
server = sys.modules["server"]
|
||||
|
||||
if not hasattr(server, "PromptServer"):
|
||||
|
||||
class MockPromptServer:
|
||||
pass
|
||||
|
||||
server.PromptServer = MockPromptServer
|
||||
|
||||
stub = PromptServerStub()
|
||||
|
||||
if rpc:
|
||||
PromptServerStub.set_rpc(rpc)
|
||||
if hasattr(stub, "set_rpc"):
|
||||
stub.set_rpc(rpc)
|
||||
|
||||
server.PromptServer.instance = stub
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup PromptServerStub: {e}")
|
||||
|
||||
|
||||
def _setup_utils_proxy(rpc=None) -> None:
|
||||
try:
|
||||
import comfy.utils
|
||||
import asyncio
|
||||
|
||||
# Capture main loop during initialization (safe context)
|
||||
main_loop = None
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
try:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .proxies.base import set_global_loop
|
||||
|
||||
if main_loop:
|
||||
set_global_loop(main_loop)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Sync hook wrapper for progress updates
|
||||
def sync_hook_wrapper(
|
||||
value: int, total: int, preview: None = None, node_id: None = None
|
||||
) -> None:
|
||||
if node_id is None:
|
||||
try:
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
ctx = get_executing_context()
|
||||
if ctx:
|
||||
node_id = ctx.node_id
|
||||
else:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Bypass blocked event loop by direct outbox injection
|
||||
if rpc:
|
||||
try:
|
||||
# Use captured main loop if available (for threaded execution), or current loop
|
||||
loop = main_loop
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
rpc.outbox.put(
|
||||
{
|
||||
"kind": "call",
|
||||
"object_id": "UtilsProxy",
|
||||
"parent_call_id": None, # We are root here usually
|
||||
"calling_loop": loop,
|
||||
"future": loop.create_future(), # Dummy future
|
||||
"method": "progress_bar_hook",
|
||||
"args": (value, total, preview, node_id),
|
||||
"kwargs": {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
|
||||
else:
|
||||
logging.getLogger(__name__).warning(
|
||||
"No RPC instance available for progress update"
|
||||
)
|
||||
|
||||
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup UtilsProxy hook: {e}")
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
327
comfy/isolation/clip_proxy.py
Normal file
327
comfy/isolation/clip_proxy.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
|
||||
# CLIP Proxy implementation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
|
||||
class CondStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "cond_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
|
||||
_registry_class = CondStageModelRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CondStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class TokenizerRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "tokenizer"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
|
||||
_registry_class = TokenizerRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TokenizerProxy {self._instance_id}>"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "clip"
|
||||
_allowed_setters = {
|
||||
"layer_idx",
|
||||
"tokenizer_options",
|
||||
"use_clip_schedule",
|
||||
"apply_hooks_to_conds",
|
||||
}
|
||||
|
||||
async def get_ram_usage(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).get_ram_usage()
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
|
||||
|
||||
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
|
||||
|
||||
async def get_cond_stage_model_id(self, instance_id: str) -> str:
|
||||
return CondStageModelRegistry().register(
|
||||
self._get_instance(instance_id).cond_stage_model
|
||||
)
|
||||
|
||||
async def get_tokenizer_id(self, instance_id: str) -> str:
|
||||
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
|
||||
|
||||
async def load_model(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).load_model()
|
||||
|
||||
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
|
||||
self._get_instance(instance_id).clip_layer(layer_idx)
|
||||
|
||||
async def set_tokenizer_option(
|
||||
self, instance_id: str, option_name: str, value: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
|
||||
if name not in self._allowed_setters:
|
||||
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
|
||||
setattr(self._get_instance(instance_id), name, value)
|
||||
|
||||
async def tokenize(
|
||||
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).tokenize(
|
||||
text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
async def encode(self, instance_id: str, text: str) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(text))
|
||||
|
||||
async def encode_from_tokens(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
return_pooled: bool = False,
|
||||
return_dict: bool = False,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens(
|
||||
tokens, return_pooled=return_pooled, return_dict=return_dict
|
||||
)
|
||||
)
|
||||
|
||||
async def encode_from_tokens_scheduled(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens_scheduled(
|
||||
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
|
||||
)
|
||||
)
|
||||
|
||||
async def add_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).add_patches(
|
||||
patches, strength_patch=strength_patch, strength_model=strength_model
|
||||
)
|
||||
|
||||
async def get_key_patches(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_key_patches()
|
||||
|
||||
async def load_sd(
|
||||
self, instance_id: str, sd: dict, full_model: bool = False
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
|
||||
|
||||
async def get_sd(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_sd()
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
return self.register(self._get_instance(instance_id).clone())
|
||||
|
||||
|
||||
class CLIPProxy(BaseProxy[CLIPRegistry]):
|
||||
_registry_class = CLIPRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
@property
|
||||
def patcher(self) -> "ModelPatcherProxy":
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@patcher.setter
|
||||
def patcher(self, value: Any) -> None:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if isinstance(value, ModelPatcherProxy):
|
||||
self._patcher_proxy = value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
|
||||
)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self) -> CondStageModelProxy:
|
||||
if not hasattr(self, "_cond_stage_model_proxy"):
|
||||
csm_id = self._call_rpc("get_cond_stage_model_id")
|
||||
self._cond_stage_model_proxy = CondStageModelProxy(
|
||||
csm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._cond_stage_model_proxy
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerProxy:
|
||||
if not hasattr(self, "_tokenizer_proxy"):
|
||||
tok_id = self._call_rpc("get_tokenizer_id")
|
||||
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
|
||||
return self._tokenizer_proxy
|
||||
|
||||
def load_model(self) -> ModelPatcherProxy:
|
||||
self._call_rpc("load_model")
|
||||
return self.patcher
|
||||
|
||||
@property
|
||||
def layer_idx(self) -> Optional[int]:
|
||||
return self._call_rpc("get_property", "layer_idx")
|
||||
|
||||
@layer_idx.setter
|
||||
def layer_idx(self, value: Optional[int]) -> None:
|
||||
self._call_rpc("set_property", "layer_idx", value)
|
||||
|
||||
@property
|
||||
def tokenizer_options(self) -> dict:
|
||||
return self._call_rpc("get_property", "tokenizer_options")
|
||||
|
||||
@tokenizer_options.setter
|
||||
def tokenizer_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_property", "tokenizer_options", value)
|
||||
|
||||
@property
|
||||
def use_clip_schedule(self) -> bool:
|
||||
return self._call_rpc("get_property", "use_clip_schedule")
|
||||
|
||||
@use_clip_schedule.setter
|
||||
def use_clip_schedule(self, value: bool) -> None:
|
||||
self._call_rpc("set_property", "use_clip_schedule", value)
|
||||
|
||||
@property
|
||||
def apply_hooks_to_conds(self) -> Any:
|
||||
return self._call_rpc("get_property", "apply_hooks_to_conds")
|
||||
|
||||
@apply_hooks_to_conds.setter
|
||||
def apply_hooks_to_conds(self, value: Any) -> None:
|
||||
self._call_rpc("set_property", "apply_hooks_to_conds", value)
|
||||
|
||||
def clip_layer(self, layer_idx: int) -> None:
|
||||
return self._call_rpc("clip_layer", layer_idx)
|
||||
|
||||
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
|
||||
return self._call_rpc("set_tokenizer_option", option_name, value)
|
||||
|
||||
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(
|
||||
"tokenize", text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
def encode(self, text: str) -> Any:
|
||||
return self._call_rpc("encode", text)
|
||||
|
||||
def encode_from_tokens(
|
||||
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
|
||||
) -> Any:
|
||||
res = self._call_rpc(
|
||||
"encode_from_tokens",
|
||||
tokens,
|
||||
return_pooled=return_pooled,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if return_pooled and isinstance(res, list) and not return_dict:
|
||||
return tuple(res)
|
||||
return res
|
||||
|
||||
def encode_from_tokens_scheduled(
|
||||
self,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return self._call_rpc(
|
||||
"encode_from_tokens_scheduled",
|
||||
tokens,
|
||||
unprojected=unprojected,
|
||||
add_dict=add_dict,
|
||||
show_pbar=show_pbar,
|
||||
)
|
||||
|
||||
def add_patches(
|
||||
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"add_patches",
|
||||
patches,
|
||||
strength_patch=strength_patch,
|
||||
strength_model=strength_model,
|
||||
)
|
||||
|
||||
def get_key_patches(self) -> Any:
|
||||
return self._call_rpc("get_key_patches")
|
||||
|
||||
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
|
||||
return self._call_rpc("load_sd", sd, full_model=full_model)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def clone(self) -> CLIPProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
|
||||
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
|
||||
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()
|
||||
388
comfy/isolation/extension_loader.py
Normal file
388
comfy/isolation/extension_loader.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import pyisolate
|
||||
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||
from packaging.requirements import InvalidRequirement, Requirement
|
||||
from packaging.utils import canonicalize_name
|
||||
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||
from .host_policy import load_host_policy
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
|
||||
"""Register an isolated extension's web directory on the host side."""
|
||||
import nodes
|
||||
|
||||
# Method 1: pyproject.toml [tool.comfy] web field
|
||||
pyproject = node_dir / "pyproject.toml"
|
||||
if pyproject.exists():
|
||||
try:
|
||||
with pyproject.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
|
||||
if web_dir_name:
|
||||
web_dir_path = str(node_dir / web_dir_name)
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
|
||||
init_file = node_dir / "__init__.py"
|
||||
if init_file.exists():
|
||||
try:
|
||||
source = init_file.read_text()
|
||||
for line in source.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("WEB_DIRECTORY"):
|
||||
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
|
||||
_, _, value = stripped.partition("=")
|
||||
value = value.strip().strip("\"'")
|
||||
if value:
|
||||
web_dir_path = str((node_dir / value).resolve())
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _stop_extension_safe(
|
||||
extension: ComfyNodeExtension, extension_name: str
|
||||
) -> None:
|
||||
try:
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
except Exception:
|
||||
logger.debug("][ %s stop failed", extension_name, exc_info=True)
|
||||
|
||||
|
||||
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
|
||||
req, sep, marker = dep.partition(";")
|
||||
req = req.strip()
|
||||
marker_suffix = f";{marker}" if sep else ""
|
||||
|
||||
def _resolve_local_path(local_path: str) -> Path | None:
|
||||
for base in base_paths:
|
||||
candidate = (base / local_path).resolve()
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
if req.startswith("./") or req.startswith("../"):
|
||||
resolved = _resolve_local_path(req)
|
||||
if resolved is not None:
|
||||
return f"{resolved}{marker_suffix}"
|
||||
|
||||
if req.startswith("file://"):
|
||||
raw = req[len("file://") :]
|
||||
if raw.startswith("./") or raw.startswith("../"):
|
||||
resolved = _resolve_local_path(raw)
|
||||
if resolved is not None:
|
||||
return f"file://{resolved}{marker_suffix}"
|
||||
|
||||
return dep
|
||||
|
||||
|
||||
def _dependency_name_from_spec(dep: str) -> str | None:
|
||||
stripped = dep.strip()
|
||||
if not stripped or stripped == "-e" or stripped.startswith("-e "):
|
||||
return None
|
||||
if stripped.startswith(("/", "./", "../", "file://")):
|
||||
return None
|
||||
|
||||
try:
|
||||
return canonicalize_name(Requirement(stripped).name)
|
||||
except InvalidRequirement:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_cuda_wheels_config(
|
||||
tool_config: dict[str, object], dependencies: list[str]
|
||||
) -> dict[str, object] | None:
|
||||
raw_config = tool_config.get("cuda_wheels")
|
||||
if raw_config is None:
|
||||
return None
|
||||
if not isinstance(raw_config, dict):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels] must be a table"
|
||||
)
|
||||
|
||||
index_url = raw_config.get("index_url")
|
||||
if not isinstance(index_url, str) or not index_url.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
|
||||
)
|
||||
|
||||
packages = raw_config.get("packages")
|
||||
if not isinstance(packages, list) or not all(
|
||||
isinstance(package_name, str) and package_name.strip()
|
||||
for package_name in packages
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
|
||||
)
|
||||
|
||||
declared_dependencies = {
|
||||
dependency_name
|
||||
for dep in dependencies
|
||||
if (dependency_name := _dependency_name_from_spec(dep)) is not None
|
||||
}
|
||||
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
|
||||
missing = [
|
||||
package_name
|
||||
for package_name in normalized_packages
|
||||
if package_name not in declared_dependencies
|
||||
]
|
||||
if missing:
|
||||
missing_joined = ", ".join(sorted(missing))
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
|
||||
f"{missing_joined}"
|
||||
)
|
||||
|
||||
package_map = raw_config.get("package_map", {})
|
||||
if not isinstance(package_map, dict):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
|
||||
)
|
||||
|
||||
normalized_package_map: dict[str, str] = {}
|
||||
for dependency_name, index_package_name in package_map.items():
|
||||
if not isinstance(dependency_name, str) or not dependency_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
|
||||
)
|
||||
if not isinstance(index_package_name, str) or not index_package_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
|
||||
)
|
||||
canonical_dependency_name = canonicalize_name(dependency_name)
|
||||
if canonical_dependency_name not in normalized_packages:
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
|
||||
"[tool.comfy.isolation.cuda_wheels.packages]"
|
||||
)
|
||||
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
|
||||
|
||||
return {
|
||||
"index_url": index_url.rstrip("/") + "/",
|
||||
"packages": normalized_packages,
|
||||
"package_map": normalized_package_map,
|
||||
}
|
||||
|
||||
|
||||
def get_enforcement_policy() -> Dict[str, bool]:
|
||||
return {
|
||||
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
|
||||
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
|
||||
}
|
||||
|
||||
|
||||
class ExtensionLoadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
|
||||
normalized_name = extension_name.replace("-", "_").replace(".", "_")
|
||||
if normalized_name not in sys.modules:
|
||||
dummy_module = types.ModuleType(normalized_name)
|
||||
dummy_module.__file__ = str(node_dir / "__init__.py")
|
||||
dummy_module.__path__ = [str(node_dir)]
|
||||
dummy_module.__package__ = normalized_name
|
||||
sys.modules[normalized_name] = dummy_module
|
||||
|
||||
|
||||
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
|
||||
for details in cached_data.values():
|
||||
if not isinstance(details, dict):
|
||||
return True
|
||||
if details.get("is_v3") and "schema_v1" not in details:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def load_isolated_node(
|
||||
node_dir: Path,
|
||||
manifest_path: Path,
|
||||
logger: logging.Logger,
|
||||
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
|
||||
venv_root: Path,
|
||||
extension_managers: List[ExtensionManager],
|
||||
) -> List[Tuple[str, str, type]]:
|
||||
try:
|
||||
with manifest_path.open("rb") as handle:
|
||||
manifest_data = tomllib.load(handle)
|
||||
except Exception as e:
|
||||
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
|
||||
return []
|
||||
|
||||
# Parse [tool.comfy.isolation]
|
||||
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||
can_isolate = tool_config.get("can_isolate", False)
|
||||
share_torch = tool_config.get("share_torch", False)
|
||||
|
||||
# Parse [project] dependencies
|
||||
project_config = manifest_data.get("project", {})
|
||||
dependencies = project_config.get("dependencies", [])
|
||||
if not isinstance(dependencies, list):
|
||||
dependencies = []
|
||||
|
||||
# Get extension name (default to folder name if not in project.name)
|
||||
extension_name = project_config.get("name", node_dir.name)
|
||||
|
||||
# LOGIC: Isolation Decision
|
||||
policy = get_enforcement_policy()
|
||||
isolated = can_isolate or policy["force_isolated"]
|
||||
|
||||
if not isolated:
|
||||
return []
|
||||
|
||||
logger.info(f"][ Loading isolated node: {extension_name}")
|
||||
|
||||
import folder_paths
|
||||
|
||||
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||
dependencies = [
|
||||
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
|
||||
for dep in dependencies
|
||||
]
|
||||
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
|
||||
|
||||
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||
ComfyNodeExtension, manager_config
|
||||
)
|
||||
extension_managers.append(manager)
|
||||
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
|
||||
sandbox_config = {}
|
||||
is_linux = platform.system() == "Linux"
|
||||
if is_linux and isolated:
|
||||
sandbox_config = {
|
||||
"network": host_policy["allow_network"],
|
||||
"writable_paths": host_policy["writable_paths"],
|
||||
"readonly_paths": host_policy["readonly_paths"],
|
||||
}
|
||||
share_cuda_ipc = share_torch and is_linux
|
||||
|
||||
extension_config = {
|
||||
"name": extension_name,
|
||||
"module_path": str(node_dir),
|
||||
"isolated": True,
|
||||
"dependencies": dependencies,
|
||||
"share_torch": share_torch,
|
||||
"share_cuda_ipc": share_cuda_ipc,
|
||||
"sandbox_mode": host_policy["sandbox_mode"],
|
||||
"sandbox": sandbox_config,
|
||||
}
|
||||
if cuda_wheels is not None:
|
||||
extension_config["cuda_wheels"] = cuda_wheels
|
||||
|
||||
extension = manager.load_extension(extension_config)
|
||||
register_dummy_module(extension_name, node_dir)
|
||||
|
||||
# Register web directory on the host — only when sandbox is disabled.
|
||||
# In sandbox mode, serving untrusted JS to the browser is not safe.
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# Try cache first (lazy spawn)
|
||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||
cached_data = load_from_cache(node_dir, venv_root)
|
||||
if cached_data:
|
||||
if _is_stale_node_cache(cached_data):
|
||||
logger.debug(
|
||||
"][ %s cache is stale/incompatible; rebuilding metadata",
|
||||
extension_name,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"][ {extension_name} loaded from cache")
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
for node_name, details in cached_data.items():
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append(
|
||||
(node_name, details.get("display_name", node_name), stub_cls)
|
||||
)
|
||||
return specs
|
||||
|
||||
# Cache miss - spawn process and get metadata
|
||||
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
|
||||
|
||||
try:
|
||||
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s metadata discovery failed, skipping isolated load: %s",
|
||||
extension_name,
|
||||
exc,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
if not remote_nodes:
|
||||
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
cache_data: Dict[str, Dict] = {}
|
||||
|
||||
for node_name, display_name in remote_nodes.items():
|
||||
try:
|
||||
details = await extension.get_node_details(node_name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s failed to load metadata for %s, skipping node: %s",
|
||||
extension_name,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
details["display_name"] = display_name
|
||||
cache_data[node_name] = details
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append((node_name, display_name, stub_cls))
|
||||
|
||||
if not specs:
|
||||
logger.warning(
|
||||
"][ %s produced no usable nodes after metadata scan; skipping",
|
||||
extension_name,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
# Save metadata to cache for future runs
|
||||
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||
logger.debug(f"][ {extension_name} metadata cached")
|
||||
|
||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
|
||||
return specs
|
||||
|
||||
|
||||
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]
|
||||
699
comfy/isolation/extension_wrapper.py
Normal file
699
comfy/isolation/extension_wrapper.py
Normal file
@@ -0,0 +1,699 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError as e:
|
||||
raise AttributeError(item) from e
|
||||
|
||||
def copy(self):
|
||||
return AttrDict(super().copy())
|
||||
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pyisolate import ExtensionBase
|
||||
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
V3_DISCOVERY_TIMEOUT = 30
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str) -> int:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return 0
|
||||
if not callable(flush_tensor_keeper):
|
||||
return 0
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
return flushed
|
||||
|
||||
|
||||
def _relieve_child_vram_pressure(marker: str) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _sanitize_for_transport(value):
|
||||
primitives = (str, int, float, bool, type(None))
|
||||
if isinstance(value, primitives):
|
||||
return value
|
||||
|
||||
cls_name = value.__class__.__name__
|
||||
if cls_name == "FlexibleOptionalInputType":
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _sanitize_for_transport(getattr(value, "type", "*")),
|
||||
}
|
||||
if cls_name == "AnyType":
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if cls_name == "ByPassTypeTuple":
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [
|
||||
_sanitize_for_transport(v) for v in tuple(value)
|
||||
]
|
||||
}
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _sanitize_for_transport(v) for k, v in value.items()}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_for_transport(v) for v in value]
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
|
||||
# The canonical definition is now in pyisolate._internal.remote_handle
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
|
||||
|
||||
|
||||
class ComfyNodeExtension(ExtensionBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.node_classes: Dict[str, type] = {}
|
||||
self.display_names: Dict[str, str] = {}
|
||||
self.node_instances: Dict[str, Any] = {}
|
||||
self.remote_objects: Dict[str, Any] = {}
|
||||
self._route_handlers: Dict[str, Any] = {}
|
||||
self._module: Any = None
|
||||
|
||||
async def on_module_loaded(self, module: Any) -> None:
|
||||
self._module = module
|
||||
|
||||
# Registries are initialized in host_hooks.py initialize_host_process()
|
||||
# They auto-register via ProxiedSingleton when instantiated
|
||||
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
|
||||
|
||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||
|
||||
try:
|
||||
from comfy_api.latest import ComfyExtension
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if not (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, ComfyExtension)
|
||||
and obj is not ComfyExtension
|
||||
):
|
||||
continue
|
||||
if not obj.__module__.startswith(module.__name__):
|
||||
continue
|
||||
try:
|
||||
ext_instance = obj()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in on_load()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
v3_nodes = await asyncio.wait_for(
|
||||
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in get_node_list()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
for node_cls in v3_nodes:
|
||||
if hasattr(node_cls, "GET_SCHEMA"):
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
self.node_classes[schema.node_id] = node_cls
|
||||
if schema.display_name:
|
||||
self.display_names[schema.node_id] = schema.display_name
|
||||
except Exception as e:
|
||||
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
module_name = getattr(module, "__name__", "isolated_nodes")
|
||||
for node_cls in self.node_classes.values():
|
||||
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
|
||||
node_cls.__module__ = module_name
|
||||
|
||||
self.node_instances = {}
|
||||
|
||||
async def list_nodes(self) -> Dict[str, str]:
|
||||
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||
|
||||
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
|
||||
return await self.get_node_details(node_name)
|
||||
|
||||
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||
|
||||
input_types_raw = (
|
||||
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||
)
|
||||
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
|
||||
if output_is_list is not None:
|
||||
output_is_list = tuple(bool(x) for x in output_is_list)
|
||||
|
||||
details: Dict[str, Any] = {
|
||||
"input_types": _sanitize_for_transport(input_types_raw),
|
||||
"return_types": tuple(
|
||||
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
|
||||
),
|
||||
"return_names": getattr(node_cls, "RETURN_NAMES", None),
|
||||
"function": str(getattr(node_cls, "FUNCTION", "execute")),
|
||||
"category": str(getattr(node_cls, "CATEGORY", "")),
|
||||
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
|
||||
"output_is_list": output_is_list,
|
||||
"is_v3": is_v3,
|
||||
}
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||
try:
|
||||
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||
except (AttributeError, TypeError):
|
||||
schema_v3 = self._build_schema_v3_fallback(schema)
|
||||
details.update(
|
||||
{
|
||||
"schema_v1": schema_v1,
|
||||
"schema_v3": schema_v3,
|
||||
"hidden": [h.value for h in (schema.hidden or [])],
|
||||
"description": getattr(schema, "description", ""),
|
||||
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
|
||||
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
|
||||
"api_node": bool(getattr(node_cls, "API_NODE", False)),
|
||||
"input_is_list": bool(
|
||||
getattr(node_cls, "INPUT_IS_LIST", False)
|
||||
),
|
||||
"not_idempotent": bool(
|
||||
getattr(node_cls, "NOT_IDEMPOTENT", False)
|
||||
),
|
||||
"accept_all_inputs": bool(
|
||||
getattr(node_cls, "ACCEPT_ALL_INPUTS", False)
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"%s V3 schema serialization failed for %s: %s",
|
||||
LOG_PREFIX,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
return details
|
||||
|
||||
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
|
||||
input_dict: Dict[str, Any] = {}
|
||||
output_dict: Dict[str, Any] = {}
|
||||
hidden_list: List[str] = []
|
||||
|
||||
if getattr(schema, "inputs", None):
|
||||
for inp in schema.inputs:
|
||||
self._add_schema_io_v3(inp, input_dict)
|
||||
if getattr(schema, "outputs", None):
|
||||
for out in schema.outputs:
|
||||
self._add_schema_io_v3(out, output_dict)
|
||||
if getattr(schema, "hidden", None):
|
||||
for h in schema.hidden:
|
||||
hidden_list.append(getattr(h, "value", str(h)))
|
||||
|
||||
return {
|
||||
"input": input_dict,
|
||||
"output": output_dict,
|
||||
"hidden": hidden_list,
|
||||
"name": getattr(schema, "node_id", None),
|
||||
"display_name": getattr(schema, "display_name", None),
|
||||
"description": getattr(schema, "description", None),
|
||||
"category": getattr(schema, "category", None),
|
||||
"output_node": getattr(schema, "is_output_node", False),
|
||||
"deprecated": getattr(schema, "is_deprecated", False),
|
||||
"experimental": getattr(schema, "is_experimental", False),
|
||||
"api_node": getattr(schema, "is_api_node", False),
|
||||
}
|
||||
|
||||
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
|
||||
io_id = getattr(io_obj, "id", None)
|
||||
if io_id is None:
|
||||
return
|
||||
|
||||
io_type_fn = getattr(io_obj, "get_io_type", None)
|
||||
io_type = (
|
||||
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
|
||||
)
|
||||
|
||||
as_dict_fn = getattr(io_obj, "as_dict", None)
|
||||
payload = as_dict_fn() if callable(as_dict_fn) else {}
|
||||
|
||||
target[str(io_id)] = (io_type, payload)
|
||||
|
||||
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
if hasattr(node_cls, "INPUT_TYPES"):
|
||||
return node_cls.INPUT_TYPES()
|
||||
return {}
|
||||
|
||||
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(inputs),
|
||||
)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
_relieve_child_vram_pressure("EXT:pre_execute")
|
||||
|
||||
resolved_inputs = self._resolve_remote_objects(inputs)
|
||||
|
||||
instance = self._get_node_instance(node_name)
|
||||
node_cls = self._get_node_class(node_name)
|
||||
|
||||
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
|
||||
# Hidden params come through RPC as string keys like "Hidden.prompt"
|
||||
from comfy_api.latest._io import Hidden, HiddenHolder
|
||||
|
||||
# Map string representations back to Hidden enum keys
|
||||
hidden_string_map = {
|
||||
"Hidden.unique_id": Hidden.unique_id,
|
||||
"Hidden.prompt": Hidden.prompt,
|
||||
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
|
||||
"Hidden.dynprompt": Hidden.dynprompt,
|
||||
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
|
||||
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
|
||||
# Uppercase enum VALUE forms — V3 execution engine passes these
|
||||
"UNIQUE_ID": Hidden.unique_id,
|
||||
"PROMPT": Hidden.prompt,
|
||||
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
|
||||
"DYNPROMPT": Hidden.dynprompt,
|
||||
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
|
||||
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Find and extract hidden parameters (both enum and string form)
|
||||
hidden_found = {}
|
||||
keys_to_remove = []
|
||||
|
||||
for key in list(resolved_inputs.keys()):
|
||||
# Check string form first (from RPC serialization)
|
||||
if key in hidden_string_map:
|
||||
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
# Also check enum form (direct calls)
|
||||
elif isinstance(key, Hidden):
|
||||
hidden_found[key] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
|
||||
# Remove hidden params from kwargs
|
||||
for key in keys_to_remove:
|
||||
resolved_inputs.pop(key)
|
||||
|
||||
# Set hidden on node class if any hidden params found
|
||||
if hidden_found:
|
||||
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
|
||||
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
|
||||
else:
|
||||
# Update existing hidden holder
|
||||
for key, value in hidden_found.items():
|
||||
setattr(node_cls.hidden, key.value.lower(), value)
|
||||
|
||||
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||
if not hasattr(instance, function_name):
|
||||
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||
|
||||
handler = getattr(instance, function_name)
|
||||
|
||||
try:
|
||||
import torch
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
with torch.inference_mode():
|
||||
result = await handler(**resolved_inputs)
|
||||
else:
|
||||
import functools
|
||||
|
||||
def _run_with_inference_mode(**kwargs):
|
||||
with torch.inference_mode():
|
||||
return handler(**kwargs)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, functools.partial(_run_with_inference_mode, **resolved_inputs)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:child_execute_error ext=%s node=%s",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
raise
|
||||
|
||||
if type(result).__name__ == "NodeOutput":
|
||||
node_output_dict = {
|
||||
"__node_output__": True,
|
||||
"args": self._wrap_unpicklable_objects(result.args),
|
||||
}
|
||||
if result.ui is not None:
|
||||
node_output_dict["ui"] = result.ui
|
||||
if getattr(result, "expand", None) is not None:
|
||||
node_output_dict["expand"] = result.expand
|
||||
if getattr(result, "block_execution", None) is not None:
|
||||
node_output_dict["block_execution"] = result.block_execution
|
||||
return node_output_dict
|
||||
if self._is_comfy_protocol_return(result):
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
return 0
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
|
||||
)
|
||||
flushed = _flush_tensor_transport_state("EXT:workflow_end")
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
removed = registry.sweep_pending_cleanup()
|
||||
if removed > 0:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_done ext=%s flushed=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
flushed,
|
||||
)
|
||||
return flushed
|
||||
|
||||
async def get_remote_object(self, object_id: str) -> Any:
|
||||
"""Retrieve a remote object by ID for host-side deserialization."""
|
||||
if object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {object_id} not found")
|
||||
|
||||
return self.remote_objects[object_id]
|
||||
|
||||
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
tensor = data.detach() if data.requires_grad else data
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
|
||||
return tensor.cpu()
|
||||
return tensor
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
data, "last_hidden_state"
|
||||
):
|
||||
fields = {}
|
||||
for attr in (
|
||||
"penultimate_hidden_states",
|
||||
"last_hidden_state",
|
||||
"image_embeds",
|
||||
"text_embeds",
|
||||
):
|
||||
if hasattr(data, attr):
|
||||
try:
|
||||
fields[attr] = self._wrap_unpicklable_objects(
|
||||
getattr(data, attr)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if fields:
|
||||
return {"__pyisolate_attribute_container__": True, "data": fields}
|
||||
|
||||
# Avoid converting arbitrary objects with stateful methods (models, etc.)
|
||||
# They will be handled via RemoteObjectHandle below.
|
||||
|
||||
type_name = type(data).__name__
|
||||
if type_name == "ModelPatcherProxy":
|
||||
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
|
||||
if type_name == "CLIPProxy":
|
||||
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
|
||||
if type_name == "VAEProxy":
|
||||
return {"__type__": "VAERef", "vae_id": data._instance_id}
|
||||
if type_name == "ModelSamplingProxy":
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
|
||||
return tuple(wrapped) if isinstance(data, tuple) else wrapped
|
||||
if isinstance(data, dict):
|
||||
converted_dict = {
|
||||
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
|
||||
}
|
||||
return {"__pyisolate_attrdict__": True, "data": converted_dict}
|
||||
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
|
||||
registry = SerializerRegistry.get_instance()
|
||||
if registry.is_data_type(type_name):
|
||||
serializer = registry.get_serializer(type_name)
|
||||
if serializer:
|
||||
return serializer(data)
|
||||
|
||||
object_id = str(uuid.uuid4())
|
||||
self.remote_objects[object_id] = data
|
||||
return RemoteObjectHandle(object_id, type(data).__name__)
|
||||
|
||||
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, RemoteObjectHandle):
|
||||
if data.object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {data.object_id} not found")
|
||||
return self.remote_objects[data.object_id]
|
||||
|
||||
if isinstance(data, dict):
|
||||
ref_type = data.get("__type__")
|
||||
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
if ref_type == "ModelSamplingRef":
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
resolved = [self._resolve_remote_objects(item) for item in data]
|
||||
return tuple(resolved) if isinstance(data, tuple) else resolved
|
||||
return data
|
||||
|
||||
def _get_node_class(self, node_name: str) -> type:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
return self.node_classes[node_name]
|
||||
|
||||
def _get_node_instance(self, node_name: str) -> Any:
|
||||
if node_name not in self.node_instances:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
self.node_instances[node_name] = self.node_classes[node_name]()
|
||||
return self.node_instances[node_name]
|
||||
|
||||
async def before_module_loaded(self) -> None:
|
||||
# Inject initialization here if we think this is the child
|
||||
try:
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
initialize_proxies()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(
|
||||
f"Failed to call initialize_proxies in before_module_loaded: {e}"
|
||||
)
|
||||
|
||||
await super().before_module_loaded()
|
||||
try:
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
|
||||
ComfyAPI_latest.Execution = ProgressProxy
|
||||
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
|
||||
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
|
||||
# latest_ui.folder_paths = fp_proxy
|
||||
# latest_resources.folder_paths = fp_proxy
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def call_route_handler(
|
||||
self,
|
||||
handler_module: str,
|
||||
handler_func: str,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Any:
|
||||
cache_key = f"{handler_module}.{handler_func}"
|
||||
if cache_key not in self._route_handlers:
|
||||
if self._module is not None and hasattr(self._module, "__file__"):
|
||||
node_dir = os.path.dirname(self._module.__file__)
|
||||
if node_dir not in sys.path:
|
||||
sys.path.insert(0, node_dir)
|
||||
try:
|
||||
module = importlib.import_module(handler_module)
|
||||
self._route_handlers[cache_key] = getattr(module, handler_func)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Route handler not found: {cache_key}") from e
|
||||
|
||||
handler = self._route_handlers[cache_key]
|
||||
mock_request = MockRequest(request_data)
|
||||
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(mock_request)
|
||||
else:
|
||||
result = handler(mock_request)
|
||||
return self._serialize_response(result)
|
||||
|
||||
def _is_comfy_protocol_return(self, result: Any) -> bool:
|
||||
"""
|
||||
Check if the result matches the ComfyUI 'Protocol Return' schema.
|
||||
|
||||
A Protocol Return is a dictionary containing specific reserved keys that
|
||||
ComfyUI's execution engine interprets as instructions (UI updates,
|
||||
Workflow expansion, etc.) rather than purely data outputs.
|
||||
|
||||
Schema:
|
||||
- Must be a dict
|
||||
- Must contain at least one of: 'ui', 'result', 'expand'
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
return any(key in result for key in ("ui", "result", "expand"))
|
||||
|
||||
def _serialize_response(self, response: Any) -> Dict[str, Any]:
|
||||
if response is None:
|
||||
return {"type": "text", "body": "", "status": 204}
|
||||
if isinstance(response, dict):
|
||||
return {"type": "json", "body": response, "status": 200}
|
||||
if isinstance(response, str):
|
||||
return {"type": "text", "body": response, "status": 200}
|
||||
if hasattr(response, "text") and hasattr(response, "status"):
|
||||
return {
|
||||
"type": "text",
|
||||
"body": response.text
|
||||
if hasattr(response, "text")
|
||||
else str(response.body),
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers)
|
||||
if hasattr(response, "headers")
|
||||
else {},
|
||||
}
|
||||
if hasattr(response, "body") and hasattr(response, "status"):
|
||||
body = response.body
|
||||
if isinstance(body, bytes):
|
||||
try:
|
||||
return {
|
||||
"type": "text",
|
||||
"body": body.decode("utf-8"),
|
||||
"status": response.status,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"type": "binary",
|
||||
"body": body.hex(),
|
||||
"status": response.status,
|
||||
}
|
||||
return {"type": "json", "body": body, "status": response.status}
|
||||
return {"type": "text", "body": str(response), "status": 200}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.method = data.get("method", "GET")
|
||||
self.path = data.get("path", "/")
|
||||
self.query = data.get("query", {})
|
||||
self._body = data.get("body", {})
|
||||
self._text = data.get("text", "")
|
||||
self.headers = data.get("headers", {})
|
||||
self.content_type = data.get(
|
||||
"content_type", self.headers.get("Content-Type", "application/json")
|
||||
)
|
||||
self.match_info = data.get("match_info", {})
|
||||
|
||||
async def json(self) -> Any:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
if isinstance(self._body, str):
|
||||
return json.loads(self._body)
|
||||
return {}
|
||||
|
||||
async def post(self) -> Dict[str, Any]:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
return {}
|
||||
|
||||
async def text(self) -> str:
|
||||
if self._text:
|
||||
return self._text
|
||||
if isinstance(self._body, str):
|
||||
return self._body
|
||||
if isinstance(self._body, dict):
|
||||
return json.dumps(self._body)
|
||||
return ""
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return (await self.text()).encode("utf-8")
|
||||
26
comfy/isolation/host_hooks.py
Normal file
26
comfy/isolation/host_hooks.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# Host process initialization for PyIsolate
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_host_process() -> None:
|
||||
root = logging.getLogger()
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
root.addHandler(logging.NullHandler())
|
||||
|
||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from .proxies.model_management_proxy import ModelManagementProxy
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
from .proxies.prompt_server_impl import PromptServerService
|
||||
from .proxies.utils_proxy import UtilsProxy
|
||||
from .vae_proxy import VAERegistry
|
||||
|
||||
FolderPathsProxy()
|
||||
ModelManagementProxy()
|
||||
ProgressProxy()
|
||||
PromptServerService()
|
||||
UtilsProxy()
|
||||
VAERegistry()
|
||||
107
comfy/isolation/host_policy.py
Normal file
107
comfy/isolation/host_policy.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# pylint: disable=logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
|
||||
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
|
||||
|
||||
|
||||
class HostSecurityPolicy(TypedDict):
|
||||
sandbox_mode: str
|
||||
allow_network: bool
|
||||
writable_paths: List[str]
|
||||
readonly_paths: List[str]
|
||||
whitelist: Dict[str, str]
|
||||
|
||||
|
||||
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": ["/dev/shm", "/tmp"],
|
||||
"readonly_paths": [],
|
||||
"whitelist": {},
|
||||
}
|
||||
|
||||
|
||||
def _default_policy() -> HostSecurityPolicy:
|
||||
return {
|
||||
"sandbox_mode": DEFAULT_POLICY["sandbox_mode"],
|
||||
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||
}
|
||||
|
||||
|
||||
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
|
||||
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
|
||||
policy = _default_policy()
|
||||
|
||||
if not config_path.exists():
|
||||
logger.debug("Host policy file missing at %s, using defaults.", config_path)
|
||||
return policy
|
||||
|
||||
try:
|
||||
with config_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse host policy from %s, using defaults.",
|
||||
config_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return policy
|
||||
|
||||
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
|
||||
if not isinstance(tool_config, dict):
|
||||
logger.debug("No [tool.comfy.host] section found, using defaults.")
|
||||
return policy
|
||||
|
||||
sandbox_mode = tool_config.get("sandbox_mode")
|
||||
if isinstance(sandbox_mode, str):
|
||||
normalized_sandbox_mode = sandbox_mode.strip().lower()
|
||||
if normalized_sandbox_mode in VALID_SANDBOX_MODES:
|
||||
policy["sandbox_mode"] = normalized_sandbox_mode
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid host sandbox_mode %r in %s, using default %r.",
|
||||
sandbox_mode,
|
||||
config_path,
|
||||
DEFAULT_POLICY["sandbox_mode"],
|
||||
)
|
||||
|
||||
if "allow_network" in tool_config:
|
||||
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||
|
||||
if "writable_paths" in tool_config:
|
||||
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
|
||||
|
||||
if "readonly_paths" in tool_config:
|
||||
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||
|
||||
whitelist_raw = tool_config.get("whitelist")
|
||||
if isinstance(whitelist_raw, dict):
|
||||
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
|
||||
|
||||
logger.debug(
|
||||
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
||||
len(policy["whitelist"]),
|
||||
policy["sandbox_mode"],
|
||||
policy["allow_network"],
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]
|
||||
186
comfy/isolation/manifest_loader.py
Normal file
186
comfy/isolation/manifest_loader.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import folder_paths
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_SUBDIR = "cache"
|
||||
CACHE_KEY_FILE = "cache_key"
|
||||
CACHE_DATA_FILE = "node_info.json"
|
||||
CACHE_KEY_LENGTH = 16
|
||||
|
||||
|
||||
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
||||
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
|
||||
manifest_dirs: List[Tuple[Path, Path]] = []
|
||||
|
||||
# Standard custom_nodes paths
|
||||
for base_path in folder_paths.get_folder_paths("custom_nodes"):
|
||||
base = Path(base_path)
|
||||
if not base.exists() or not base.is_dir():
|
||||
continue
|
||||
|
||||
for entry in base.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
# Look for pyproject.toml
|
||||
manifest = entry / "pyproject.toml"
|
||||
if not manifest.exists():
|
||||
continue
|
||||
|
||||
# Validate [tool.comfy.isolation] section existence
|
||||
try:
|
||||
with manifest.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
if (
|
||||
"tool" in data
|
||||
and "comfy" in data["tool"]
|
||||
and "isolation" in data["tool"]["comfy"]
|
||||
):
|
||||
manifest_dirs.append((entry, manifest))
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return manifest_dirs
|
||||
|
||||
|
||||
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
|
||||
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
try:
|
||||
# Hashing the manifest content ensures config changes invalidate cache
|
||||
hasher.update(manifest_path.read_bytes())
|
||||
except OSError:
|
||||
hasher.update(b"__manifest_read_error__")
|
||||
|
||||
try:
|
||||
py_files = sorted(node_dir.rglob("*.py"))
|
||||
for py_file in py_files:
|
||||
rel_path = py_file.relative_to(node_dir)
|
||||
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
|
||||
continue
|
||||
hasher.update(str(rel_path).encode("utf-8"))
|
||||
try:
|
||||
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
|
||||
except OSError:
|
||||
hasher.update(b"__file_stat_error__")
|
||||
except OSError:
|
||||
hasher.update(b"__dir_scan_error__")
|
||||
|
||||
hasher.update(sys.version.encode("utf-8"))
|
||||
|
||||
try:
|
||||
import pyisolate
|
||||
|
||||
hasher.update(pyisolate.__version__.encode("utf-8"))
|
||||
except (ImportError, AttributeError):
|
||||
hasher.update(b"__pyisolate_unknown__")
|
||||
|
||||
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
|
||||
|
||||
|
||||
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
|
||||
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
|
||||
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
|
||||
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
|
||||
|
||||
|
||||
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
|
||||
"""Return True only if stored cache key matches current computed key."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_key_file.exists() or not cache_data_file.exists():
|
||||
return False
|
||||
current_key = compute_cache_key(node_dir, manifest_path)
|
||||
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
|
||||
return current_key == stored_key
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
|
||||
"""Load node metadata from cache, return None on any error."""
|
||||
try:
|
||||
_, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_data_file.exists():
|
||||
return None
|
||||
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_to_cache(
|
||||
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
|
||||
) -> None:
|
||||
"""Save node metadata and cache key atomically."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
cache_dir = cache_key_file.parent
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_key = compute_cache_key(node_dir, manifest_path)
|
||||
|
||||
# Atomic write: data
|
||||
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
|
||||
json.dump(node_data, f, indent=2)
|
||||
os.replace(tmp_data_path, cache_data_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_data_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Atomic write: key
|
||||
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
|
||||
f.write(cache_key)
|
||||
os.replace(tmp_key_path, cache_key_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_key_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"find_manifest_directories",
|
||||
"compute_cache_key",
|
||||
"get_cache_path",
|
||||
"is_cache_valid",
|
||||
"load_from_cache",
|
||||
"save_to_cache",
|
||||
]
|
||||
861
comfy/isolation/model_patcher_proxy.py
Normal file
861
comfy/isolation/model_patcher_proxy.py
Normal file
@@ -0,0 +1,861 @@
|
||||
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
|
||||
# RPC proxy for ModelPatcher (parent process)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, List, Set, Dict, Callable
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
AutoPatcherEjector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
_registry_class = ModelPatcherRegistry
|
||||
__module__ = "comfy.model_patcher"
|
||||
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
else:
|
||||
self._rpc_caller = self._registry
|
||||
return self._rpc_caller
|
||||
|
||||
def get_all_callbacks(self, call_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_callbacks", call_type)
|
||||
|
||||
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_wrappers", wrapper_type)
|
||||
|
||||
def _load_list(self, *args, **kwargs) -> Any:
|
||||
return self._call_rpc("load_list_internal", *args, **kwargs)
|
||||
|
||||
def prepare_hook_patches_current_keyframe(
|
||||
self, t: Any, hook_group: Any, model_options: Any
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
|
||||
)
|
||||
|
||||
def add_hook_patches(
|
||||
self,
|
||||
hook: Any,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"add_hook_patches", hook, patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
def clear_cached_hook_weights(self) -> None:
|
||||
self._call_rpc("clear_cached_hook_weights")
|
||||
|
||||
def get_combined_hook_patches(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("get_combined_hook_patches", hooks)
|
||||
|
||||
def get_additional_models_with_key(self, key: str) -> Any:
|
||||
return self._call_rpc("get_additional_models_with_key", key)
|
||||
|
||||
@property
|
||||
def object_patches(self) -> Any:
|
||||
return self._call_rpc("get_object_patches")
|
||||
|
||||
@property
|
||||
def patches(self) -> Any:
|
||||
res = self._call_rpc("get_patches")
|
||||
if isinstance(res, dict):
|
||||
new_res = {}
|
||||
for k, v in res.items():
|
||||
new_list = []
|
||||
for item in v:
|
||||
if isinstance(item, list):
|
||||
new_list.append(tuple(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
new_res[k] = new_list
|
||||
return new_res
|
||||
return res
|
||||
|
||||
@property
|
||||
def pinned(self) -> Set:
|
||||
val = self._call_rpc("get_patcher_attr", "pinned")
|
||||
return set(val) if val is not None else set()
|
||||
|
||||
@property
|
||||
def hook_patches(self) -> Dict:
|
||||
val = self._call_rpc("get_patcher_attr", "hook_patches")
|
||||
if val is None:
|
||||
return {}
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import json
|
||||
|
||||
new_val = {}
|
||||
for k, v in val.items():
|
||||
if isinstance(k, str):
|
||||
if k.startswith("PYISOLATE_HOOKREF:"):
|
||||
ref_id = k.split(":", 1)[1]
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
elif k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[len("__pyisolate_key__") :]
|
||||
data = json.loads(json_str)
|
||||
ref_id = None
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and item[0] == "id"
|
||||
):
|
||||
ref_id = item[1]
|
||||
break
|
||||
if ref_id:
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
except Exception:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
return new_val
|
||||
except ImportError:
|
||||
return val
|
||||
|
||||
def set_hook_mode(self, hook_mode: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", hook_mode)
|
||||
|
||||
def register_all_hook_patches(
|
||||
self,
|
||||
hooks: Any,
|
||||
target_dict: Any,
|
||||
model_options: Any = None,
|
||||
registered: Any = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"register_all_hook_patches", hooks, target_dict, model_options, registered
|
||||
)
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
if isinstance(other, ModelPatcherProxy):
|
||||
return self._call_rpc("is_clone_by_id", other._instance_id)
|
||||
return False
|
||||
|
||||
def clone(self) -> ModelPatcherProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return ModelPatcherProxy(
|
||||
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
|
||||
)
|
||||
|
||||
def clone_has_same_weights(self, clone: Any) -> bool:
|
||||
if isinstance(clone, ModelPatcherProxy):
|
||||
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
|
||||
if not IS_CHILD_PROCESS:
|
||||
return self._call_rpc("is_clone", clone)
|
||||
return False
|
||||
|
||||
def get_model_object(self, name: str) -> Any:
|
||||
return self._call_rpc("get_model_object", name)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
data = self._call_rpc("get_model_options")
|
||||
import json
|
||||
|
||||
def _decode_keys(obj):
|
||||
if isinstance(obj, dict):
|
||||
new_d = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[17:]
|
||||
val = json.loads(json_str)
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
new_d[val] = _decode_keys(v)
|
||||
except:
|
||||
new_d[k] = _decode_keys(v)
|
||||
else:
|
||||
new_d[k] = _decode_keys(v)
|
||||
return new_d
|
||||
if isinstance(obj, list):
|
||||
return [_decode_keys(x) for x in obj]
|
||||
return obj
|
||||
|
||||
return _decode_keys(data)
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_model_options", value)
|
||||
|
||||
def apply_hooks(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("apply_hooks", hooks)
|
||||
|
||||
def prepare_state(self, timestep: Any) -> Any:
|
||||
return self._call_rpc("prepare_state", timestep)
|
||||
|
||||
def restore_hook_patches(self) -> None:
|
||||
self._call_rpc("restore_hook_patches")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
|
||||
self._call_rpc("unpatch_hooks", whitelist_keys_set)
|
||||
|
||||
def model_patches_to(self, device: Any) -> Any:
|
||||
return self._call_rpc("model_patches_to", device)
|
||||
|
||||
def partially_load(
|
||||
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"partially_load", device, extra_memory, force_patch_weights
|
||||
)
|
||||
|
||||
def partially_unload(
|
||||
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
|
||||
) -> int:
|
||||
return self._call_rpc(
|
||||
"partially_unload", device_to, memory_to_free, force_patch_weights
|
||||
)
|
||||
|
||||
def load(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
)
|
||||
|
||||
def patch_model(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
self._call_rpc(
|
||||
"patch_model",
|
||||
device_to,
|
||||
lowvram_model_memory,
|
||||
load_weights,
|
||||
force_patch_weights,
|
||||
)
|
||||
return self
|
||||
|
||||
def unpatch_model(
|
||||
self, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._call_rpc("unpatch_model", device_to, unpatch_weights)
|
||||
|
||||
def detach(self, unpatch_all: bool = True) -> Any:
|
||||
self._call_rpc("detach", unpatch_all)
|
||||
return self.model
|
||||
|
||||
def _cpu_tensor_bytes(self, obj: Any) -> int:
|
||||
import torch
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device.type == "cpu":
|
||||
return obj.nbytes
|
||||
return 0
|
||||
if isinstance(obj, dict):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj)
|
||||
return 0
|
||||
|
||||
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
|
||||
if required_bytes <= 0:
|
||||
return True
|
||||
|
||||
import torch
|
||||
import comfy.model_management as model_management
|
||||
|
||||
target_raw = self.load_device
|
||||
try:
|
||||
if isinstance(target_raw, torch.device):
|
||||
target = target_raw
|
||||
elif isinstance(target_raw, str):
|
||||
target = torch.device(target_raw)
|
||||
elif isinstance(target_raw, int):
|
||||
target = torch.device(f"cuda:{target_raw}")
|
||||
else:
|
||||
target = torch.device(target_raw)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if target.type != "cuda":
|
||||
return True
|
||||
|
||||
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
|
||||
if model_management.get_free_memory(target) >= required:
|
||||
return True
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.free_memory(required, target, for_dynamic=True)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
|
||||
model_management.free_memory(required, target, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.load_models_gpu(
|
||||
[self],
|
||||
minimum_memory_required=required,
|
||||
)
|
||||
|
||||
return model_management.get_free_memory(target) >= required
|
||||
|
||||
def apply_model(self, *args, **kwargs) -> Any:
|
||||
import torch
|
||||
|
||||
def _preferred_device() -> Any:
|
||||
for value in args:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
return None
|
||||
|
||||
def _move_result_to_device(obj: Any, device: Any) -> Any:
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.to(device) if obj.device != device else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_move_result_to_device(v, device) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_move_result_to_device(v, device) for v in obj)
|
||||
return obj
|
||||
|
||||
# DynamicVRAM models must keep load/offload decisions in host process.
|
||||
# Child-side CUDA staging here can deadlock before first inference RPC.
|
||||
if self.is_dynamic():
|
||||
out = self._call_rpc("inner_model_apply_model", args, kwargs)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
|
||||
def _to_cuda(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
|
||||
return obj.to("cuda")
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cuda(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cuda(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cuda(v) for v in obj)
|
||||
return obj
|
||||
|
||||
try:
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
except torch.OutOfMemoryError:
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
|
||||
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
keys = self._call_rpc("model_state_dict", filter_prefix)
|
||||
return dict.fromkeys(keys, None)
|
||||
|
||||
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
|
||||
res = self._call_rpc("add_patches", *args, **kwargs)
|
||||
if isinstance(res, list):
|
||||
return [tuple(x) if isinstance(x, list) else x for x in res]
|
||||
return res
|
||||
|
||||
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
return self._call_rpc("get_key_patches", filter_prefix)
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
self._call_rpc("pin_weight_to_device", key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
self._call_rpc("unpin_weight", key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self._call_rpc("unpin_all_weights")
|
||||
|
||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
|
||||
return self._call_rpc(
|
||||
"calculate_weight", patches, weight, key, intermediate_dtype
|
||||
)
|
||||
|
||||
def inject_model(self) -> None:
|
||||
self._call_rpc("inject_model")
|
||||
|
||||
def eject_model(self) -> None:
|
||||
self._call_rpc("eject_model")
|
||||
|
||||
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
|
||||
return AutoPatcherEjector(
|
||||
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
|
||||
)
|
||||
|
||||
@property
|
||||
def is_injected(self) -> bool:
|
||||
return self._call_rpc("get_is_injected")
|
||||
|
||||
@property
|
||||
def skip_injection(self) -> bool:
|
||||
return self._call_rpc("get_skip_injection")
|
||||
|
||||
@skip_injection.setter
|
||||
def skip_injection(self, value: bool) -> None:
|
||||
self._call_rpc("set_skip_injection", value)
|
||||
|
||||
def clean_hooks(self) -> None:
|
||||
self._call_rpc("clean_hooks")
|
||||
|
||||
def pre_run(self) -> None:
|
||||
self._call_rpc("pre_run")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
try:
|
||||
self._call_rpc("cleanup")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcherProxy cleanup RPC failed for %s",
|
||||
self._instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def model(self) -> _InnerModelProxy:
|
||||
return _InnerModelProxy(self)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
_whitelisted_attrs = {
|
||||
"hook_patches_backup",
|
||||
"hook_backup",
|
||||
"cached_hook_patches",
|
||||
"current_hooks",
|
||||
"forced_hooks",
|
||||
"is_clip",
|
||||
"patches_uuid",
|
||||
"pinned",
|
||||
"attachments",
|
||||
"additional_models",
|
||||
"injections",
|
||||
"hook_patches",
|
||||
"model_lowvram",
|
||||
"model_loaded_weight_memory",
|
||||
"backup",
|
||||
"object_patches_backup",
|
||||
"weight_wrapper_patches",
|
||||
"weight_inplace_update",
|
||||
"force_cast_weights",
|
||||
}
|
||||
if name in _whitelisted_attrs:
|
||||
return self._call_rpc("get_patcher_attr", name)
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
lora_path: str,
|
||||
strength_model: float,
|
||||
clip: Optional[Any] = None,
|
||||
strength_clip: float = 1.0,
|
||||
) -> tuple:
|
||||
clip_id = None
|
||||
if clip is not None:
|
||||
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
|
||||
result = self._call_rpc(
|
||||
"load_lora", lora_path, strength_model, clip_id, strength_clip
|
||||
)
|
||||
new_model = None
|
||||
if result.get("model_id"):
|
||||
new_model = ModelPatcherProxy(
|
||||
result["model_id"],
|
||||
self._registry,
|
||||
manage_lifecycle=not IS_CHILD_PROCESS,
|
||||
)
|
||||
new_clip = None
|
||||
if result.get("clip_id"):
|
||||
from comfy.isolation.clip_proxy import CLIPProxy
|
||||
|
||||
new_clip = CLIPProxy(result["clip_id"])
|
||||
return (new_model, new_clip)
|
||||
|
||||
@property
|
||||
def load_device(self) -> Any:
|
||||
return self._call_rpc("get_load_device")
|
||||
|
||||
@property
|
||||
def offload_device(self) -> Any:
|
||||
return self._call_rpc("get_offload_device")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.load_device
|
||||
|
||||
def current_loaded_device(self) -> Any:
|
||||
return self._call_rpc("current_loaded_device")
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._call_rpc("get_size")
|
||||
|
||||
def model_size(self) -> Any:
|
||||
return self._call_rpc("model_size")
|
||||
|
||||
def loaded_size(self) -> Any:
|
||||
return self._call_rpc("loaded_size")
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return self._call_rpc("lowvram_patch_counter")
|
||||
|
||||
def memory_required(self, input_shape: Any) -> Any:
|
||||
return self._call_rpc("memory_required", input_shape)
|
||||
|
||||
def get_operation_state(self) -> Dict[str, Any]:
|
||||
state = self._call_rpc("get_operation_state")
|
||||
return state if isinstance(state, dict) else {}
|
||||
|
||||
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
|
||||
return bool(self._call_rpc("wait_for_idle", timeout_ms))
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return bool(self._call_rpc("is_dynamic"))
|
||||
|
||||
def get_free_memory(self, device: Any) -> Any:
|
||||
return self._call_rpc("get_free_memory", device)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload: int) -> Any:
|
||||
return self._call_rpc("partially_unload_ram", ram_to_unload)
|
||||
|
||||
def model_dtype(self) -> Any:
|
||||
res = self._call_rpc("model_dtype")
|
||||
if isinstance(res, str) and res.startswith("torch."):
|
||||
try:
|
||||
import torch
|
||||
|
||||
attr = res.split(".")[-1]
|
||||
if hasattr(torch, attr):
|
||||
return getattr(torch, attr)
|
||||
except ImportError:
|
||||
pass
|
||||
return res
|
||||
|
||||
@property
|
||||
def hook_mode(self) -> Any:
|
||||
return self._call_rpc("get_hook_mode")
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", value)
|
||||
|
||||
def set_model_sampler_cfg_function(
|
||||
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_cfg_function",
|
||||
sampler_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_post_cfg_function(
|
||||
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_post_cfg_function",
|
||||
post_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_pre_cfg_function(
|
||||
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_pre_cfg_function",
|
||||
pre_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
|
||||
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
|
||||
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
|
||||
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
|
||||
|
||||
def set_model_patch(self, patch: Any, name: str) -> None:
|
||||
self._call_rpc("set_model_patch", patch, name)
|
||||
|
||||
def set_model_patch_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
name: str,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_patch_replace",
|
||||
patch,
|
||||
name,
|
||||
block_name,
|
||||
number,
|
||||
transformer_index,
|
||||
)
|
||||
|
||||
def set_model_attn1_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn1", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn2_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn2", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def set_model_emb_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "emb_patch")
|
||||
|
||||
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def set_model_post_input_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(
|
||||
self,
|
||||
scale_x=1.0,
|
||||
shift_x=0.0,
|
||||
scale_y=1.0,
|
||||
shift_y=0.0,
|
||||
scale_t=1.0,
|
||||
shift_t=0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
options = {
|
||||
"scale_x": scale_x,
|
||||
"shift_x": shift_x,
|
||||
"scale_y": scale_y,
|
||||
"shift_y": shift_y,
|
||||
"scale_t": scale_t,
|
||||
"shift_t": shift_t,
|
||||
}
|
||||
options.update(kwargs)
|
||||
self._call_rpc("set_model_rope_options", options)
|
||||
|
||||
def set_model_compute_dtype(self, dtype: Any) -> None:
|
||||
self._call_rpc("set_model_compute_dtype", dtype)
|
||||
|
||||
def add_object_patch(self, name: str, obj: Any) -> None:
|
||||
self._call_rpc("add_object_patch", name, obj)
|
||||
|
||||
def add_weight_wrapper(self, name: str, function: Any) -> None:
|
||||
self._call_rpc("add_weight_wrapper", name, function)
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
|
||||
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
|
||||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
|
||||
|
||||
@property
|
||||
def wrappers(self) -> Any:
|
||||
return self._call_rpc("get_wrappers")
|
||||
|
||||
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
|
||||
self._call_rpc("add_callback_with_key", call_type, key, callback)
|
||||
|
||||
def add_callback(self, call_type: str, callback: Any) -> None:
|
||||
self.add_callback_with_key(call_type, None, callback)
|
||||
|
||||
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_callbacks_with_key", call_type, key)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> Any:
|
||||
return self._call_rpc("get_callbacks")
|
||||
|
||||
def set_attachments(self, key: str, attachment: Any) -> None:
|
||||
self._call_rpc("set_attachments", key, attachment)
|
||||
|
||||
def get_attachment(self, key: str) -> Any:
|
||||
return self._call_rpc("get_attachment", key)
|
||||
|
||||
def remove_attachments(self, key: str) -> None:
|
||||
self._call_rpc("remove_attachments", key)
|
||||
|
||||
def set_injections(self, key: str, injections: Any) -> None:
|
||||
self._call_rpc("set_injections", key, injections)
|
||||
|
||||
def get_injections(self, key: str) -> Any:
|
||||
return self._call_rpc("get_injections", key)
|
||||
|
||||
def remove_injections(self, key: str) -> None:
|
||||
self._call_rpc("remove_injections", key)
|
||||
|
||||
def set_additional_models(self, key: str, models: Any) -> None:
|
||||
ids = [m._instance_id for m in models]
|
||||
self._call_rpc("set_additional_models", key, ids)
|
||||
|
||||
def remove_additional_models(self, key: str) -> None:
|
||||
self._call_rpc("remove_additional_models", key)
|
||||
|
||||
def get_nested_additional_models(self) -> Any:
|
||||
return self._call_rpc("get_nested_additional_models")
|
||||
|
||||
def get_additional_models(self) -> List[ModelPatcherProxy]:
|
||||
ids = self._call_rpc("get_additional_models")
|
||||
return [
|
||||
ModelPatcherProxy(
|
||||
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
|
||||
)
|
||||
for mid in ids
|
||||
]
|
||||
|
||||
def model_patches_models(self) -> Any:
|
||||
return self._call_rpc("model_patches_models")
|
||||
|
||||
@property
|
||||
def parent(self) -> Any:
|
||||
return self._call_rpc("get_parent")
|
||||
|
||||
|
||||
class _InnerModelProxy:
|
||||
def __init__(self, parent: ModelPatcherProxy):
|
||||
self._parent = parent
|
||||
self._model_sampling = None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(name)
|
||||
if name in (
|
||||
"model_config",
|
||||
"latent_format",
|
||||
"model_type",
|
||||
"current_weight_patches_uuid",
|
||||
):
|
||||
return self._parent._call_rpc("get_inner_model_attr", name)
|
||||
if name == "load_device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "load_device")
|
||||
if name == "device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "device")
|
||||
if name == "current_patcher":
|
||||
return ModelPatcherProxy(
|
||||
self._parent._instance_id,
|
||||
self._parent._registry,
|
||||
manage_lifecycle=False,
|
||||
)
|
||||
if name == "model_sampling":
|
||||
if self._model_sampling is None:
|
||||
self._model_sampling = self._parent._call_rpc(
|
||||
"get_model_object", "model_sampling"
|
||||
)
|
||||
return self._model_sampling
|
||||
if name == "extra_conds_shapes":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds_shapes", a, k
|
||||
)
|
||||
if name == "extra_conds":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds", a, k
|
||||
)
|
||||
if name == "memory_required":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_memory_required", a, k
|
||||
)
|
||||
if name == "apply_model":
|
||||
# Delegate to parent's method to get the CPU->CUDA optimization
|
||||
return self._parent.apply_model
|
||||
if name == "process_latent_in":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
|
||||
if name == "process_latent_out":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
|
||||
if name == "scale_latent_inpaint":
|
||||
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
|
||||
if name == "diffusion_model":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
|
||||
raise AttributeError(f"'{name}' not supported on isolated InnerModel")
|
||||
1230
comfy/isolation/model_patcher_proxy_registry.py
Normal file
1230
comfy/isolation/model_patcher_proxy_registry.py
Normal file
File diff suppressed because it is too large
Load Diff
156
comfy/isolation/model_patcher_proxy_utils.py
Normal file
156
comfy/isolation/model_patcher_proxy_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
|
||||
# Isolation utilities and serializers for ModelPatcherProxy
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
isolation_active = args.use_process_isolation or is_child
|
||||
|
||||
if not isolation_active:
|
||||
return model_patcher
|
||||
if is_child:
|
||||
return model_patcher
|
||||
if isinstance(model_patcher, ModelPatcherProxy):
|
||||
return model_patcher
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
model_id = registry.register(model_patcher)
|
||||
logger.debug(f"Isolated ModelPatcher: {model_id}")
|
||||
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
|
||||
|
||||
|
||||
def register_hooks_serializers(registry=None):
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
import comfy.hooks
|
||||
|
||||
if registry is None:
|
||||
registry = SerializerRegistry.get_instance()
|
||||
|
||||
def serialize_enum(obj):
|
||||
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
|
||||
|
||||
def deserialize_enum(data):
|
||||
cls_name, val_name = data["__enum__"].split(".")
|
||||
cls = getattr(comfy.hooks, cls_name)
|
||||
return cls[val_name]
|
||||
|
||||
registry.register("EnumHookType", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
|
||||
|
||||
def serialize_hook_group(obj):
|
||||
return {"__type__": "HookGroup", "hooks": obj.hooks}
|
||||
|
||||
def deserialize_hook_group(data):
|
||||
hg = comfy.hooks.HookGroup()
|
||||
for h in data["hooks"]:
|
||||
hg.add(h)
|
||||
return hg
|
||||
|
||||
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
|
||||
|
||||
def serialize_dict_state(obj):
|
||||
d = obj.__dict__.copy()
|
||||
d["__type__"] = type(obj).__name__
|
||||
if "custom_should_register" in d:
|
||||
del d["custom_should_register"]
|
||||
return d
|
||||
|
||||
def deserialize_dict_state_generic(cls):
|
||||
def _deserialize(data):
|
||||
h = cls()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
return _deserialize
|
||||
|
||||
def deserialize_hook_keyframe(data):
|
||||
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
|
||||
|
||||
def deserialize_hook_keyframe_group(data):
|
||||
h = comfy.hooks.HookKeyframeGroup()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register(
|
||||
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
|
||||
)
|
||||
|
||||
def deserialize_hook(data):
|
||||
h = comfy.hooks.Hook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("Hook", serialize_dict_state, deserialize_hook)
|
||||
|
||||
def deserialize_weight_hook(data):
|
||||
h = comfy.hooks.WeightHook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
|
||||
|
||||
def serialize_set(obj):
|
||||
return {"__set__": list(obj)}
|
||||
|
||||
def deserialize_set(data):
|
||||
return set(data["__set__"])
|
||||
|
||||
registry.register("set", serialize_set, deserialize_set)
|
||||
|
||||
try:
|
||||
from comfy.weight_adapter.lora import LoRAAdapter
|
||||
|
||||
def serialize_lora(obj):
|
||||
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
|
||||
|
||||
def deserialize_lora(data):
|
||||
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
|
||||
|
||||
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import uuid
|
||||
|
||||
def serialize_hook_ref(obj):
|
||||
return {
|
||||
"__hook_ref__": True,
|
||||
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
|
||||
}
|
||||
|
||||
def deserialize_hook_ref(data):
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
|
||||
return h
|
||||
|
||||
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register _HookRef: {e}")
|
||||
|
||||
|
||||
try:
|
||||
register_hooks_serializers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize hook serializers: {e}")
|
||||
360
comfy/isolation/model_sampling_proxy.py
Normal file
360
comfy/isolation/model_sampling_proxy.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _describe_value(obj: Any) -> str:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
try:
|
||||
if torch is not None and isinstance(obj, torch.Tensor):
|
||||
return (
|
||||
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
|
||||
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "%s(id=%s)" % (type(obj).__name__, id(obj))
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return None
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor) and t.is_cuda:
|
||||
return t.device
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.device
|
||||
return None
|
||||
|
||||
|
||||
def _to_device(obj: Any, device: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device != device:
|
||||
return obj.to(device)
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_device(x, device) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_device(v, device) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
def _to_cpu_for_rpc(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
t = obj.detach() if obj.requires_grad else obj
|
||||
if t.is_cuda:
|
||||
return t.to("cpu")
|
||||
return t
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_cpu_for_rpc(x) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.calculate_input(sigma, noise))
|
||||
|
||||
async def calculate_denoised(
|
||||
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.calculate_denoised(sigma, model_output, model_input)
|
||||
)
|
||||
|
||||
async def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
|
||||
)
|
||||
|
||||
async def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
|
||||
|
||||
async def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.timestep(sigma)
|
||||
|
||||
async def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.sigma(timestep)
|
||||
|
||||
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.percent_to_sigma(percent)
|
||||
|
||||
async def get_sigma_min(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_min)
|
||||
|
||||
async def get_sigma_max(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_max)
|
||||
|
||||
async def get_sigma_data(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_data)
|
||||
|
||||
async def get_sigmas(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigmas)
|
||||
|
||||
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
sampling = self._get_instance(instance_id)
|
||||
sampling.set_sigmas(sigmas)
|
||||
|
||||
|
||||
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
_registry_class = ModelSamplingRegistry
|
||||
__module__ = "comfy.isolation.model_sampling_proxy"
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
|
||||
)
|
||||
else:
|
||||
registry = ModelSamplingRegistry()
|
||||
|
||||
class _LocalCaller:
|
||||
def calculate_input(
|
||||
self, instance_id: str, sigma: Any, noise: Any
|
||||
) -> Any:
|
||||
return registry.calculate_input(instance_id, sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
model_output: Any,
|
||||
model_input: Any,
|
||||
) -> Any:
|
||||
return registry.calculate_denoised(
|
||||
instance_id, sigma, model_output, model_input
|
||||
)
|
||||
|
||||
def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
return registry.noise_scaling(
|
||||
instance_id, sigma, noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
return registry.inverse_noise_scaling(
|
||||
instance_id, sigma, latent
|
||||
)
|
||||
|
||||
def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
return registry.timestep(instance_id, sigma)
|
||||
|
||||
def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
return registry.sigma(instance_id, timestep)
|
||||
|
||||
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
return registry.percent_to_sigma(instance_id, percent)
|
||||
|
||||
def get_sigma_min(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_min(instance_id)
|
||||
|
||||
def get_sigma_max(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_max(instance_id)
|
||||
|
||||
def get_sigma_data(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_data(instance_id)
|
||||
|
||||
def get_sigmas(self, instance_id: str) -> Any:
|
||||
return registry.get_sigmas(instance_id)
|
||||
|
||||
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
return registry.set_sigmas(instance_id, sigmas)
|
||||
|
||||
self._rpc_caller = _LocalCaller()
|
||||
return self._rpc_caller
|
||||
|
||||
def _call(self, method_name: str, *args: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
result = method(self._instance_id, *args)
|
||||
timeout_ms = self._rpc_timeout_ms()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
call_id = "%s:%s:%s:%.6f" % (
|
||||
self._instance_id,
|
||||
method_name,
|
||||
thread_id,
|
||||
start_perf,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
timeout_ms,
|
||||
)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
out = run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
out = loop.run_until_complete(result)
|
||||
else:
|
||||
out = result
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
_describe_value(out),
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _rpc_timeout_ms() -> int:
|
||||
raw = os.environ.get(
|
||||
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
|
||||
)
|
||||
try:
|
||||
timeout_ms = int(raw)
|
||||
except ValueError:
|
||||
timeout_ms = 30000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
return self._call("get_sigma_min")
|
||||
|
||||
@property
|
||||
def sigma_max(self) -> Any:
|
||||
return self._call("get_sigma_max")
|
||||
|
||||
@property
|
||||
def sigma_data(self) -> Any:
|
||||
return self._call("get_sigma_data")
|
||||
|
||||
@property
|
||||
def sigmas(self) -> Any:
|
||||
return self._call("get_sigmas")
|
||||
|
||||
def calculate_input(self, sigma: Any, noise: Any) -> Any:
|
||||
return self._call("calculate_input", sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
return self._call("calculate_denoised", sigma, model_output, model_input)
|
||||
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
preferred_device = _prefer_device(noise, latent_image)
|
||||
out = self._call(
|
||||
"noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(noise),
|
||||
_to_cpu_for_rpc(latent_image),
|
||||
max_denoise,
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
preferred_device = _prefer_device(latent)
|
||||
out = self._call(
|
||||
"inverse_noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(latent),
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
def sigma(self, timestep: Any) -> Any:
|
||||
return self._call("sigma", timestep)
|
||||
|
||||
def percent_to_sigma(self, percent: float) -> Any:
|
||||
return self._call("percent_to_sigma", percent)
|
||||
|
||||
def set_sigmas(self, sigmas: Any) -> None:
|
||||
return self._call("set_sigmas", sigmas)
|
||||
17
comfy/isolation/proxies/__init__.py
Normal file
17
comfy/isolation/proxies/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IS_CHILD_PROCESS",
|
||||
"BaseRegistry",
|
||||
"BaseProxy",
|
||||
"get_thread_loop",
|
||||
"run_coro_in_new_loop",
|
||||
"detach_if_grad",
|
||||
]
|
||||
283
comfy/isolation/proxies/base.py
Normal file
283
comfy/isolation/proxies/base.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
_thread_local = threading.local()
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
||||
loop = getattr(_thread_local, "loop", None)
|
||||
if loop is None or loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
def run_coro_in_new_loop(coro: Any) -> Any:
|
||||
result_box: Dict[str, Any] = {}
|
||||
exc_box: Dict[str, BaseException] = {}
|
||||
|
||||
def runner() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result_box["value"] = loop.run_until_complete(coro)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exc_box["exc"] = exc
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
t = threading.Thread(target=runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
if "exc" in exc_box:
|
||||
raise exc_box["exc"]
|
||||
return result_box.get("value")
|
||||
|
||||
|
||||
def detach_if_grad(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.detach() if obj.requires_grad else obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(detach_if_grad(x) for x in obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: detach_if_grad(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
||||
_type_prefix: str = "base"
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
||||
super().__init__()
|
||||
self._registry: Dict[str, T] = {}
|
||||
self._id_map: Dict[int, str] = {}
|
||||
self._counter = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, instance: T) -> str:
|
||||
with self._lock:
|
||||
obj_id = id(instance)
|
||||
if obj_id in self._id_map:
|
||||
return self._id_map[obj_id]
|
||||
instance_id = f"{self._type_prefix}_{self._counter}"
|
||||
self._counter += 1
|
||||
self._registry[instance_id] = instance
|
||||
self._id_map[obj_id] = instance_id
|
||||
return instance_id
|
||||
|
||||
def unregister_sync(self, instance_id: str) -> None:
|
||||
with self._lock:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance:
|
||||
self._id_map.pop(id(instance), None)
|
||||
|
||||
def _get_instance(self, instance_id: str) -> T:
|
||||
if IS_CHILD_PROCESS:
|
||||
raise RuntimeError(
|
||||
f"[{self.__class__.__name__}] _get_instance called in child"
|
||||
)
|
||||
with self._lock:
|
||||
instance = self._registry.get(instance_id)
|
||||
if instance is None:
|
||||
raise ValueError(f"{instance_id} not found")
|
||||
return instance
|
||||
|
||||
|
||||
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
|
||||
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _GLOBAL_LOOP
|
||||
_GLOBAL_LOOP = loop
|
||||
|
||||
|
||||
class BaseProxy(Generic[T]):
|
||||
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||
__module__: str = "comfy.isolation.proxies.base"
|
||||
_TIMEOUT_RPC_METHODS = frozenset(
|
||||
{
|
||||
"partially_load",
|
||||
"partially_unload",
|
||||
"load",
|
||||
"patch_model",
|
||||
"unpatch_model",
|
||||
"inner_model_apply_model",
|
||||
"memory_required",
|
||||
"model_dtype",
|
||||
"inner_model_memory_required",
|
||||
"inner_model_extra_conds_shapes",
|
||||
"inner_model_extra_conds",
|
||||
"process_latent_in",
|
||||
"process_latent_out",
|
||||
"scale_latent_inpaint",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
registry: Optional[Any] = None,
|
||||
manage_lifecycle: bool = False,
|
||||
) -> None:
|
||||
self._instance_id = instance_id
|
||||
self._rpc_caller: Optional[Any] = None
|
||||
self._registry = registry if registry is not None else self._registry_class()
|
||||
self._manage_lifecycle = manage_lifecycle
|
||||
self._cleaned_up = False
|
||||
if manage_lifecycle and not IS_CHILD_PROCESS:
|
||||
self._finalizer = weakref.finalize(
|
||||
self, self._registry.unregister_sync, instance_id
|
||||
)
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is None:
|
||||
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
return self._rpc_caller
|
||||
|
||||
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
|
||||
if method_name not in self._TIMEOUT_RPC_METHODS:
|
||||
return None
|
||||
try:
|
||||
timeout_ms = int(
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
|
||||
)
|
||||
except ValueError:
|
||||
timeout_ms = 120000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
|
||||
coro = method(self._instance_id, *args, **kwargs)
|
||||
if timeout_ms is not None:
|
||||
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
||||
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
loop_id: Optional[int] = id(running_loop)
|
||||
except RuntimeError:
|
||||
loop_id = None
|
||||
logger.debug(
|
||||
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
|
||||
"thread=%s loop=%s timeout_ms=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
loop_id,
|
||||
timeout_ms,
|
||||
)
|
||||
|
||||
try:
|
||||
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop - we are in a worker thread.
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result(
|
||||
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
except concurrent.futures.TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
finally:
|
||||
end_epoch = time.time()
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
|
||||
"elapsed_ms=%.3f thread=%s loop=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
end_epoch,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
loop_id,
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_instance_id": self._instance_id}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self._instance_id = state["_instance_id"]
|
||||
self._rpc_caller = None
|
||||
self._registry = self._registry_class()
|
||||
self._manage_lifecycle = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self._cleaned_up or IS_CHILD_PROCESS:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
finalizer = getattr(self, "_finalizer", None)
|
||||
if finalizer is not None:
|
||||
finalizer.detach()
|
||||
self._registry.unregister_sync(self._instance_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._instance_id}>"
|
||||
|
||||
|
||||
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
||||
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(method_name, *args, **kwargs)
|
||||
|
||||
method.__name__ = method_name
|
||||
return method
|
||||
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict
|
||||
|
||||
import folder_paths
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
|
||||
class FolderPathsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for folder_paths.
|
||||
Uses __getattr__ for most lookups, with explicit handling for
|
||||
mutable collections to ensure efficient by-value transfer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(folder_paths, name)
|
||||
|
||||
# Return dict snapshots (avoid RPC chatter)
|
||||
@property
|
||||
def folder_names_and_paths(self) -> Dict:
|
||||
return dict(folder_paths.folder_names_and_paths)
|
||||
|
||||
@property
|
||||
def extension_mimetypes_cache(self) -> Dict:
|
||||
return dict(folder_paths.extension_mimetypes_cache)
|
||||
|
||||
@property
|
||||
def filename_list_cache(self) -> Dict:
|
||||
return dict(folder_paths.filename_list_cache)
|
||||
98
comfy/isolation/proxies/helper_proxies.py
Normal file
98
comfy/isolation/proxies/helper_proxies.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class AnyTypeProxy(str):
|
||||
"""Replacement for custom AnyType objects used by some nodes."""
|
||||
|
||||
def __new__(cls, value: str = "*"):
|
||||
return super().__new__(cls, value)
|
||||
|
||||
def __ne__(self, other): # type: ignore[override]
|
||||
return False
|
||||
|
||||
|
||||
class FlexibleOptionalInputProxy(dict):
|
||||
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
|
||||
|
||||
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
|
||||
super().__init__()
|
||||
self.type = flex_type
|
||||
if data:
|
||||
self.update(data)
|
||||
|
||||
def __getitem__(self, key): # type: ignore[override]
|
||||
return (self.type,)
|
||||
|
||||
def __contains__(self, key): # type: ignore[override]
|
||||
return True
|
||||
|
||||
|
||||
class ByPassTypeTupleProxy(tuple):
|
||||
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
|
||||
|
||||
def __new__(cls, values):
|
||||
return super().__new__(cls, values)
|
||||
|
||||
def __getitem__(self, index): # type: ignore[override]
|
||||
if index >= len(self):
|
||||
return AnyTypeProxy("*")
|
||||
return super().__getitem__(index)
|
||||
|
||||
|
||||
def _restore_special_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
if value.get("__pyisolate_any_type__"):
|
||||
return AnyTypeProxy(value.get("value", "*"))
|
||||
if value.get("__pyisolate_flexible_optional__"):
|
||||
flex_type = _restore_special_value(value.get("type"))
|
||||
data_raw = value.get("data")
|
||||
data = (
|
||||
{k: _restore_special_value(v) for k, v in data_raw.items()}
|
||||
if isinstance(data_raw, dict)
|
||||
else {}
|
||||
)
|
||||
return FlexibleOptionalInputProxy(flex_type, data)
|
||||
if value.get("__pyisolate_tuple__") is not None:
|
||||
return tuple(
|
||||
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
|
||||
)
|
||||
if value.get("__pyisolate_bypass_tuple__") is not None:
|
||||
return ByPassTypeTupleProxy(
|
||||
tuple(
|
||||
_restore_special_value(v)
|
||||
for v in value["__pyisolate_bypass_tuple__"]
|
||||
)
|
||||
)
|
||||
return {k: _restore_special_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_restore_special_value(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
|
||||
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
return raw # type: ignore[return-value]
|
||||
|
||||
restored: Dict[str, object] = {}
|
||||
for section, entries in raw.items():
|
||||
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
|
||||
restored[section] = _restore_special_value(entries)
|
||||
elif isinstance(entries, dict):
|
||||
restored[section] = {
|
||||
k: _restore_special_value(v) for k, v in entries.items()
|
||||
}
|
||||
else:
|
||||
restored[section] = _restore_special_value(entries)
|
||||
return restored
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnyTypeProxy",
|
||||
"FlexibleOptionalInputProxy",
|
||||
"ByPassTypeTupleProxy",
|
||||
"restore_input_types",
|
||||
]
|
||||
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import comfy.model_management as mm
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
|
||||
class ModelManagementProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for comfy.model_management.
|
||||
Uses __getattr__ to forward all calls to the underlying module,
|
||||
reducing maintenance burden.
|
||||
"""
|
||||
|
||||
# Explicitly expose Enums/Classes as properties
|
||||
@property
|
||||
def VRAMState(self):
|
||||
return mm.VRAMState
|
||||
|
||||
@property
|
||||
def CPUState(self):
|
||||
return mm.CPUState
|
||||
|
||||
@property
|
||||
def OOM_EXCEPTION(self):
|
||||
return mm.OOM_EXCEPTION
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward all other attribute access to the module."""
|
||||
return getattr(mm, name)
|
||||
35
comfy/isolation/proxies/progress_proxy.py
Normal file
35
comfy/isolation/proxies/progress_proxy.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton:
|
||||
pass
|
||||
|
||||
|
||||
from comfy_execution.progress import get_progress_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressProxy(ProxiedSingleton):
|
||||
def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: Optional[str] = None,
|
||||
image: Any = None,
|
||||
) -> None:
|
||||
get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ProgressProxy"]
|
||||
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
|
||||
"""Stateless RPC Implementation for PromptServer.
|
||||
|
||||
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
|
||||
- Host: PromptServerService (RPC Handler)
|
||||
- Child: PromptServerStub (Interface Implementation)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
# IMPORTS
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_PREFIX = "[Isolation:C<->H]"
|
||||
|
||||
# ...
|
||||
|
||||
# =============================================================================
|
||||
# CHILD SIDE: PromptServerStub
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerStub:
|
||||
"""Stateless Stub for PromptServer."""
|
||||
|
||||
# Masquerade as the real server module
|
||||
__module__ = "server"
|
||||
|
||||
_instance: Optional["PromptServerStub"] = None
|
||||
_rpc: Optional[Any] = None # This will be the Caller object
|
||||
_source_file: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.routes = RouteStub(self)
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
"""Inject RPC client (called by adapter.py or manually)."""
|
||||
# Create caller for HOST Service
|
||||
# Assuming Host Service is registered as "PromptServerService" (class name)
|
||||
# We target the Host Service Class
|
||||
target_id = "PromptServerService"
|
||||
# We need to pass a class to create_caller? Usually yes.
|
||||
# But we don't have the Service class imported here necessarily (if running on child).
|
||||
# pyisolate check verify_service type?
|
||||
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
|
||||
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
|
||||
# We need a dummy class with right name?
|
||||
# Or just rely on string ID if create_caller supports it?
|
||||
# Standard: rpc.create_caller(PromptServerStub, target_id)
|
||||
# But wait, PromptServerStub is the *Local* class.
|
||||
# We want to call *Remote* class.
|
||||
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
|
||||
# The first arg is 'service_cls'.
|
||||
cls._rpc = rpc.create_caller(
|
||||
PromptServerService, target_id
|
||||
) # We import Service below?
|
||||
|
||||
# We need PromptServerService available for the create_caller call?
|
||||
# Or just use the Stub class if ID matches?
|
||||
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
|
||||
|
||||
@property
|
||||
def instance(self) -> "PromptServerStub":
|
||||
return self
|
||||
|
||||
# ... Compatibility ...
|
||||
@classmethod
|
||||
def _get_source_file(cls) -> str:
|
||||
if cls._source_file is None:
|
||||
import folder_paths
|
||||
|
||||
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
|
||||
return cls._source_file
|
||||
|
||||
@property
|
||||
def __file__(self) -> str:
|
||||
return self._get_source_file()
|
||||
|
||||
# --- Properties ---
|
||||
@property
|
||||
def client_id(self) -> Optional[str]:
|
||||
return "isolated_client"
|
||||
|
||||
def supports(self, feature: str) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def prompt_queue(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.prompt_queue is not accessible in isolated nodes."
|
||||
)
|
||||
|
||||
# --- UI Communication (RPC Delegates) ---
|
||||
async def send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send_sync(event, data, sid)
|
||||
|
||||
async def send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send(event, data, sid)
|
||||
|
||||
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
|
||||
if self._rpc:
|
||||
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
|
||||
# We must schedule it?
|
||||
# Or use fire_remote equivalent?
|
||||
# Caller object usually proxies calls. If host method is async, it returns coro.
|
||||
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
|
||||
# But UtilsProxy hook wrapper creates task.
|
||||
# Does send_progress_text need to be sync? Yes, node code calls it sync.
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
|
||||
except RuntimeError:
|
||||
pass # Sync context without loop?
|
||||
|
||||
# --- Route Registration Logic ---
|
||||
def register_route(self, method: str, path: str, handler: Callable):
|
||||
"""Register a route handler via RPC."""
|
||||
if not self._rpc:
|
||||
logger.error("RPC not initialized in PromptServerStub")
|
||||
return
|
||||
|
||||
# Fire registration async
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
class RouteStub:
|
||||
"""Simulates aiohttp.web.RouteTableDef."""
|
||||
|
||||
def __init__(self, stub: PromptServerStub):
|
||||
self._stub = stub
|
||||
|
||||
def get(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HOST SIDE: PromptServerService
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerService(ProxiedSingleton):
|
||||
"""Host-side RPC Service for PromptServer."""
|
||||
|
||||
def __init__(self):
|
||||
# We will bind to the real server instance lazily or via global import
|
||||
pass
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
from server import PromptServer
|
||||
|
||||
return PromptServer.instance
|
||||
|
||||
async def ui_send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send_sync(event, data, sid)
|
||||
|
||||
async def ui_send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send(event, data, sid)
|
||||
|
||||
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
|
||||
# Made async to be awaitable by RPC layer
|
||||
self.server.send_progress_text(text, node_id, sid)
|
||||
|
||||
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
||||
"""RPC Target: Register a route that forwards to the Child."""
|
||||
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
||||
|
||||
async def route_wrapper(request: web.Request) -> web.Response:
|
||||
# 1. Capture request data
|
||||
req_data = {
|
||||
"method": request.method,
|
||||
"path": request.path,
|
||||
"query": dict(request.query),
|
||||
}
|
||||
if request.can_read_body:
|
||||
req_data["text"] = await request.text()
|
||||
|
||||
try:
|
||||
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
|
||||
result = await child_handler_proxy(req_data)
|
||||
|
||||
# 3. Serialize Response
|
||||
return self._serialize_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
# Register loop
|
||||
self.server.app.router.add_route(method, path, route_wrapper)
|
||||
|
||||
def _serialize_response(self, result: Any) -> web.Response:
|
||||
"""Helper to convert Child result -> web.Response"""
|
||||
if isinstance(result, web.Response):
|
||||
return result
|
||||
# Handle dict (json)
|
||||
if isinstance(result, dict):
|
||||
return web.json_response(result)
|
||||
# Handle string
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result)
|
||||
# Fallback
|
||||
return web.Response(text=str(result))
|
||||
64
comfy/isolation/proxies/utils_proxy.py
Normal file
64
comfy/isolation/proxies/utils_proxy.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any
|
||||
import comfy.utils
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class UtilsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Proxy for comfy.utils.
|
||||
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
|
||||
from isolated nodes reach the host.
|
||||
"""
|
||||
|
||||
# _instance and __new__ removed to rely on SingletonMetaclass
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
# Create caller using class name as ID (standard for Singletons)
|
||||
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
|
||||
|
||||
async def progress_bar_hook(
|
||||
self,
|
||||
value: int,
|
||||
total: int,
|
||||
preview: Optional[bytes] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Host-side implementation: forwards the call to the real global hook.
|
||||
Child-side: this method call is intercepted by RPC and sent to host.
|
||||
"""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
# Manual RPC dispatch for Child process
|
||||
# Use class-level RPC storage (Static Injection)
|
||||
if UtilsProxy._rpc:
|
||||
return await UtilsProxy._rpc.progress_bar_hook(
|
||||
value, total, preview, node_id
|
||||
)
|
||||
|
||||
# Fallback channel: global child rpc
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
get_child_rpc_instance()
|
||||
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
|
||||
# but we need a caller. For now, just pass to avoid crashing.
|
||||
pass
|
||||
except (ImportError, LookupError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# Host Execution
|
||||
if comfy.utils.PROGRESS_BAR_HOOK is not None:
|
||||
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
|
||||
|
||||
def set_progress_bar_global_hook(self, hook: Any) -> None:
|
||||
"""Forward hook registration (though usually not needed from child)."""
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
49
comfy/isolation/rpc_bridge.py
Normal file
49
comfy/isolation/rpc_bridge.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RpcBridge:
|
||||
"""Minimal helper to run coroutines synchronously inside isolated processes.
|
||||
|
||||
If an event loop is already running, the coroutine is executed on a fresh
|
||||
thread with its own loop to avoid nested run_until_complete errors.
|
||||
"""
|
||||
|
||||
def run_sync(self, maybe_coro):
|
||||
if not asyncio.iscoroutine(maybe_coro):
|
||||
return maybe_coro
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
result_container = {}
|
||||
exc_container = {}
|
||||
|
||||
def _runner():
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result_container["value"] = new_loop.run_until_complete(maybe_coro)
|
||||
except Exception as exc: # pragma: no cover
|
||||
exc_container["error"] = exc
|
||||
finally:
|
||||
try:
|
||||
new_loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
if "error" in exc_container:
|
||||
raise exc_container["error"]
|
||||
return result_container.get("value")
|
||||
|
||||
return asyncio.run(maybe_coro)
|
||||
363
comfy/isolation/runtime_helpers.py
Normal file
363
comfy/isolation/runtime_helpers.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
from .proxies.helper_proxies import restore_input_types
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from comfy_api.latest import _io as latest_io
|
||||
from .shm_forensics import scan_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
def _resource_snapshot() -> Dict[str, int]:
|
||||
fd_count = -1
|
||||
shm_sender_files = 0
|
||||
try:
|
||||
fd_count = len(os.listdir("/proc/self/fd"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
shm_root = Path("/dev/shm")
|
||||
if shm_root.exists():
|
||||
prefix = f"torch_{os.getpid()}_"
|
||||
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
|
||||
except Exception:
|
||||
pass
|
||||
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
|
||||
|
||||
|
||||
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
|
||||
summary: Dict[str, int] = {
|
||||
"tensor_count": 0,
|
||||
"cpu_tensors": 0,
|
||||
"cuda_tensors": 0,
|
||||
"shared_cpu_tensors": 0,
|
||||
"tensor_bytes": 0,
|
||||
}
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return summary
|
||||
|
||||
def visit(node: Any) -> None:
|
||||
if isinstance(node, torch.Tensor):
|
||||
summary["tensor_count"] += 1
|
||||
summary["tensor_bytes"] += int(node.numel() * node.element_size())
|
||||
if node.device.type == "cpu":
|
||||
summary["cpu_tensors"] += 1
|
||||
if node.is_shared():
|
||||
summary["shared_cpu_tensors"] += 1
|
||||
elif node.device.type == "cuda":
|
||||
summary["cuda_tensors"] += 1
|
||||
return
|
||||
if isinstance(node, dict):
|
||||
for v in node.values():
|
||||
visit(v)
|
||||
return
|
||||
if isinstance(node, (list, tuple)):
|
||||
for v in node:
|
||||
visit(v)
|
||||
|
||||
visit(value)
|
||||
return summary
|
||||
|
||||
|
||||
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
|
||||
for key, value in inputs.items():
|
||||
key_text = str(key)
|
||||
if "unique_id" in key_text:
|
||||
return str(value)
|
||||
return None
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return
|
||||
if not callable(flush_tensor_keeper):
|
||||
return
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
|
||||
|
||||
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _detach_shared_cpu_tensors(value: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.device.type == "cpu" and value.is_shared():
|
||||
clone = value.clone()
|
||||
if value.requires_grad:
|
||||
clone.requires_grad_(True)
|
||||
return clone
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return [_detach_shared_cpu_tensors(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_detach_shared_cpu_tensors(v) for v in value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def build_stub_class(
|
||||
node_name: str,
|
||||
info: Dict[str, object],
|
||||
extension: "ComfyNodeExtension",
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
logger: logging.Logger,
|
||||
) -> type:
|
||||
is_v3 = bool(info.get("is_v3", False))
|
||||
function_name = "_pyisolate_execute"
|
||||
restored_input_types = restore_input_types(info.get("input_types", {}))
|
||||
|
||||
async def _execute(self, **inputs):
|
||||
from comfy.isolation import _RUNNING_EXTENSIONS
|
||||
|
||||
# Update BOTH the local dict AND the module-level dict
|
||||
running_extensions[extension.name] = extension
|
||||
_RUNNING_EXTENSIONS[extension.name] = extension
|
||||
prev_child = None
|
||||
node_unique_id = _extract_hidden_unique_id(inputs)
|
||||
summary = _tensor_transport_summary(inputs)
|
||||
resources = _resource_snapshot()
|
||||
logger.debug(
|
||||
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
summary["tensor_count"],
|
||||
summary["cpu_tensors"],
|
||||
summary["cuda_tensors"],
|
||||
summary["shared_cpu_tensors"],
|
||||
summary["tensor_bytes"],
|
||||
resources["fd_count"],
|
||||
resources["shm_sender_files"],
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
|
||||
try:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
|
||||
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
|
||||
from pyisolate._internal.model_serialization import (
|
||||
serialize_for_isolation,
|
||||
deserialize_from_isolation,
|
||||
)
|
||||
|
||||
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
serialized = serialize_for_isolation(inputs)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
result = await extension.execute_node(node_name, **serialized)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
# Reconstruct NodeOutput if the child serialized one
|
||||
if isinstance(result, dict) and result.get("__node_output__"):
|
||||
from comfy_api.latest import io as latest_io
|
||||
args_raw = result.get("args", ())
|
||||
deserialized_args = await deserialize_from_isolation(args_raw, extension)
|
||||
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return latest_io.NodeOutput(
|
||||
*deserialized_args,
|
||||
ui=result.get("ui"),
|
||||
expand=result.get("expand"),
|
||||
block_execution=result.get("block_execution"),
|
||||
)
|
||||
deserialized = await deserialize_from_isolation(result, extension)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return _detach_shared_cpu_tensors(deserialized)
|
||||
except ImportError:
|
||||
return await extension.execute_node(node_name, **inputs)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:execute_error ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if prev_child is not None:
|
||||
os.environ["PYISOLATE_CHILD"] = prev_child
|
||||
logger.debug(
|
||||
"%s ISO:execute_end ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
|
||||
|
||||
def _input_types(
|
||||
cls,
|
||||
include_hidden: bool = True,
|
||||
return_schema: bool = False,
|
||||
live_inputs: Any = None,
|
||||
):
|
||||
if not is_v3:
|
||||
return restored_input_types
|
||||
|
||||
inputs_copy = copy.deepcopy(restored_input_types)
|
||||
if not include_hidden:
|
||||
inputs_copy.pop("hidden", None)
|
||||
|
||||
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
|
||||
dynamic = inputs_copy.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
|
||||
if return_schema:
|
||||
hidden_vals = info.get("hidden", []) or []
|
||||
hidden_enums = []
|
||||
for h in hidden_vals:
|
||||
try:
|
||||
hidden_enums.append(latest_io.Hidden(h))
|
||||
except Exception:
|
||||
hidden_enums.append(h)
|
||||
|
||||
class SchemaProxy:
|
||||
hidden = hidden_enums
|
||||
|
||||
return inputs_copy, SchemaProxy, v3_data
|
||||
return inputs_copy
|
||||
|
||||
def _validate_class(cls):
|
||||
return True
|
||||
|
||||
def _get_node_info_v1(cls):
|
||||
node_info = copy.deepcopy(info.get("schema_v1", {}))
|
||||
relative_python_module = node_info.get("python_module")
|
||||
if not isinstance(relative_python_module, str) or not relative_python_module:
|
||||
relative_python_module = f"custom_nodes.{extension.name}"
|
||||
node_info["python_module"] = relative_python_module
|
||||
return node_info
|
||||
|
||||
def _get_base_class(cls):
|
||||
return latest_io.ComfyNode
|
||||
|
||||
attributes: Dict[str, object] = {
|
||||
"FUNCTION": function_name,
|
||||
"CATEGORY": info.get("category", ""),
|
||||
"OUTPUT_NODE": info.get("output_node", False),
|
||||
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
|
||||
"RETURN_NAMES": info.get("return_names"),
|
||||
function_name: _execute,
|
||||
"_pyisolate_extension": extension,
|
||||
"_pyisolate_node_name": node_name,
|
||||
"INPUT_TYPES": classmethod(_input_types),
|
||||
}
|
||||
|
||||
output_is_list = info.get("output_is_list")
|
||||
if output_is_list is not None:
|
||||
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
|
||||
|
||||
if is_v3:
|
||||
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
|
||||
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
|
||||
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
|
||||
attributes["DESCRIPTION"] = info.get("description", "")
|
||||
attributes["EXPERIMENTAL"] = info.get("experimental", False)
|
||||
attributes["DEPRECATED"] = info.get("deprecated", False)
|
||||
attributes["API_NODE"] = info.get("api_node", False)
|
||||
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
|
||||
attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||
attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
|
||||
|
||||
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
|
||||
bases = (_ComfyNodeInternal,) if is_v3 else ()
|
||||
stub_cls = type(class_name, bases, attributes)
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
stub_cls.VALIDATE_CLASS()
|
||||
except Exception as e:
|
||||
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
|
||||
|
||||
return stub_cls
|
||||
|
||||
|
||||
def get_class_types_for_extension(
|
||||
extension_name: str,
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
specs: List[Any],
|
||||
) -> Set[str]:
|
||||
extension = running_extensions.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in specs:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
return class_types
|
||||
|
||||
|
||||
__all__ = ["build_stub_class", "get_class_types_for_extension"]
|
||||
217
comfy/isolation/shm_forensics.py
Normal file
217
comfy/isolation/shm_forensics.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _shm_debug_enabled() -> bool:
|
||||
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
|
||||
|
||||
|
||||
class _SHMForensicsTracker:
|
||||
def __init__(self) -> None:
|
||||
self._started = False
|
||||
self._tracked_files: Set[str] = set()
|
||||
self._current_model_context: Dict[str, str] = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _snapshot_shm() -> Set[str]:
|
||||
shm_path = Path("/dev/shm")
|
||||
if not shm_path.exists():
|
||||
return set()
|
||||
return {f.name for f in shm_path.glob("torch_*")}
|
||||
|
||||
def start(self) -> None:
|
||||
if self._started or not _shm_debug_enabled():
|
||||
return
|
||||
self._tracked_files = self._snapshot_shm()
|
||||
self._started = True
|
||||
logger.debug(
|
||||
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
self.scan("shutdown", refresh_model_context=True)
|
||||
self._started = False
|
||||
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
|
||||
|
||||
def _compute_model_hash(self, model_patcher: Any) -> str:
|
||||
try:
|
||||
model_instance_id = getattr(model_patcher, "_instance_id", None)
|
||||
if model_instance_id is not None:
|
||||
model_id_text = str(model_instance_id)
|
||||
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
|
||||
|
||||
import torch
|
||||
|
||||
real_model = (
|
||||
model_patcher.model
|
||||
if hasattr(model_patcher, "model")
|
||||
else model_patcher
|
||||
)
|
||||
tensor = None
|
||||
if hasattr(real_model, "parameters"):
|
||||
for p in real_model.parameters():
|
||||
if torch.is_tensor(p) and p.numel() > 0:
|
||||
tensor = p
|
||||
break
|
||||
|
||||
if tensor is None:
|
||||
return "0000"
|
||||
|
||||
flat = tensor.flatten()
|
||||
values = []
|
||||
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
|
||||
for i in indices:
|
||||
if i < flat.shape[0]:
|
||||
values.append(flat[i].item())
|
||||
|
||||
size = 0
|
||||
if hasattr(model_patcher, "model_size"):
|
||||
size = model_patcher.model_size()
|
||||
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
|
||||
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
|
||||
except Exception:
|
||||
return "err!"
|
||||
|
||||
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
snapshot: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for loaded_model in model_management.current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is None:
|
||||
continue
|
||||
if str(getattr(loaded_model, "device", "")) != "cuda:0":
|
||||
continue
|
||||
|
||||
name = (
|
||||
model.model.__class__.__name__
|
||||
if hasattr(model, "model")
|
||||
else type(model).__name__
|
||||
)
|
||||
model_hash = self._compute_model_hash(model)
|
||||
model_instance_id = getattr(model, "_instance_id", None)
|
||||
if model_instance_id is None:
|
||||
model_instance_id = model_hash
|
||||
snapshot.append(
|
||||
{
|
||||
"name": str(name),
|
||||
"id": str(model_instance_id),
|
||||
"hash": str(model_hash or "????"),
|
||||
"used": bool(getattr(loaded_model, "currently_used", False)),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return snapshot
|
||||
|
||||
def _update_model_context(self) -> None:
|
||||
snapshot = self._get_models_snapshot()
|
||||
selected = None
|
||||
|
||||
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
|
||||
if used_models:
|
||||
selected = used_models[-1]
|
||||
else:
|
||||
live_models = [m for m in snapshot if m.get("id")]
|
||||
if live_models:
|
||||
selected = live_models[-1]
|
||||
|
||||
if selected is None:
|
||||
self._current_model_context = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
return
|
||||
|
||||
self._current_model_context = {
|
||||
"id": str(selected.get("id", "unknown")),
|
||||
"name": str(selected.get("name", "unknown")),
|
||||
"hash": str(selected.get("hash", "????") or "????"),
|
||||
}
|
||||
|
||||
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
|
||||
if not self._started or not _shm_debug_enabled():
|
||||
return
|
||||
|
||||
if refresh_model_context:
|
||||
self._update_model_context()
|
||||
|
||||
current = self._snapshot_shm()
|
||||
added = current - self._tracked_files
|
||||
removed = self._tracked_files - current
|
||||
self._tracked_files = current
|
||||
|
||||
if not added and not removed:
|
||||
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
|
||||
return
|
||||
|
||||
for filename in sorted(added):
|
||||
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
|
||||
model_id = self._current_model_context["id"]
|
||||
if model_id == "unknown":
|
||||
logger.error(
|
||||
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
|
||||
LOG_PREFIX,
|
||||
filename,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
|
||||
LOG_PREFIX,
|
||||
model_id,
|
||||
filename,
|
||||
self._current_model_context["name"],
|
||||
self._current_model_context["hash"],
|
||||
)
|
||||
|
||||
for filename in sorted(removed):
|
||||
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
|
||||
|
||||
logger.debug(
|
||||
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
len(added),
|
||||
len(removed),
|
||||
len(self._tracked_files),
|
||||
)
|
||||
|
||||
|
||||
_TRACKER = _SHMForensicsTracker()
|
||||
|
||||
|
||||
def start_shm_forensics() -> None:
|
||||
_TRACKER.start()
|
||||
|
||||
|
||||
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
|
||||
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
|
||||
|
||||
|
||||
def stop_shm_forensics() -> None:
|
||||
_TRACKER.stop()
|
||||
|
||||
|
||||
atexit.register(stop_shm_forensics)
|
||||
214
comfy/isolation/vae_proxy.py
Normal file
214
comfy/isolation/vae_proxy.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FirstStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "first_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
async def has_property(self, instance_id: str, name: str) -> bool:
|
||||
obj = self._get_instance(instance_id)
|
||||
return hasattr(obj, name)
|
||||
|
||||
|
||||
class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]):
|
||||
_registry_class = FirstStageModelRegistry
|
||||
__module__ = "comfy.ldm.models.autoencoder"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FirstStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class VAERegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "vae"
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return ModelPatcherRegistry().register(vae.patcher)
|
||||
|
||||
async def get_first_stage_model_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return FirstStageModelRegistry().register(vae.first_stage_model)
|
||||
|
||||
async def encode(self, instance_id: str, pixels: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(pixels))
|
||||
|
||||
async def encode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
pixels: Any,
|
||||
tile_x: int = 512,
|
||||
tile_y: int = 512,
|
||||
overlap: int = 64,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_tiled(
|
||||
pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap
|
||||
)
|
||||
)
|
||||
|
||||
async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs))
|
||||
|
||||
async def decode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).decode_tiled(
|
||||
samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_encode(shape, dtype)
|
||||
|
||||
async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_decode(shape, dtype)
|
||||
|
||||
async def process_input(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_input(image))
|
||||
|
||||
async def process_output(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_output(image))
|
||||
|
||||
|
||||
class VAEProxy(BaseProxy[VAERegistry]):
|
||||
_registry_class = VAERegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
@property
|
||||
def patcher(self) -> ModelPatcherProxy:
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@property
|
||||
def first_stage_model(self) -> FirstStageModelProxy:
|
||||
if not hasattr(self, "_first_stage_model_proxy"):
|
||||
fsm_id = self._call_rpc("get_first_stage_model_id")
|
||||
self._first_stage_model_proxy = FirstStageModelProxy(
|
||||
fsm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._first_stage_model_proxy
|
||||
|
||||
@property
|
||||
def vae_dtype(self) -> Any:
|
||||
return self._get_property("vae_dtype")
|
||||
|
||||
def encode(self, pixels: Any) -> Any:
|
||||
return self._call_rpc("encode", pixels)
|
||||
|
||||
def encode_tiled(
|
||||
self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64
|
||||
) -> Any:
|
||||
return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap)
|
||||
|
||||
def decode(self, samples: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc("decode", samples, **kwargs)
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"decode_tiled", samples, tile_x, tile_y, overlap, **kwargs
|
||||
)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def _get_property(self, name: str) -> Any:
|
||||
return self._call_rpc("get_property", name)
|
||||
|
||||
@property
|
||||
def latent_dim(self) -> int:
|
||||
return self._get_property("latent_dim")
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return self._get_property("latent_channels")
|
||||
|
||||
@property
|
||||
def downscale_ratio(self) -> Any:
|
||||
return self._get_property("downscale_ratio")
|
||||
|
||||
@property
|
||||
def upscale_ratio(self) -> Any:
|
||||
return self._get_property("upscale_ratio")
|
||||
|
||||
@property
|
||||
def output_channels(self) -> int:
|
||||
return self._get_property("output_channels")
|
||||
|
||||
@property
|
||||
def check_not_vide(self) -> bool:
|
||||
return self._get_property("not_video")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self._get_property("device")
|
||||
|
||||
@property
|
||||
def working_dtypes(self) -> Any:
|
||||
return self._get_property("working_dtypes")
|
||||
|
||||
@property
|
||||
def disable_offload(self) -> bool:
|
||||
return self._get_property("disable_offload")
|
||||
|
||||
@property
|
||||
def size(self) -> Any:
|
||||
return self._get_property("size")
|
||||
|
||||
def memory_used_encode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_encode", shape, dtype)
|
||||
|
||||
def memory_used_decode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_decode", shape, dtype)
|
||||
|
||||
def process_input(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_input", image)
|
||||
|
||||
def process_output(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_output", image)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_VAE_REGISTRY_SINGLETON = VAERegistry()
|
||||
_FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
@@ -12,8 +13,8 @@ from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.cli_args import args
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
@@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if isolation_active:
|
||||
target_device = sigmas.device
|
||||
if x.device != target_device:
|
||||
x = x.to(target_device)
|
||||
s_in = s_in.to(target_device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
|
||||
@@ -136,7 +136,16 @@ class ResBlock(nn.Module):
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
@@ -44,22 +44,6 @@ class FluxParams:
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@@ -154,7 +138,6 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
@@ -181,6 +164,10 @@ class Flux(nn.Module):
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
@@ -195,24 +182,6 @@ class Flux(nn.Module):
|
||||
else:
|
||||
pe = None
|
||||
|
||||
vec_orig = vec
|
||||
txt_vec = vec
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
modulation_dims = []
|
||||
batch = vec.shape[0] // 2
|
||||
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
|
||||
invert = invert_slices(timestep_zero_index, img.shape[1])
|
||||
for s in invert:
|
||||
modulation_dims.append((s[0], s[1], 0))
|
||||
for s in timestep_zero_index:
|
||||
modulation_dims.append((s[0], s[1], 1))
|
||||
extra_kwargs["modulation_dims_img"] = modulation_dims
|
||||
txt_vec = vec[:batch]
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
@@ -226,8 +195,7 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
@@ -245,8 +213,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options,
|
||||
**extra_kwargs)
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -263,12 +230,6 @@ class Flux(nn.Module):
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
lambda a: 0 if a == 0 else a + txt.shape[1]
|
||||
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
|
||||
extra_kwargs["modulation_dims"] = modulation_dims_combined
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
@@ -281,8 +242,7 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
@@ -293,7 +253,7 @@ class Flux(nn.Module):
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -304,11 +264,7 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
extra_kwargs["modulation_dims"] = modulation_dims
|
||||
|
||||
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
@@ -356,16 +312,13 @@ class Flux(nn.Module):
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
timestep_zero = ref_latents_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method in ("index", "index_timestep_zero"):
|
||||
if ref_latents_method == "index":
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
@@ -389,13 +342,6 @@ class Flux(nn.Module):
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
@@ -403,6 +349,6 @@ class Flux(nn.Module):
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@@ -343,7 +343,6 @@ class CrossAttention(nn.Module):
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
@@ -413,7 +412,6 @@ class Attention(nn.Module):
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
@@ -23,11 +23,6 @@ class CausalConv3d(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if isinstance(stride, int):
|
||||
self.time_stride = stride
|
||||
else:
|
||||
self.time_stride = stride[0]
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
@@ -63,25 +58,16 @@ class CausalConv3d(nn.Module):
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
input_length = sum([piece.shape[2] for piece in pieces])
|
||||
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and cache_length == 0:
|
||||
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
needs_caching = False
|
||||
if needs_caching and x.shape[2] >= cache_length:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
del pieces
|
||||
del cached
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
elif is_end:
|
||||
self.temporal_cache_state[tid] = (None, True)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -233,7 +232,10 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -244,14 +246,10 @@ class Encoder(nn.Module):
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
@@ -283,35 +281,9 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||
|
||||
outputs = []
|
||||
samples = [sample[:, :, :1, :, :]]
|
||||
if sample.shape[2] > 1:
|
||||
chunk_t = max(2, max_chunk_size // frame_size)
|
||||
if chunk_t < 4:
|
||||
chunk_t = 2
|
||||
elif chunk_t < 8:
|
||||
chunk_t = 4
|
||||
else:
|
||||
chunk_t = (chunk_t // 8) * 8
|
||||
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
|
||||
for chunk_idx, chunk in enumerate(samples):
|
||||
if chunk_idx == len(samples) - 1:
|
||||
mark_conv3d_ended(self)
|
||||
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||
output = self._forward_chunk(chunk)
|
||||
if output is not None:
|
||||
outputs.append(output)
|
||||
|
||||
return torch_cat_if_needed(outputs, dim=2)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
@@ -324,23 +296,7 @@ class Encoder(nn.Module):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||
|
||||
def get_max_chunk_size(device: torch.device) -> int:
|
||||
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||
|
||||
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||
return MIN_CHUNK_SIZE
|
||||
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||
return MAX_CHUNK_SIZE
|
||||
|
||||
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||
)
|
||||
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -500,17 +456,6 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||
ts, hs, ws, to = 1, 1, 1, 0
|
||||
for block in self.up_blocks:
|
||||
if isinstance(block, DepthToSpaceUpsample):
|
||||
ts *= block.stride[0]
|
||||
hs *= block.stride[1]
|
||||
ws *= block.stride[2]
|
||||
if block.stride[0] > 1:
|
||||
to = to * block.stride[0] + 1
|
||||
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
@@ -532,62 +477,11 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def decode_output_shape(self, input_shape):
|
||||
c, (ts, hs, ws), to = self._output_scale
|
||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||
|
||||
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
||||
sample = sample_ref[0]
|
||||
sample_ref[0] = None
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
output_offset[0] += t
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if ended:
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||
|
||||
if num_chunks == 1:
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
output_buffer: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
@@ -602,7 +496,6 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
timestep_shift_scale = None
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
@@ -630,18 +523,48 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
output = []
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample)
|
||||
return
|
||||
|
||||
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
return output_buffer
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
@@ -765,25 +688,12 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
tid = threading.get_ident()
|
||||
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||
if cached_input is not None:
|
||||
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||
cached_input = None
|
||||
|
||||
if self.stride[0] == 2 and pad_first:
|
||||
if self.stride[0] == 2:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
pad_first = False
|
||||
|
||||
if x.shape[2] < self.stride[0]:
|
||||
cached_input = x
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
return None
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
@@ -798,26 +708,15 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||
if cached_x is not None:
|
||||
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||
cached_x = None
|
||||
else:
|
||||
cached_x = x
|
||||
x = None
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
if x is not None:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
x = x + x_in
|
||||
|
||||
return x
|
||||
|
||||
@@ -1150,8 +1049,6 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
comfy_has_chunked_io = True
|
||||
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -1294,15 +1191,14 @@ class VideoVAE(nn.Module):
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x, device=None):
|
||||
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
def decode(self, x):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
@@ -99,7 +99,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
@@ -109,7 +109,22 @@ class Resample(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
@@ -130,24 +145,19 @@ class Resample(nn.Module):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :]
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
|
||||
deferred_x = feat_cache[idx + 1]
|
||||
if deferred_x is not None:
|
||||
x = torch.cat([deferred_x, x], 2)
|
||||
feat_cache[idx + 1] = None
|
||||
|
||||
if x.shape[2] == 1 and not final:
|
||||
feat_cache[idx + 1] = x
|
||||
x = None
|
||||
|
||||
feat_idx[0] += 2
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
@@ -167,12 +177,19 @@ class ResidualBlock(nn.Module):
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -196,7 +213,7 @@ class AttentionBlock(nn.Module):
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
@@ -266,10 +283,17 @@ class Encoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -279,16 +303,14 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if x is None:
|
||||
return None
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -296,7 +318,14 @@ class Encoder3d(nn.Module):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -360,48 +389,18 @@ class Decoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
|
||||
x = x_ref[0]
|
||||
x_ref[0] = None
|
||||
if layer_idx >= len(self.upsamples):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[feat_idx[0]])
|
||||
feat_cache[feat_idx[0]] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
out_chunks.append(x)
|
||||
return
|
||||
|
||||
layer = self.upsamples[layer_idx]
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
|
||||
for frame_idx in range(0, x.shape[2], 2):
|
||||
self.run_up(
|
||||
layer_idx + 1,
|
||||
[x[:, :, frame_idx:frame_idx + 2, :, :]],
|
||||
feat_cache,
|
||||
feat_idx.copy(),
|
||||
out_chunks,
|
||||
)
|
||||
del x
|
||||
return
|
||||
|
||||
next_x_ref = [x]
|
||||
del x
|
||||
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -410,21 +409,42 @@ class Decoder3d(nn.Module):
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
out_chunks = []
|
||||
|
||||
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
|
||||
return out_chunks
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_cache_layers(model):
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -462,12 +482,11 @@ class WanVAE(nn.Module):
|
||||
conv_idx = [0]
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
t = 1 + ((t - 1) // 4) * 4
|
||||
iter_ = 1 + (t - 1) // 2
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_cache_layers(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
@@ -477,23 +496,20 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
final=(i == (iter_ - 1)))
|
||||
if out_ is None:
|
||||
continue
|
||||
feat_idx=conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
# z: [b,c,t,h,w]
|
||||
iter_ = 1 + z.shape[2] // 2
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_cache_layers(self.decoder)
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -504,8 +520,8 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
out += out_
|
||||
return torch.cat(out, 2)
|
||||
out = torch.cat([out, out_], 2)
|
||||
return out
|
||||
|
||||
@@ -1,71 +1,9 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
thread_id: int
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
def read_tensor_file_slice_into(tensor, destination):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
return False
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||
if info is None:
|
||||
return False
|
||||
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
@@ -20,8 +20,8 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
@@ -113,8 +113,20 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
from comfy.cli_args import args
|
||||
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
if isolation_runtime_enabled:
|
||||
def __reduce__(self):
|
||||
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||
try:
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||
registry = ModelSamplingRegistry()
|
||||
ms_id = registry.register(self)
|
||||
return (ModelSamplingProxy, (ms_id,))
|
||||
except Exception as exc:
|
||||
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
@@ -286,12 +298,6 @@ class BaseModel(torch.nn.Module):
|
||||
return data
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||
Return a sliced cond object, or None to fall through to default handling.
|
||||
Use comfy.context_windows.slice_cond() for common cases."""
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@@ -1382,11 +1388,6 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "vace_context":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
@@ -1439,11 +1440,6 @@ class WAN21_HuMo(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_Animate(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
||||
@@ -1461,13 +1457,6 @@ class WAN22_Animate(WAN21):
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "face_pixel_values":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
|
||||
if cond_key == "pose_latents":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
@@ -1504,11 +1493,6 @@ class WAN22_S2V(WAN21):
|
||||
out['reference_motion'] = reference_motion.shape
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@@ -372,7 +372,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
@@ -400,7 +400,7 @@ try:
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
@@ -497,6 +497,9 @@ except:
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
def _isolation_mode_enabled():
|
||||
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@@ -505,28 +508,6 @@ def module_size(module):
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
def module_mmap_residency(module, free=False):
|
||||
mmap_touched_mem = 0
|
||||
module_mem = 0
|
||||
bounced_mmaps = set()
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||
continue
|
||||
mmap_touched_mem += t.nbytes
|
||||
if not free:
|
||||
continue
|
||||
storage._comfy_tensor_mmap_touched = False
|
||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||
if mmap_obj in bounced_mmaps:
|
||||
continue
|
||||
mmap_obj.bounce()
|
||||
bounced_mmaps.add(mmap_obj)
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self._set_model(model)
|
||||
@@ -541,7 +522,6 @@ class LoadedModel:
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
self._patcher_finalizer.atexit = False
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
@@ -555,9 +535,6 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return self.model.model_mmap_residency(free=free)
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
@@ -588,7 +565,6 @@ class LoadedModel:
|
||||
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
self.model_finalizer.atexit = False
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
@@ -603,8 +579,9 @@ class LoadedModel:
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
if self.model_finalizer is not None:
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
@@ -618,8 +595,15 @@ class LoadedModel:
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def dead_state(self):
|
||||
model_ref_gone = self.model is None
|
||||
real_model_ref = self.real_model
|
||||
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
|
||||
return model_ref_gone, real_model_ref_gone
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
model_ref_gone, real_model_ref_gone = self.dead_state()
|
||||
return model_ref_gone or real_model_ref_gone
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
@@ -660,11 +644,12 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
isolation_active = _isolation_mode_enabled()
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
@@ -673,14 +658,24 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
can_unload_sorted = sorted(can_unload)
|
||||
for x in can_unload_sorted:
|
||||
if can_unload and isolation_active:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
flush_tensor_keeper = None
|
||||
if callable(flush_tensor_keeper):
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
|
||||
gc.collect()
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
@@ -689,21 +684,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if pins_to_free > 0:
|
||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
if ram_to_free <= 0 and i not in unloaded_model:
|
||||
continue
|
||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||
if resident_memory > 0:
|
||||
if ram_to_free > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
unloaded = current_loaded_models.pop(i)
|
||||
model_obj = unloaded.model
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
unloaded_models.append(unloaded)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@@ -762,31 +754,23 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
for i in to_unload:
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
if model_to_unload.model_finalizer is not None:
|
||||
model_to_unload.model_finalizer.detach()
|
||||
model_to_unload.model_finalizer = None
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||
#make this JIT to keep as much pinned as possible.
|
||||
pins_required = model_memory - pinned_memory
|
||||
ram_required = model_memory - resident_memory
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required[device],
|
||||
ram_required=total_ram_required[device])
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@@ -835,25 +819,62 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
if not _isolation_mode_enabled():
|
||||
dead_found = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].is_dead():
|
||||
dead_found = True
|
||||
break
|
||||
|
||||
if dead_found:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
return
|
||||
|
||||
dead_found = False
|
||||
has_real_model_leak = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
dead_found = True
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
has_real_model_leak = True
|
||||
|
||||
if do_gc:
|
||||
if dead_found:
|
||||
if has_real_model_leak:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
model_ref_gone, real_model_ref_gone = cur.dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
|
||||
|
||||
def archive_model_dtypes(model):
|
||||
@@ -867,11 +888,20 @@ def archive_model_dtypes(model):
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
real_model_ref = current_loaded_models[i].real_model
|
||||
if real_model_ref is None:
|
||||
to_delete = [i] + to_delete
|
||||
continue
|
||||
if callable(real_model_ref) and real_model_ref() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
model_obj = getattr(x, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
@@ -1052,12 +1082,6 @@ def intermediate_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def intermediate_dtype():
|
||||
if args.fp16_intermediates:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def vae_device():
|
||||
if args.cpu_vae:
|
||||
return torch.device("cpu")
|
||||
@@ -1278,11 +1302,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||
storage._comfy_tensor_mmap_touched = True
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
@@ -1720,19 +1739,6 @@ def supports_nvfp4_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_mxfp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if torch_version_numeric < (2, 10):
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
|
||||
@@ -297,9 +297,6 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
@@ -1066,10 +1063,6 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def pinned_memory_size(self):
|
||||
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
@@ -1660,16 +1653,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
return freed
|
||||
|
||||
def pinned_memory_size(self):
|
||||
total = 0
|
||||
loading = self._load_list(for_dynamic=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
pin = comfy.pinned_memory.get_pin(m)
|
||||
if pin is not None:
|
||||
total += pin.numel() * pin.element_size()
|
||||
return total
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
|
||||
236
comfy/ops.py
236
comfy/ops.py
@@ -306,40 +306,10 @@ class CastWeightBiasOp:
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
@staticmethod
|
||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||
missing_keys, unexpected_keys, weight_shape,
|
||||
bias_shape=None):
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k, v in state_dict.items():
|
||||
key = k[prefix_len:]
|
||||
if key == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif bias_shape is not None and key == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
if module.weight is None:
|
||||
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "weight")
|
||||
|
||||
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
||||
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "bias")
|
||||
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
# don't trust subclasses that BYO state dict loader to call us.
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
@@ -360,21 +330,32 @@ class disable_weight_init:
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.in_features, self.out_features),
|
||||
bias_shape=(self.out_features,),
|
||||
)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
@@ -566,53 +547,6 @@ class disable_weight_init:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
||||
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
||||
_freeze=False, device=None, dtype=None):
|
||||
# don't trust subclasses that BYO state dict loader to call us.
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
||||
norm_type, scale_grad_by_freq, sparse, _weight,
|
||||
_freeze, device, dtype)
|
||||
return
|
||||
|
||||
torch.nn.Module.__init__(self)
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
# Keep shape/dtype visible for module introspection without reserving storage.
|
||||
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.bias = None
|
||||
self.weight_comfy_model_dtype = dtype
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
@@ -776,71 +710,6 @@ from .quant_ops import (
|
||||
)
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
q_input = inp
|
||||
|
||||
w = weight.detach() if weight.requires_grad else weight
|
||||
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
@@ -932,22 +801,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "mxfp8":
|
||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.uint8)
|
||||
|
||||
if block_scale is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
|
||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
@@ -1035,37 +888,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
_use_quantized = (
|
||||
getattr(self, 'layout_type', None) is not None and
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True
|
||||
)
|
||||
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
|
||||
output = QuantLinearFunc.apply(
|
||||
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||
)
|
||||
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@@ -1113,10 +939,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
@@ -1127,15 +950,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not mxfp8_compute:
|
||||
disabled.add("mxfp8")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -13,31 +12,18 @@ def pin_memory(module):
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
try:
|
||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||
except RuntimeError:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||
module._pin_hostbuf = hostbuf
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
del module._pin
|
||||
del module._pin_hostbuf
|
||||
return size
|
||||
|
||||
@@ -43,18 +43,6 @@ except ImportError as e:
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
|
||||
_CK_MXFP8_AVAILABLE = False
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||
_CK_MXFP8_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||
|
||||
if not _CK_MXFP8_AVAILABLE:
|
||||
class _CKMxfp8Layout:
|
||||
pass
|
||||
|
||||
import comfy.float
|
||||
|
||||
# ==============================================================================
|
||||
@@ -96,31 +84,6 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||
|
||||
params = cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
@@ -174,8 +137,6 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
@@ -196,14 +157,6 @@ QUANT_ALGOS = {
|
||||
},
|
||||
}
|
||||
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
QUANT_ALGOS["mxfp8"] = {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
|
||||
@@ -8,12 +8,12 @@ import comfy.nested_tensor
|
||||
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
|
||||
@@ -11,12 +11,14 @@ from functools import partial
|
||||
import collections
|
||||
import math
|
||||
import logging
|
||||
import os
|
||||
import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
from comfy.cli_args import args
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
result = executor.execute(model, conds, x_in, timestep, model_options)
|
||||
return result
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
|
||||
if memory_required * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
if isolation_active:
|
||||
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
|
||||
input_x = torch.cat(input_x).to(target_device)
|
||||
else:
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
if isolation_active:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
|
||||
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
|
||||
else:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
@@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
out_t = output[o]
|
||||
mult_t = mult[o]
|
||||
if isolation_active:
|
||||
target_dev = out_conds[cond_index].device
|
||||
if hasattr(out_t, "device") and out_t.device != target_dev:
|
||||
out_t = out_t.to(target_dev)
|
||||
if hasattr(mult_t, "device") and mult_t.device != target_dev:
|
||||
mult_t = mult_t.to(target_dev)
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
out_conds[cond_index] += out_t * mult_t
|
||||
out_counts[cond_index] += mult_t
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
@@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
out_c += out_t * mult_t
|
||||
out_cts += mult_t
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@@ -392,14 +413,31 @@ class KSamplerX0Inpaint:
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if denoise_mask is not None:
|
||||
if isolation_active and denoise_mask.device != x.device:
|
||||
denoise_mask = denoise_mask.to(x.device)
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||
if isolation_active:
|
||||
latent_image = self.latent_image
|
||||
if hasattr(latent_image, "device") and latent_image.device != x.device:
|
||||
latent_image = latent_image.to(x.device)
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
|
||||
if hasattr(scaled, "device") and scaled.device != x.device:
|
||||
scaled = scaled.to(x.device)
|
||||
else:
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(
|
||||
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
|
||||
)
|
||||
x = x * denoise_mask + scaled * latent_mask
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
latent_image = self.latent_image
|
||||
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
|
||||
latent_image = latent_image.to(out.device)
|
||||
out = out * denoise_mask + latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model_sampling, steps):
|
||||
@@ -741,7 +779,11 @@ class KSAMPLER(Sampler):
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
model_sampling = model_wrap.inner_model.model_sampling
|
||||
noise = model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
@@ -985,8 +1027,8 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
@@ -1028,7 +1070,6 @@ class CFGGuider:
|
||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||
else:
|
||||
denoise_mask = denoise_masks[0]
|
||||
denoise_mask = denoise_mask.float()
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
|
||||
55
comfy/sd.py
55
comfy/sd.py
@@ -455,7 +455,7 @@ class VAE:
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
@@ -871,16 +871,13 @@ class VAE:
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
return pixels
|
||||
|
||||
def vae_output_dtype(self):
|
||||
return model_management.intermediate_dtype()
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
@@ -890,16 +887,16 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@@ -908,7 +905,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@@ -917,7 +914,7 @@ class VAE:
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
@@ -926,7 +923,7 @@ class VAE:
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
@@ -935,7 +932,7 @@ class VAE:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
@@ -951,23 +948,12 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
@@ -978,7 +964,6 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
@@ -1039,15 +1024,10 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
@@ -1060,7 +1040,6 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
|
||||
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
|
||||
@@ -20,8 +20,6 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import ctypes
|
||||
import os
|
||||
import comfy.memory_management
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
@@ -34,7 +32,7 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import mmap
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@@ -83,17 +81,14 @@ _TYPES = {
|
||||
}
|
||||
|
||||
def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
|
||||
header_size = struct.unpack("<Q", mv[:8])[0]
|
||||
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
|
||||
|
||||
mv = mv[(data_base_offset := 8 + header_size):]
|
||||
mv = mv[8 + header_size:]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
@@ -107,14 +102,7 @@ def load_safetensors(ckpt):
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||
sd[name] = tensor
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
@@ -897,10 +885,6 @@ def set_attr(obj, attr, value):
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
# Clone inference tensors (created under torch.inference_mode) since
|
||||
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||
if (not torch.is_inference_mode_enabled()) and value.is_inference():
|
||||
value = value.clone()
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
@@ -1135,8 +1119,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = output[b:b+1].zero_()
|
||||
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
@@ -1151,7 +1135,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
@@ -1174,7 +1158,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
out.div_(out_div)
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
|
||||
@@ -25,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
self.caching = self.Caching()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||
@@ -85,36 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
class Caching(ProxiedSingleton):
|
||||
"""
|
||||
External cache provider API for sharing cached node outputs
|
||||
across ComfyUI instances.
|
||||
|
||||
Example::
|
||||
|
||||
from comfy_api.latest import Caching
|
||||
|
||||
class MyCacheProvider(Caching.CacheProvider):
|
||||
async def on_lookup(self, context):
|
||||
... # check external storage
|
||||
|
||||
async def on_store(self, context, value):
|
||||
... # store to external storage
|
||||
|
||||
Caching.register_provider(MyCacheProvider())
|
||||
"""
|
||||
from ._caching import CacheProvider, CacheContext, CacheValue
|
||||
|
||||
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||
"""Register an external cache provider. Providers are called in registration order."""
|
||||
from comfy_execution.cache_provider import register_cache_provider
|
||||
register_cache_provider(provider)
|
||||
|
||||
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||
"""Unregister a previously registered cache provider."""
|
||||
from comfy_execution.cache_provider import unregister_cache_provider
|
||||
unregister_cache_provider(provider)
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
@@ -147,9 +116,6 @@ class Types:
|
||||
VOXEL = VOXEL
|
||||
File3D = File3D
|
||||
|
||||
|
||||
Caching = ComfyAPI_latest.Caching
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
# Create a synchronous version of the API
|
||||
@@ -169,7 +135,6 @@ __all__ = [
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"Caching",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"IO",
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheContext:
|
||||
node_id: str
|
||||
class_type: str
|
||||
cache_key_hash: str # SHA256 hex digest
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheValue:
|
||||
outputs: list
|
||||
ui: dict = None
|
||||
|
||||
|
||||
class CacheProvider(ABC):
|
||||
"""Abstract base class for external cache providers.
|
||||
Exceptions from provider methods are caught by the caller and never break execution.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
"""Called after local store. Dispatched via asyncio.create_task."""
|
||||
pass
|
||||
|
||||
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||
"""Return False to skip external caching for this node. Default: True."""
|
||||
return True
|
||||
|
||||
def on_prompt_start(self, prompt_id: str) -> None:
|
||||
pass
|
||||
|
||||
def on_prompt_end(self, prompt_id: str) -> None:
|
||||
pass
|
||||
@@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
@@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if self.__duration and frame.time > start_time + self.__duration:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
|
||||
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D, PLY as _PLY, NPZ as _NPZ
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
|
||||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
@@ -678,6 +678,16 @@ class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
|
||||
@comfytype(io_type="PLY")
|
||||
class Ply(ComfyTypeIO):
|
||||
Type = _PLY
|
||||
|
||||
|
||||
@comfytype(io_type="NPZ")
|
||||
class Npz(ComfyTypeIO):
|
||||
Type = _NPZ
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
"""General 3D file type - accepts any supported 3D format."""
|
||||
@@ -2197,6 +2207,8 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"Ply",
|
||||
"Npz",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .image_types import SVG
|
||||
from .ply_types import PLY
|
||||
from .npz_types import NPZ
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
@@ -11,4 +13,6 @@ __all__ = [
|
||||
"MESH",
|
||||
"File3D",
|
||||
"SVG",
|
||||
"PLY",
|
||||
"NPZ",
|
||||
]
|
||||
|
||||
27
comfy_api/latest/_util/npz_types.py
Normal file
27
comfy_api/latest/_util/npz_types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class NPZ:
|
||||
"""Ordered collection of NPZ file payloads.
|
||||
|
||||
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
|
||||
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
|
||||
``save_to`` writes numbered files into a directory.
|
||||
"""
|
||||
|
||||
def __init__(self, frames: list[bytes]) -> None:
|
||||
self.frames = frames
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return len(self.frames)
|
||||
|
||||
def save_to(self, directory: str, prefix: str = "frame") -> str:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for i, frame_bytes in enumerate(self.frames):
|
||||
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
|
||||
with open(path, "wb") as f:
|
||||
f.write(frame_bytes)
|
||||
return directory
|
||||
97
comfy_api/latest/_util/ply_types.py
Normal file
97
comfy_api/latest/_util/ply_types.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PLY:
|
||||
"""Point cloud payload for PLY file output.
|
||||
|
||||
Supports two schemas:
|
||||
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
|
||||
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
|
||||
|
||||
When ``raw_data`` is provided, the object acts as an opaque binary PLY
|
||||
carrier and ``save_to`` writes the bytes directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
points: np.ndarray | None = None,
|
||||
colors: np.ndarray | None = None,
|
||||
confidence: np.ndarray | None = None,
|
||||
view_id: np.ndarray | None = None,
|
||||
raw_data: bytes | None = None,
|
||||
) -> None:
|
||||
self.raw_data = raw_data
|
||||
if raw_data is not None:
|
||||
self.points = None
|
||||
self.colors = None
|
||||
self.confidence = None
|
||||
self.view_id = None
|
||||
return
|
||||
if points is None:
|
||||
raise ValueError("Either points or raw_data must be provided")
|
||||
if points.ndim != 2 or points.shape[1] != 3:
|
||||
raise ValueError(f"points must be (N, 3), got {points.shape}")
|
||||
self.points = np.ascontiguousarray(points, dtype=np.float32)
|
||||
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
|
||||
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
|
||||
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
|
||||
|
||||
@property
|
||||
def is_gaussian(self) -> bool:
|
||||
return self.raw_data is not None
|
||||
|
||||
@property
|
||||
def num_points(self) -> int:
|
||||
if self.points is not None:
|
||||
return self.points.shape[0]
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
if self.raw_data is not None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(self.raw_data)
|
||||
return path
|
||||
self.points = self._to_numpy(self.points, np.float32)
|
||||
self.colors = self._to_numpy(self.colors, np.float32)
|
||||
self.confidence = self._to_numpy(self.confidence, np.float32)
|
||||
self.view_id = self._to_numpy(self.view_id, np.int32)
|
||||
N = self.num_points
|
||||
header_lines = [
|
||||
"ply",
|
||||
"format ascii 1.0",
|
||||
f"element vertex {N}",
|
||||
"property float x",
|
||||
"property float y",
|
||||
"property float z",
|
||||
]
|
||||
if self.colors is not None:
|
||||
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
|
||||
if self.confidence is not None:
|
||||
header_lines.append("property float confidence")
|
||||
if self.view_id is not None:
|
||||
header_lines.append("property int view_id")
|
||||
header_lines.append("end_header")
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(header_lines) + "\n")
|
||||
for i in range(N):
|
||||
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
|
||||
if self.colors is not None:
|
||||
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
|
||||
parts.append(f"{r} {g} {b}")
|
||||
if self.confidence is not None:
|
||||
parts.append(f"{self.confidence[i]}")
|
||||
if self.view_id is not None:
|
||||
parts.append(f"{int(self.view_id[i])}")
|
||||
f.write(" ".join(parts) + "\n")
|
||||
return path
|
||||
@@ -67,7 +67,6 @@ class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
thought: bool | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
|
||||
@@ -29,21 +29,13 @@ class ImageEditRequest(BaseModel):
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
image: InputUrlObject | None = Field(None)
|
||||
reference_images: list[InputUrlObject] | None = Field(None)
|
||||
image: InputUrlObject | None = Field(...)
|
||||
duration: int = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class VideoExtensionRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
video: InputUrlObject = Field(...)
|
||||
duration: int = Field(default=6)
|
||||
model: str | None = Field(default=None)
|
||||
|
||||
|
||||
class VideoEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class QuiverImageObject(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class QuiverTextToSVGRequest(BaseModel):
|
||||
model: str = Field(default="arrow-preview")
|
||||
prompt: str = Field(...)
|
||||
instructions: str | None = Field(default=None)
|
||||
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
|
||||
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||
|
||||
|
||||
class QuiverImageToSVGRequest(BaseModel):
|
||||
model: str = Field(default="arrow-preview")
|
||||
image: QuiverImageObject = Field(...)
|
||||
auto_crop: bool | None = Field(default=None)
|
||||
target_size: int | None = Field(default=None, ge=128, le=4096)
|
||||
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||
|
||||
|
||||
class QuiverSVGResponseItem(BaseModel):
|
||||
svg: str = Field(...)
|
||||
mime_type: str | None = Field(default="image/svg+xml")
|
||||
|
||||
|
||||
class QuiverSVGUsage(BaseModel):
|
||||
total_tokens: int | None = Field(default=None)
|
||||
input_tokens: int | None = Field(default=None)
|
||||
output_tokens: int | None = Field(default=None)
|
||||
|
||||
|
||||
class QuiverSVGResponse(BaseModel):
|
||||
id: str | None = Field(default=None)
|
||||
created: int | None = Field(default=None)
|
||||
data: list[QuiverSVGResponseItem] = Field(...)
|
||||
usage: QuiverSVGUsage | None = Field(default=None)
|
||||
@@ -47,10 +47,6 @@ SEEDREAM_MODELS = {
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
@@ -139,7 +135,6 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03}""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -947,7 +942,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
]
|
||||
return await process_video_task(
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -957,12 +952,6 @@ async def process_video_task(
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
|
||||
@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$isFlash := $contains($m, "nano banana 2");
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
||||
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
@@ -188,12 +188,10 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if (part.thought is True) != thought:
|
||||
continue
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
@@ -933,11 +931,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -999,11 +992,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
||||
@@ -8,7 +8,6 @@ from comfy_api_nodes.apis.grok import (
|
||||
ImageGenerationResponse,
|
||||
InputUrlObject,
|
||||
VideoEditRequest,
|
||||
VideoExtensionRequest,
|
||||
VideoGenerationRequest,
|
||||
VideoGenerationResponse,
|
||||
VideoStatusResponse,
|
||||
@@ -22,7 +21,6 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
@@ -35,13 +33,6 @@ def _extract_grok_price(response) -> float | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_grok_video_price(response) -> float | None:
|
||||
price = _extract_grok_price(response)
|
||||
if price is not None:
|
||||
return price * 1.43
|
||||
return None
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@@ -363,8 +354,6 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "grok-imagine-video-beta":
|
||||
model = "grok-imagine-video"
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
@@ -473,244 +462,6 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="api node/video/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="reference_",
|
||||
min=1,
|
||||
max=7,
|
||||
),
|
||||
tooltip="Up to 7 reference images to guide the video generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=6,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="The duration of the output video in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model.duration", "model.resolution"],
|
||||
input_groups=["model.reference_images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$refs := inputGroups["model.reference_images"];
|
||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
ref_image_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
list(model["reference_images"].values()),
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading base images",
|
||||
max_images=7,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model["model"],
|
||||
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||
prompt=prompt,
|
||||
resolution=model["resolution"],
|
||||
duration=model["duration"],
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoExtendNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="api node/video/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of what should happen next in the video.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=8,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Length of the extension in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video extension.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||
video_size = get_fs_object_size(video.get_stream_source())
|
||||
if video_size > 50 * 1024 * 1024:
|
||||
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||
data=VideoExtensionRequest(
|
||||
prompt=prompt,
|
||||
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||
duration=model["duration"],
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -718,9 +469,7 @@ class GrokExtension(ComfyExtension):
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokVideoNode,
|
||||
GrokVideoReferenceNode,
|
||||
GrokVideoEditNode,
|
||||
GrokVideoExtendNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
@@ -21,10 +17,7 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
@@ -43,68 +36,6 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
||||
)
|
||||
|
||||
|
||||
class ObjZipResult:
|
||||
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Types.File3D,
|
||||
texture: Input.Image | None = None,
|
||||
metallic: Input.Image | None = None,
|
||||
normal: Input.Image | None = None,
|
||||
roughness: Input.Image | None = None,
|
||||
):
|
||||
self.obj = obj
|
||||
self.texture = texture
|
||||
self.metallic = metallic
|
||||
self.normal = normal
|
||||
self.roughness = roughness
|
||||
|
||||
|
||||
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
|
||||
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
|
||||
|
||||
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
|
||||
identified by their filename suffixes.
|
||||
"""
|
||||
data = BytesIO()
|
||||
await download_url_to_bytesio(url, data)
|
||||
data.seek(0)
|
||||
if not zipfile.is_zipfile(data):
|
||||
data.seek(0)
|
||||
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
|
||||
data.seek(0)
|
||||
obj_bytes = None
|
||||
textures: dict[str, Input.Image] = {}
|
||||
with zipfile.ZipFile(data) as zf:
|
||||
for name in zf.namelist():
|
||||
lower = name.lower()
|
||||
if lower.endswith(".obj"):
|
||||
obj_bytes = zf.read(name)
|
||||
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
|
||||
stem = lower.rsplit(".", 1)[0]
|
||||
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
|
||||
matched_key = "texture"
|
||||
for suffix, key in {
|
||||
"_metallic": "metallic",
|
||||
"_normal": "normal",
|
||||
"_roughness": "roughness",
|
||||
}.items():
|
||||
if stem.endswith(suffix):
|
||||
matched_key = key
|
||||
break
|
||||
textures[matched_key] = tensor
|
||||
if obj_bytes is None:
|
||||
raise ValueError("ZIP archive does not contain an OBJ file.")
|
||||
return ObjZipResult(
|
||||
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
|
||||
texture=textures.get("texture"),
|
||||
metallic=textures.get("metallic"),
|
||||
normal=textures.get("normal"),
|
||||
roughness=textures.get("roughness"),
|
||||
)
|
||||
|
||||
|
||||
def get_file_from_response(
|
||||
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||
) -> ResultFile3D | None:
|
||||
@@ -162,7 +93,6 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -221,14 +151,14 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
obj_result.obj,
|
||||
obj_result.texture,
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -281,10 +211,6 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
IO.Image.Output(display_name="optional_metallic"),
|
||||
IO.Image.Output(display_name="optional_normal"),
|
||||
IO.Image.Output(display_name="optional_roughness"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -378,17 +304,14 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
obj_result.obj,
|
||||
obj_result.texture,
|
||||
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
|
||||
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
||||
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -508,8 +431,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -558,8 +480,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
)
|
||||
|
||||
|
||||
@@ -733,7 +654,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
TencentModelTo3DUVNode,
|
||||
Tencent3DTextureEditNode,
|
||||
# Tencent3DTextureEditNode,
|
||||
Tencent3DPartNode,
|
||||
TencentSmartTopologyNode,
|
||||
]
|
||||
|
||||
@@ -1459,7 +1459,6 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling 3.0 Omni Edit Video",
|
||||
category="api node/video/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
from io import BytesIO
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.quiver import (
|
||||
QuiverImageObject,
|
||||
QuiverImageToSVGRequest,
|
||||
QuiverSVGResponse,
|
||||
QuiverTextToSVGRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_extras.nodes_images import SVG
|
||||
|
||||
|
||||
class QuiverTextToSVGNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="QuiverTextToSVGNode",
|
||||
display_name="Quiver Text to SVG",
|
||||
category="api node/image/Quiver",
|
||||
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired SVG output.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"instructions",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Additional style or formatting guidance.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="ref_",
|
||||
min=0,
|
||||
max=4,
|
||||
),
|
||||
tooltip="Up to 4 reference images to guide the generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"arrow-preview",
|
||||
[
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Randomness control. Higher values increase randomness.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=1.0,
|
||||
min=0.05,
|
||||
max=1.0,
|
||||
step=0.05,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Nucleus sampling parameter.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"presence_penalty",
|
||||
default=0.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Token presence penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for SVG generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.429}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
instructions: str = None,
|
||||
reference_images: IO.Autogrow.Type = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
|
||||
references = None
|
||||
if reference_images:
|
||||
references = []
|
||||
for key in reference_images:
|
||||
url = await upload_image_to_comfyapi(cls, reference_images[key])
|
||||
references.append(QuiverImageObject(url=url))
|
||||
if len(references) > 4:
|
||||
raise ValueError("Maximum 4 reference images are allowed.")
|
||||
|
||||
instructions_val = instructions.strip() if instructions else None
|
||||
if instructions_val == "":
|
||||
instructions_val = None
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
|
||||
response_model=QuiverSVGResponse,
|
||||
data=QuiverTextToSVGRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
instructions=instructions_val,
|
||||
references=references,
|
||||
temperature=model.get("temperature"),
|
||||
top_p=model.get("top_p"),
|
||||
presence_penalty=model.get("presence_penalty"),
|
||||
),
|
||||
)
|
||||
|
||||
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class QuiverImageToSVGNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="QuiverImageToSVGNode",
|
||||
display_name="Quiver Image to SVG",
|
||||
category="api node/image/Quiver",
|
||||
description="Vectorize a raster image into SVG using Quiver AI.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Input image to vectorize.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_crop",
|
||||
default=False,
|
||||
tooltip="Automatically crop to the dominant subject.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"arrow-preview",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"target_size",
|
||||
default=1024,
|
||||
min=128,
|
||||
max=4096,
|
||||
tooltip="Square resize target in pixels.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Randomness control. Higher values increase randomness.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=1.0,
|
||||
min=0.05,
|
||||
max=1.0,
|
||||
step=0.05,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Nucleus sampling parameter.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"presence_penalty",
|
||||
default=0.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Token presence penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for SVG vectorization.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.429}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image,
|
||||
auto_crop: bool,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_url = await upload_image_to_comfyapi(cls, image)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
|
||||
response_model=QuiverSVGResponse,
|
||||
data=QuiverImageToSVGRequest(
|
||||
model=model["model"],
|
||||
image=QuiverImageObject(url=image_url),
|
||||
auto_crop=auto_crop if auto_crop else None,
|
||||
target_size=model.get("target_size"),
|
||||
temperature=model.get("temperature"),
|
||||
top_p=model.get("top_p"),
|
||||
presence_penalty=model.get("presence_penalty"),
|
||||
),
|
||||
)
|
||||
|
||||
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class QuiverExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
QuiverTextToSVGNode,
|
||||
QuiverImageToSVGNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> QuiverExtension:
|
||||
return QuiverExtension()
|
||||
@@ -833,7 +833,6 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
||||
node_id="RecraftVectorizeImageNode",
|
||||
display_name="Recraft Vectorize Image",
|
||||
category="api node/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Generates SVG synchronously from an input image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
|
||||
@@ -38,7 +38,6 @@ from comfy_api_nodes.util import (
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
"Starlight Precise 2.5": "slp-2.5",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
from typing import Any, Optional, Tuple, List
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
|
||||
# Public types — source of truth is comfy_api.latest._caching
|
||||
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_providers: List[CacheProvider] = []
|
||||
_providers_lock = threading.Lock()
|
||||
_providers_snapshot: Tuple[CacheProvider, ...] = ()
|
||||
|
||||
|
||||
def register_cache_provider(provider: CacheProvider) -> None:
|
||||
"""Register an external cache provider. Providers are called in registration order."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
if provider in _providers:
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||
return
|
||||
_providers.append(provider)
|
||||
_providers_snapshot = tuple(_providers)
|
||||
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
|
||||
|
||||
|
||||
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
try:
|
||||
_providers.remove(provider)
|
||||
_providers_snapshot = tuple(_providers)
|
||||
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||
except ValueError:
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||
|
||||
|
||||
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||
return _providers_snapshot
|
||||
|
||||
|
||||
def _has_cache_providers() -> bool:
|
||||
return bool(_providers_snapshot)
|
||||
|
||||
|
||||
def _clear_cache_providers() -> None:
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
_providers.clear()
|
||||
_providers_snapshot = ()
|
||||
|
||||
|
||||
def _canonicalize(obj: Any) -> Any:
|
||||
# Convert to canonical JSON-serializable form with deterministic ordering.
|
||||
# Frozensets have non-deterministic iteration order between Python sessions.
|
||||
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
|
||||
# _serialize_cache_key returns None and external caching is skipped.
|
||||
if isinstance(obj, frozenset):
|
||||
return ("__frozenset__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, tuple):
|
||||
return ("__tuple__", [_canonicalize(item) for item in obj])
|
||||
elif isinstance(obj, list):
|
||||
return [_canonicalize(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {"__dict__": sorted(
|
||||
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
)}
|
||||
elif isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return (type(obj).__name__, obj)
|
||||
elif isinstance(obj, bytes):
|
||||
return ("__bytes__", obj.hex())
|
||||
else:
|
||||
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
|
||||
|
||||
|
||||
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
||||
# Returns deterministic SHA256 hex digest, or None on failure.
|
||||
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
|
||||
try:
|
||||
canonical = _canonicalize(cache_key)
|
||||
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to serialize cache key: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _contains_self_unequal(obj: Any) -> bool:
|
||||
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
|
||||
# never hit locally, but serialized form would match externally. Skip these.
|
||||
try:
|
||||
if not (obj == obj):
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||
return any(_contains_self_unequal(item) for item in obj)
|
||||
if isinstance(obj, dict):
|
||||
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
|
||||
if hasattr(obj, 'value'):
|
||||
return _contains_self_unequal(obj.value)
|
||||
return False
|
||||
|
||||
|
||||
def _estimate_value_size(value: CacheValue) -> int:
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
|
||||
def estimate(obj):
|
||||
nonlocal total
|
||||
if isinstance(obj, torch.Tensor):
|
||||
total += obj.numel() * obj.element_size()
|
||||
elif isinstance(obj, dict):
|
||||
for v in obj.values():
|
||||
estimate(v)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
estimate(item)
|
||||
|
||||
for output in value.outputs:
|
||||
estimate(output)
|
||||
return total
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
@@ -148,15 +147,13 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
def __init__(self, key_class):
|
||||
self.key_class = key_class
|
||||
self.initialized = False
|
||||
self.enable_providers = enable_providers
|
||||
self.dynprompt: DynamicPrompt
|
||||
self.cache_key_set: CacheKeySet
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
self._pending_store_tasks: set = set()
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
@@ -199,138 +196,18 @@ class BasicCache:
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get_local(self, node_id):
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
return None
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
async def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
await self._notify_providers_store(node_id, cache_key, value)
|
||||
|
||||
async def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
external_result = await self._check_providers_lookup(node_id, cache_key)
|
||||
if external_result is not None:
|
||||
self.cache[cache_key] = external_result
|
||||
return external_result
|
||||
|
||||
return None
|
||||
|
||||
async def _notify_providers_store(self, node_id, cache_key, value):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_self_unequal, _logger
|
||||
)
|
||||
|
||||
if not self.enable_providers:
|
||||
return
|
||||
if not _has_cache_providers():
|
||||
return
|
||||
if not self._is_external_cacheable_value(value):
|
||||
return
|
||||
if _contains_self_unequal(cache_key):
|
||||
return
|
||||
|
||||
context = self._build_context(node_id, cache_key)
|
||||
if context is None:
|
||||
return
|
||||
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if provider.should_cache(context, cache_value):
|
||||
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
|
||||
self._pending_store_tasks.add(task)
|
||||
task.add_done_callback(self._pending_store_tasks.discard)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _safe_provider_store(provider, context, cache_value):
|
||||
from comfy_execution.cache_provider import _logger
|
||||
try:
|
||||
await provider.on_store(context, cache_value)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
||||
|
||||
async def _check_providers_lookup(self, node_id, cache_key):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_self_unequal, _logger
|
||||
)
|
||||
|
||||
if not self.enable_providers:
|
||||
return None
|
||||
if not _has_cache_providers():
|
||||
return None
|
||||
if _contains_self_unequal(cache_key):
|
||||
return None
|
||||
|
||||
context = self._build_context(node_id, cache_key)
|
||||
if context is None:
|
||||
return None
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if not provider.should_cache(context):
|
||||
continue
|
||||
result = await provider.on_lookup(context)
|
||||
if result is not None:
|
||||
if not isinstance(result, CacheValue):
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||
continue
|
||||
if not isinstance(result.outputs, (list, tuple)):
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||
continue
|
||||
from execution import CacheEntry
|
||||
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _is_external_cacheable_value(self, value):
|
||||
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||
|
||||
def _get_class_type(self, node_id):
|
||||
if not self.initialized or not self.dynprompt:
|
||||
return ''
|
||||
try:
|
||||
return self.dynprompt.get_node(node_id).get('class_type', '')
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
def _build_context(self, node_id, cache_key):
|
||||
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
|
||||
try:
|
||||
cache_key_hash = _serialize_cache_key(cache_key)
|
||||
if cache_key_hash is None:
|
||||
return None
|
||||
return CacheContext(
|
||||
node_id=node_id,
|
||||
class_type=self._get_class_type(node_id),
|
||||
cache_key_hash=cache_key_hash,
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
@@ -359,8 +236,8 @@ class BasicCache:
|
||||
return result
|
||||
|
||||
class HierarchicalCache(BasicCache):
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
super().__init__(key_class, enable_providers=enable_providers)
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class)
|
||||
|
||||
def _get_cache_for(self, node_id):
|
||||
assert self.dynprompt is not None
|
||||
@@ -380,27 +257,16 @@ class HierarchicalCache(BasicCache):
|
||||
return None
|
||||
return cache
|
||||
|
||||
async def get(self, node_id):
|
||||
def get(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return await cache._get_immediate(node_id)
|
||||
return cache._get_immediate(node_id)
|
||||
|
||||
def get_local(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return BasicCache.get_local(cache, node_id)
|
||||
|
||||
async def set(self, node_id, value):
|
||||
def set(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
await cache._set_immediate(node_id, value)
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
BasicCache.set_local(cache, node_id, value)
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
@@ -421,24 +287,18 @@ class NullCache:
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def get(self, node_id):
|
||||
def get(self, node_id):
|
||||
return None
|
||||
|
||||
def get_local(self, node_id):
|
||||
return None
|
||||
|
||||
async def set(self, node_id, value):
|
||||
pass
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
def set(self, node_id, value):
|
||||
pass
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
return self
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100, enable_providers=False):
|
||||
super().__init__(key_class, enable_providers=enable_providers)
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
self.max_size = max_size
|
||||
self.min_generation = 0
|
||||
self.generation = 0
|
||||
@@ -462,18 +322,18 @@ class LRUCache(BasicCache):
|
||||
del self.children[key]
|
||||
self._clean_subcaches()
|
||||
|
||||
async def get(self, node_id):
|
||||
def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return await self._get_immediate(node_id)
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def _mark_used(self, node_id):
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key is not None:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
|
||||
async def set(self, node_id, value):
|
||||
def set(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
return await self._set_immediate(node_id, value)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
@@ -506,20 +366,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
super().__init__(key_class, 0, enable_providers=enable_providers)
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
|
||||
def clean_unused(self):
|
||||
self._clean_subcaches()
|
||||
|
||||
async def set(self, node_id, value):
|
||||
def set(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
await super().set(node_id, value)
|
||||
super().set(node_id, value)
|
||||
|
||||
async def get(self, node_id):
|
||||
def get(self, node_id):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return await super().get(node_id)
|
||||
return super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
|
||||
@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get_local(node_id) is not None
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if to_node_id not in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if from_node_id not in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
|
||||
if value is None:
|
||||
return None
|
||||
#Write back to the main cache on touch.
|
||||
self.output_cache.set_local(from_node_id, value)
|
||||
self.output_cache.set(from_node_id, value)
|
||||
return value
|
||||
|
||||
def cache_update(self, node_id, value):
|
||||
|
||||
@@ -19,7 +19,6 @@ class EmptyLatentAudio(IO.ComfyNode):
|
||||
node_id="EmptyLatentAudio",
|
||||
display_name="Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
@@ -186,7 +185,6 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import torch
|
||||
|
||||
|
||||
class Canny(io.ComfyNode):
|
||||
@@ -30,8 +29,8 @@ class Canny(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
|
||||
output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
return io.NodeOutput(img_out)
|
||||
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
|
||||
@@ -6,7 +6,6 @@ import comfy.model_management
|
||||
import torch
|
||||
import math
|
||||
import nodes
|
||||
import comfy.ldm.flux.math
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -232,68 +231,6 @@ class Flux2Scheduler(io.ComfyNode):
|
||||
sigmas = get_schedule(steps, round(seq_len))
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class KV_Attn_Input:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def __call__(self, q, k, v, extra_options, **kwargs):
|
||||
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
||||
if len(reference_image_num_tokens) == 0:
|
||||
return {}
|
||||
|
||||
ref_toks = sum(reference_image_num_tokens)
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
|
||||
self.set_cache = True
|
||||
return {"q": q, "k": k, "v": v}
|
||||
|
||||
def cleanup(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
class FluxKVCache(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FluxKVCache",
|
||||
display_name="Flux KV Cache",
|
||||
description="Enables KV Cache optimization for reference images on Flux family models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
input_patch_obj = KV_Attn_Input()
|
||||
|
||||
def model_input_patch(inputs):
|
||||
if len(input_patch_obj.cache) > 0:
|
||||
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
||||
if ref_image_tokens > 0:
|
||||
img = inputs["img"]
|
||||
inputs["img"] = img[:, :-ref_image_tokens]
|
||||
return inputs
|
||||
|
||||
m.set_model_attn1_patch(input_patch_obj)
|
||||
m.set_model_post_input_patch(model_input_patch)
|
||||
if hasattr(model.model.diffusion_model, "params"):
|
||||
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
||||
else:
|
||||
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class FluxExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -306,7 +243,6 @@ class FluxExtension(ComfyExtension):
|
||||
FluxKontextMultiReferenceLatentMethod,
|
||||
EmptyFlux2LatentImage,
|
||||
Flux2Scheduler,
|
||||
FluxKVCache,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ class ImageCompare(IO.ComfyNode):
|
||||
display_name="Image Compare",
|
||||
description="Compares two images side by side with a slider.",
|
||||
category="image",
|
||||
essentials_category="Image Tools",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
|
||||
@@ -58,7 +58,6 @@ class ImageCropV2(IO.ComfyNode):
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io, UI
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0.0, 0.0, 0.0)
|
||||
r = int(hex_color[0:2], 16) / 255.0
|
||||
g = int(hex_color[2:4], 16) / 255.0
|
||||
b = int(hex_color[4:6], 16) / 255.0
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
class PainterNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional base image to paint over",
|
||||
),
|
||||
io.String.Input(
|
||||
"mask",
|
||||
default="",
|
||||
socketless=True,
|
||||
extra_dict={"widgetType": "PAINTER", "image_upload": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Color.Input("bg_color", default="#000000"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("IMAGE"),
|
||||
io.Mask.Output("MASK"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
|
||||
if image is not None:
|
||||
base_image = image[:1]
|
||||
h, w = base_image.shape[1], base_image.shape[2]
|
||||
else:
|
||||
h, w = height, width
|
||||
r, g, b = hex_to_rgb(bg_color)
|
||||
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
|
||||
base_image[0, :, :, 0] = r
|
||||
base_image[0, :, :, 1] = g
|
||||
base_image[0, :, :, 2] = b
|
||||
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
painter_img = node_helpers.pillow(Image.open, mask_path)
|
||||
painter_img = painter_img.convert("RGBA")
|
||||
|
||||
if painter_img.size != (w, h):
|
||||
painter_img = painter_img.resize((w, h), Image.LANCZOS)
|
||||
|
||||
painter_np = np.array(painter_img).astype(np.float32) / 255.0
|
||||
painter_rgb = painter_np[:, :, :3]
|
||||
painter_alpha = painter_np[:, :, 3:4]
|
||||
|
||||
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
|
||||
|
||||
base_np = base_image[0].cpu().numpy()
|
||||
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
|
||||
out_image = torch.from_numpy(composited).unsqueeze(0)
|
||||
else:
|
||||
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
|
||||
out_image = base_image
|
||||
|
||||
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
if os.path.exists(mask_path):
|
||||
m = hashlib.sha256()
|
||||
with open(mask_path, "rb") as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
class PainterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [PainterNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return PainterExtension()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user