mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 20:20:02 +00:00
Compare commits
17 Commits
v0.3.43
...
js/drafts/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0254d9cc11 | ||
|
|
92f9a10782 | ||
|
|
a6a6b615f4 | ||
|
|
50bf72f852 | ||
|
|
46c8311d14 | ||
|
|
772de7c006 | ||
|
|
b22e97dcfa | ||
|
|
f02de13316 | ||
|
|
c46268bf60 | ||
|
|
cf49a2c5b5 | ||
|
|
170c7bb90c | ||
|
|
2a0b138feb | ||
|
|
e195c1b13f | ||
|
|
5b4eb021cb | ||
|
|
396454fa41 | ||
|
|
a3cf272522 | ||
|
|
ba9548f756 |
108
.github/workflows/release-webhook.yml
vendored
Normal file
108
.github/workflows/release-webhook.yml
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
name: Release Webhook
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||
run: |
|
||||
# Generate UUID for delivery ID
|
||||
DELIVERY_ID=$(uuidgen)
|
||||
HOOK_ID="release-webhook-$(date +%s)"
|
||||
|
||||
# Create webhook payload matching GitHub release webhook format
|
||||
PAYLOAD=$(cat <<EOF
|
||||
{
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": ${{ github.event.release.id }},
|
||||
"node_id": "${{ github.event.release.node_id }}",
|
||||
"url": "${{ github.event.release.url }}",
|
||||
"html_url": "${{ github.event.release.html_url }}",
|
||||
"assets_url": "${{ github.event.release.assets_url }}",
|
||||
"upload_url": "${{ github.event.release.upload_url }}",
|
||||
"tag_name": "${{ github.event.release.tag_name }}",
|
||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||
"name": ${{ toJSON(github.event.release.name) }},
|
||||
"body": ${{ toJSON(github.event.release.body) }},
|
||||
"draft": ${{ github.event.release.draft }},
|
||||
"prerelease": ${{ github.event.release.prerelease }},
|
||||
"created_at": "${{ github.event.release.created_at }}",
|
||||
"published_at": "${{ github.event.release.published_at }}",
|
||||
"author": {
|
||||
"login": "${{ github.event.release.author.login }}",
|
||||
"id": ${{ github.event.release.author.id }},
|
||||
"node_id": "${{ github.event.release.author.node_id }}",
|
||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||
"url": "${{ github.event.release.author.url }}",
|
||||
"html_url": "${{ github.event.release.author.html_url }}",
|
||||
"type": "${{ github.event.release.author.type }}",
|
||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||
},
|
||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||
},
|
||||
"repository": {
|
||||
"id": ${{ github.event.repository.id }},
|
||||
"node_id": "${{ github.event.repository.node_id }}",
|
||||
"name": "${{ github.event.repository.name }}",
|
||||
"full_name": "${{ github.event.repository.full_name }}",
|
||||
"private": ${{ github.event.repository.private }},
|
||||
"owner": {
|
||||
"login": "${{ github.event.repository.owner.login }}",
|
||||
"id": ${{ github.event.repository.owner.id }},
|
||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||
"url": "${{ github.event.repository.owner.url }}",
|
||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||
"type": "${{ github.event.repository.owner.type }}",
|
||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||
},
|
||||
"html_url": "${{ github.event.repository.html_url }}",
|
||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||
"git_url": "${{ github.event.repository.git_url }}",
|
||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||
"url": "${{ github.event.repository.url }}",
|
||||
"created_at": "${{ github.event.repository.created_at }}",
|
||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||
"fork": ${{ github.event.repository.fork }}
|
||||
},
|
||||
"sender": {
|
||||
"login": "${{ github.event.sender.login }}",
|
||||
"id": ${{ github.event.sender.id }},
|
||||
"node_id": "${{ github.event.sender.node_id }}",
|
||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||
"url": "${{ github.event.sender.url }}",
|
||||
"html_url": "${{ github.event.sender.html_url }}",
|
||||
"type": "${{ github.event.sender.type }}",
|
||||
"site_admin": ${{ github.event.sender.site_admin }}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Generate HMAC-SHA256 signature
|
||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||
|
||||
# Send webhook with required headers
|
||||
curl -X POST "$WEBHOOK_URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-GitHub-Event: release" \
|
||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||
-d "$PAYLOAD" \
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
3
.github/workflows/stable-release.yml
vendored
3
.github/workflows/stable-release.yml
vendored
@@ -102,5 +102,4 @@ jobs:
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
prerelease: true
|
||||
make_latest: false
|
||||
draft: true
|
||||
|
||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@@ -28,3 +28,7 @@ jobs:
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
- name: Run Execution Model Tests
|
||||
run: |
|
||||
python -m pytest tests/inference/test_execution.py
|
||||
|
||||
|
||||
@@ -151,6 +151,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
@@ -1447,14 +1447,15 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@@ -1462,12 +1463,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
def default_er_sde_noise_scaler(x):
|
||||
return x * ((x ** 0.3).exp() + 10.0)
|
||||
|
||||
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
@@ -1478,32 +1485,36 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||
alpha_s = sigmas[i] / er_lambda_s
|
||||
alpha_t = sigmas[i + 1] / er_lambda_t
|
||||
r_alpha = alpha_t / alpha_s
|
||||
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
# Stage 1 Euler
|
||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
if stage_used >= 2:
|
||||
dt = er_lambda_t - er_lambda_s
|
||||
lambda_step_size = -dt / num_integration_points
|
||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||
scaled_pos = noise_scaler(lambda_pos)
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
||||
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@@ -1039,13 +1039,13 @@ class SchedulerHandler(NamedTuple):
|
||||
use_ms: bool = True
|
||||
|
||||
SCHEDULER_HANDLERS = {
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||
"beta": SchedulerHandler(beta_scheduler),
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ class T5Attention(torch.nn.Module):
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
return values.contiguous()
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
q = self.q(x)
|
||||
|
||||
@@ -997,11 +997,12 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
def __init__(self, total, node_id=None):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.node_id = node_id
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
@@ -1010,7 +1011,7 @@ class ProgressBar:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
|
||||
@@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet:
|
||||
class CacheKeySet(ABC):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
@abstractmethod
|
||||
async def add_keys(self, node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
@@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
@@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
async def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
@@ -150,9 +150,10 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
await self.cache_key_set.add_keys(node_ids)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.initialized = True
|
||||
|
||||
@@ -201,13 +202,13 @@ class BasicCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
@@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
@@ -273,8 +274,8 @@ class LRUCache(BasicCache):
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
@@ -303,11 +304,11 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
super()._ensure_subcache(node_id, children_ids)
|
||||
await super()._ensure_subcache(node_id, children_ids)
|
||||
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
await self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
@@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
@@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
@@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
@@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
from typing import Type, Literal
|
||||
|
||||
import nodes
|
||||
import asyncio
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
@@ -100,6 +101,8 @@ class TopologicalSort:
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
@@ -153,6 +156,16 @@ class TopologicalSort:
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def add_external_block(self, node_id):
|
||||
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
|
||||
self.externalBlocks += 1
|
||||
self.blockCount[node_id] += 1
|
||||
def unblock():
|
||||
self.externalBlocks -= 1
|
||||
self.blockCount[node_id] -= 1
|
||||
self.unblockedEvent.set()
|
||||
return unblock
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return False
|
||||
|
||||
@@ -181,11 +194,16 @@ class ExecutionList(TopologicalSort):
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def stage_node_execution(self):
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
available = self.get_ready_nodes()
|
||||
while len(available) == 0 and self.externalBlocks > 0:
|
||||
# Wait for an external block to be released
|
||||
await self.unblockedEvent.wait()
|
||||
self.unblockedEvent.clear()
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
|
||||
288
comfy_execution/progress.py
Normal file
288
comfy_execution/progress.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from typing import TypedDict, Dict, Optional
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
from abc import ABC
|
||||
from tqdm import tqdm
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
Running = "running"
|
||||
Finished = "finished"
|
||||
Error = "error"
|
||||
|
||||
class NodeProgressState(TypedDict):
|
||||
"""
|
||||
A class to represent the state of a node's progress.
|
||||
"""
|
||||
state: NodeState
|
||||
value: float
|
||||
max: float
|
||||
|
||||
class ProgressHandler(ABC):
|
||||
"""
|
||||
Abstract base class for progress handlers.
|
||||
Progress handlers receive progress updates and display them in various ways.
|
||||
"""
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
pass
|
||||
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node starts processing"""
|
||||
pass
|
||||
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
"""Called when a node's progress is updated"""
|
||||
pass
|
||||
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node finishes processing"""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Called when the progress registry is reset"""
|
||||
pass
|
||||
|
||||
def enable(self):
|
||||
"""Enable this handler"""
|
||||
self.enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable this handler"""
|
||||
self.enabled = False
|
||||
|
||||
class CLIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that displays progress using tqdm progress bars in the CLI.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__("cli")
|
||||
self.progress_bars: Dict[str, tqdm] = {}
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Create a new tqdm progress bar
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=state["max"],
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
|
||||
@override
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Handle case where start_handler wasn't called
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=max_value,
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
self.progress_bars[node_id].update(value)
|
||||
else:
|
||||
# Update existing progress bar
|
||||
if max_value != self.progress_bars[node_id].total:
|
||||
self.progress_bars[node_id].total = max_value
|
||||
# Calculate the update amount (difference from current position)
|
||||
current_position = self.progress_bars[node_id].n
|
||||
update_amount = value - current_position
|
||||
if update_amount > 0:
|
||||
self.progress_bars[node_id].update(update_amount)
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Complete and close the progress bar if it exists
|
||||
if node_id in self.progress_bars:
|
||||
# Ensure the bar shows 100% completion
|
||||
remaining = state["max"] - self.progress_bars[node_id].n
|
||||
if remaining > 0:
|
||||
self.progress_bars[node_id].update(remaining)
|
||||
self.progress_bars[node_id].close()
|
||||
del self.progress_bars[node_id]
|
||||
|
||||
@override
|
||||
def reset(self):
|
||||
# Close all progress bars
|
||||
for bar in self.progress_bars.values():
|
||||
bar.close()
|
||||
self.progress_bars.clear()
|
||||
|
||||
class WebUIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that sends progress updates to the WebUI via WebSockets.
|
||||
"""
|
||||
def __init__(self, server_instance):
|
||||
super().__init__("webui")
|
||||
self.server_instance = server_instance
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
self.registry = registry
|
||||
|
||||
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
||||
"""Send the current progress state to the client"""
|
||||
if self.server_instance is None:
|
||||
return
|
||||
|
||||
# Only send info for non-pending nodes
|
||||
active_nodes = {
|
||||
node_id: {
|
||||
"value": state["value"],
|
||||
"max": state["max"],
|
||||
"state": state["state"].value,
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
for node_id, state in nodes.items()
|
||||
if state["state"] != NodeState.Pending
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
self.server_instance.send_sync("progress_state", {
|
||||
"prompt_id": prompt_id,
|
||||
"nodes": active_nodes
|
||||
})
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
@override
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
if image:
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id)
|
||||
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
class ProgressRegistry:
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.nodes: Dict[str, NodeProgressState] = {}
|
||||
self.handlers: Dict[str, ProgressHandler] = {}
|
||||
|
||||
def register_handler(self, handler: ProgressHandler) -> None:
|
||||
"""Register a progress handler"""
|
||||
self.handlers[handler.name] = handler
|
||||
|
||||
def unregister_handler(self, handler_name: str) -> None:
|
||||
"""Unregister a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
# Allow handler to clean up resources
|
||||
self.handlers[handler_name].reset()
|
||||
del self.handlers[handler_name]
|
||||
|
||||
def enable_handler(self, handler_name: str) -> None:
|
||||
"""Enable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].enable()
|
||||
|
||||
def disable_handler(self, handler_name: str) -> None:
|
||||
"""Disable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].disable()
|
||||
|
||||
def ensure_entry(self, node_id: str) -> NodeProgressState:
|
||||
"""Ensure a node entry exists"""
|
||||
if node_id not in self.nodes:
|
||||
self.nodes[node_id] = NodeProgressState(
|
||||
state = NodeState.Pending,
|
||||
value = 0,
|
||||
max = 1
|
||||
)
|
||||
return self.nodes[node_id]
|
||||
|
||||
def start_progress(self, node_id: str) -> None:
|
||||
"""Start progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = 0.0
|
||||
entry["max"] = 1.0
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.start_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None:
|
||||
"""Update progress for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = value
|
||||
entry["max"] = max_value
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image)
|
||||
|
||||
def finish_progress(self, node_id: str) -> None:
|
||||
"""Finish progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Finished
|
||||
entry["value"] = entry["max"]
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.finish_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def reset_handlers(self) -> None:
|
||||
"""Reset all handlers"""
|
||||
for handler in self.handlers.values():
|
||||
handler.reset()
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None:
|
||||
global global_progress_registry
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
if global_progress_registry is not None:
|
||||
global_progress_registry.reset_handlers()
|
||||
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
handler.set_registry(global_progress_registry)
|
||||
global_progress_registry.register_handler(handler)
|
||||
|
||||
def get_progress_state() -> ProgressRegistry:
|
||||
return global_progress_registry
|
||||
46
comfy_execution/utils.py
Normal file
46
comfy_execution/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
)
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
self.token = current_executing_context.set(self.context)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.token is not None:
|
||||
current_executing_context.reset(self.token)
|
||||
@@ -2,6 +2,7 @@ import math
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import latent_preview
|
||||
import torch
|
||||
import comfy.utils
|
||||
@@ -480,6 +481,46 @@ class SamplerDPMAdaptative:
|
||||
"s_noise":s_noise })
|
||||
return (sampler, )
|
||||
|
||||
|
||||
class SamplerER_SDE(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
|
||||
"max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
|
||||
"eta": (
|
||||
IO.FLOAT,
|
||||
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
|
||||
),
|
||||
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.SAMPLER,)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, solver_type, max_stage, eta, s_noise):
|
||||
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
|
||||
eta = 0
|
||||
s_noise = 0
|
||||
|
||||
def reverse_time_sde_noise_scaler(x):
|
||||
return x ** (eta + 1)
|
||||
|
||||
if solver_type == "ER-SDE":
|
||||
# Use the default one in sample_er_sde()
|
||||
noise_scaler = None
|
||||
else:
|
||||
noise_scaler = reverse_time_sde_noise_scaler
|
||||
|
||||
sampler_name = "er_sde"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
|
||||
return (sampler,)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@@ -609,8 +650,14 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
middle_cond = self.conds.get("middle", None)
|
||||
positive_cond = self.conds.get("positive", None)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.cfg2, 1.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg1, 1.0):
|
||||
middle_cond = None
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
|
||||
class DualCFGGuider:
|
||||
@@ -781,6 +828,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SamplerER_SDE": SamplerER_SDE,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
|
||||
@@ -4,6 +4,7 @@ import comfy.sampler_helpers
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import math
|
||||
|
||||
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
||||
pos = noise_pred_pos - noise_pred_nocond
|
||||
@@ -69,8 +70,23 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
empty_cond = self.conds.get("empty_negative_prompt", None)
|
||||
|
||||
(noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
|
||||
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.neg_scale, 0.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg, 1.0):
|
||||
empty_cond = None
|
||||
|
||||
conds = [positive_cond, negative_cond, empty_cond]
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, conds, x, timestep, model_options)
|
||||
|
||||
# Apply pre_cfg_functions since sampling_function() is skipped
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": self.cfg, "timestep": timestep,
|
||||
"input": x, "sigma": timestep, "model": self.inner_model, "model_options": model_options}
|
||||
out = fn(args)
|
||||
|
||||
noise_pred_pos, noise_pred_neg, noise_pred_empty = out
|
||||
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
|
||||
|
||||
# normally this would be done in cfg_function, but we skipped
|
||||
@@ -82,6 +98,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
"denoised": cfg_result,
|
||||
"cond": positive_cond,
|
||||
"uncond": negative_cond,
|
||||
"cond_scale": self.cfg,
|
||||
"model": self.inner_model,
|
||||
"uncond_denoised": noise_pred_neg,
|
||||
"cond_denoised": noise_pred_pos,
|
||||
|
||||
71
comfy_extras/nodes_tcfg.py
Normal file
71
comfy_extras/nodes_tcfg.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||
"""Drop tangential components from uncond score to align with cond score."""
|
||||
# (B, 1, ...)
|
||||
batch_num = cond_score.shape[0]
|
||||
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||
|
||||
# Score matrix A (B, 2, ...)
|
||||
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||
try:
|
||||
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||
except RuntimeError:
|
||||
# Fallback to CPU
|
||||
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||
|
||||
# Drop the tangential components
|
||||
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||
|
||||
|
||||
class TCFG(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL,)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
|
||||
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
|
||||
def tangential_damping_cfg(args):
|
||||
# Assume [cond, uncond, ...]
|
||||
x = args["input"]
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
# Skip when either cond or uncond is None
|
||||
return conds_out
|
||||
cond_pred = conds_out[0]
|
||||
uncond_pred = conds_out[1]
|
||||
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||
uncond_pred_td = x - uncond_td
|
||||
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TCFG": TCFG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TCFG": "Tangential Damping CFG",
|
||||
}
|
||||
135
execution.py
135
execution.py
@@ -8,12 +8,14 @@ import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
DependencyAwareCache,
|
||||
@@ -28,6 +30,8 @@ from comfy_execution.graph import (
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -39,12 +43,13 @@ class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
if node_id in self.is_changed:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -62,7 +67,8 @@ class IsChangedCache:
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
logging.warning("WARNING: {}".format(e))
|
||||
@@ -164,7 +170,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def resolve_map_node_over_list_results(results):
|
||||
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
|
||||
if len(remaining) == 0:
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
else:
|
||||
done, pending = await asyncio.wait(remaining)
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -178,7 +196,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
results = []
|
||||
def process_inputs(inputs, index=None, input_is_list=False):
|
||||
async def process_inputs(inputs, index=None, input_is_list=False):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
execution_block = None
|
||||
@@ -194,20 +212,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
results.append(getattr(obj, func)(**inputs))
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
result = task.result()
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(execution_block)
|
||||
|
||||
if input_is_list:
|
||||
process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
elif max_len_input == 0:
|
||||
process_inputs({})
|
||||
await process_inputs({})
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
process_inputs(input_dict, i)
|
||||
await process_inputs(input_dict, i)
|
||||
return results
|
||||
|
||||
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
@@ -229,11 +264,18 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
|
||||
return output, ui, has_subgraph, False
|
||||
|
||||
def get_output_from_returns(return_values, obj):
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
@@ -267,6 +309,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
# TODO: Think there's an existing bug here
|
||||
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
|
||||
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
|
||||
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui, has_subgraph
|
||||
@@ -279,7 +325,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@@ -291,11 +337,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
if unique_id in pending_subgraph_results:
|
||||
if unique_id in pending_async_nodes:
|
||||
results = []
|
||||
for r in pending_async_nodes[unique_id]:
|
||||
if isinstance(r, asyncio.Task):
|
||||
try:
|
||||
results.append(r.result())
|
||||
except Exception as ex:
|
||||
# An async task failed - propagate the exception up
|
||||
del pending_async_nodes[unique_id]
|
||||
raise ex
|
||||
else:
|
||||
results.append(r)
|
||||
del pending_async_nodes[unique_id]
|
||||
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
||||
elif unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
@@ -317,6 +378,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
@@ -328,7 +390,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
@@ -357,8 +420,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
else:
|
||||
return block
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
async def await_completion():
|
||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
unblock()
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
@@ -401,7 +474,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||
subcache.clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
@@ -446,6 +520,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
|
||||
return (ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@@ -500,6 +575,11 @@ class PromptExecutor:
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
@@ -512,9 +592,11 @@ class PromptExecutor:
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
cached_nodes = []
|
||||
@@ -527,6 +609,7 @@ class PromptExecutor:
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
@@ -534,12 +617,13 @@ class PromptExecutor:
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = execution_list.stage_node_execution()
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@@ -569,7 +653,7 @@ class PromptExecutor:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
@@ -646,7 +730,7 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
r = await validate_inputs(prompt_id, prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
@@ -771,7 +855,8 @@ def validate_inputs(prompt, item, validated):
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
@@ -804,7 +889,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -847,7 +932,7 @@ def validate_prompt(prompt):
|
||||
valid = False
|
||||
reasons = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
m = await validate_inputs(prompt_id, prompt, o, validated)
|
||||
valid = m[0]
|
||||
reasons = m[1]
|
||||
except Exception as ex:
|
||||
|
||||
35
main.py
35
main.py
@@ -11,6 +11,8 @@ import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
@@ -55,6 +57,9 @@ def apply_custom_paths():
|
||||
|
||||
|
||||
def execute_prestartup_script():
|
||||
if args.disable_all_custom_nodes and len(args.whitelist_custom_nodes) == 0:
|
||||
return
|
||||
|
||||
def execute_script(script_path):
|
||||
module_name = os.path.splitext(script_path)[0]
|
||||
try:
|
||||
@@ -66,9 +71,6 @@ def execute_prestartup_script():
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
return False
|
||||
|
||||
if args.disable_all_custom_nodes:
|
||||
return
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
@@ -81,6 +83,9 @@ def execute_prestartup_script():
|
||||
|
||||
script_path = os.path.join(module_path, "prestartup_script.py")
|
||||
if os.path.exists(script_path):
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = execute_script(script_path)
|
||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@@ -128,7 +133,7 @@ import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from server import BinaryEventTypes
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
@@ -224,14 +229,25 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
)
|
||||
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
def hook(value, total, preview_image):
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
executing_context = get_executing_context()
|
||||
if prompt_id is None and executing_context is not None:
|
||||
prompt_id = executing_context.prompt_id
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||
if prompt_id is None:
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if preview_image is not None:
|
||||
# Also send old method for backward compatibility
|
||||
# TODO - Remove after this repo is updated to frontend with metadata support
|
||||
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
@@ -276,7 +292,10 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||
nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
4
nodes.py
4
nodes.py
@@ -2187,6 +2187,9 @@ def init_external_custom_nodes():
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@@ -2280,6 +2283,7 @@ def init_builtin_extra_nodes():
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
7
protocol.py
Normal file
7
protocol.py
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
PREVIEW_IMAGE_WITH_METADATA = 4
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.30
|
||||
comfyui-workflow-templates==0.1.31
|
||||
comfyui-embedded-docs==0.2.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
51
server.py
51
server.py
@@ -35,11 +35,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
@@ -643,7 +639,8 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
valid = execution.validate_prompt(prompt)
|
||||
prompt_id = str(uuid.uuid4())
|
||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
@@ -651,7 +648,6 @@ class PromptServer():
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
@@ -766,6 +762,10 @@ class PromptServer():
|
||||
async def send(self, event, data, sid=None):
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
await self.send_image(data, sid=sid)
|
||||
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||
# data is (preview_image, metadata)
|
||||
preview_image, metadata = data
|
||||
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
await self.send_bytes(event, data, sid)
|
||||
else:
|
||||
@@ -804,6 +804,43 @@ class PromptServer():
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
|
||||
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
|
||||
|
||||
# Prepare metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["image_type"] = mimetype
|
||||
|
||||
# Serialize metadata as JSON
|
||||
import json
|
||||
metadata_json = json.dumps(metadata).encode('utf-8')
|
||||
metadata_length = len(metadata_json)
|
||||
|
||||
# Prepare image data
|
||||
bytesIO = BytesIO()
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
image_bytes = bytesIO.getvalue()
|
||||
|
||||
# Combine metadata and image
|
||||
combined_data = bytearray()
|
||||
combined_data.extend(struct.pack(">I", metadata_length))
|
||||
combined_data.extend(metadata_json)
|
||||
combined_data.extend(image_bytes)
|
||||
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
|
||||
|
||||
async def send_bytes(self, event, data, sid=None):
|
||||
message = self.encode_bytes(event, data)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pytest>=7.8.0
|
||||
pytest-aiohttp
|
||||
pytest-asyncio
|
||||
websocket-client
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Config for testing nodes
|
||||
testing:
|
||||
custom_nodes: tests/inference/testing_nodes
|
||||
custom_nodes: testing_nodes
|
||||
|
||||
|
||||
410
tests/inference/test_async_nodes.py
Normal file
410
tests/inference/test_async_nodes.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import urllib.error
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.inference.test_execution import ComfyClient
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestAsyncNodes:
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
# Running server with args: pargs
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
client = ComfyClient()
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||
except ConnectionRefusedError:
|
||||
# Retrying...
|
||||
pass
|
||||
else:
|
||||
break
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
# Happy Path Tests
|
||||
|
||||
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that a basic async node executes correctly."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
||||
output = g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution completed
|
||||
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify the image passed through correctly
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async sleep nodes with different durations
|
||||
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
|
||||
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Add outputs for each
|
||||
_output1 = g.node("PreviewImage", images=sleep1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||
|
||||
# Verify all nodes executed
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||
|
||||
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with proper dependency handling."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Chain of async operations
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Average depends on both async results
|
||||
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
||||
output = g.node("SaveImage", images=average.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution order
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2)
|
||||
assert result.did_run(average) and result.did_run(output)
|
||||
|
||||
# Verify averaged result
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
||||
|
||||
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async VALIDATE_INPUTS function."""
|
||||
g = builder
|
||||
# Create a test node with async validation
|
||||
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
# Should pass validation
|
||||
result = client.run(g)
|
||||
assert result.did_run(validation_node)
|
||||
|
||||
# Test validation failure
|
||||
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||
|
||||
# Create async nodes that will be evaluated lazily
|
||||
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
|
||||
|
||||
# Use lazy mix that only needs sleep1 (mask=0.0)
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should only execute sleep1, not sleep2
|
||||
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||
|
||||
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async check_lazy_status function."""
|
||||
g = builder
|
||||
# Create a node with async check_lazy_status
|
||||
lazy_node = g.node("TestAsyncLazyCheck",
|
||||
input1="value1",
|
||||
input2="value2",
|
||||
condition=True)
|
||||
g.node("SaveImage", images=lazy_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
assert result.did_run(lazy_node)
|
||||
|
||||
# Error Handling Tests
|
||||
|
||||
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async execution errors are properly handled."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Create an async node that will error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
||||
|
||||
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async validation error handling."""
|
||||
g = builder
|
||||
# Node with async validation that will fail
|
||||
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
# Verify it's a validation error
|
||||
assert exc_info.value.code == 400
|
||||
|
||||
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling of async operations that timeout."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Very long sleep that would timeout
|
||||
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
|
||||
g.node("SaveImage", images=timeout_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised a timeout error"
|
||||
except Exception as e:
|
||||
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
||||
|
||||
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that workflow can recover after async errors."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# First run with error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# Second run should succeed
|
||||
g2 = GraphBuilder(prefix="recovery_test")
|
||||
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
||||
g2.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g2)
|
||||
assert result.did_run(sleep_node), "Should be able to run after error"
|
||||
|
||||
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling when sync node errors while async node is executing."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async node that takes time
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Sync node that will error immediately
|
||||
error_node = g.node("TestSyncError", value=image.out(0))
|
||||
|
||||
# Both feed into output
|
||||
g.node("PreviewImage", images=sleep_node.out(0))
|
||||
g.node("PreviewImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
# Verify the sync error was caught even though async was running
|
||||
assert 'prompt_id' in e.args[0]
|
||||
|
||||
# Edge Cases
|
||||
|
||||
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with execution blockers."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async sleep nodes
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Create list of images
|
||||
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
|
||||
|
||||
# Create list of blocking conditions - [False, True] to block only the second item
|
||||
int1 = g.node("StubInt", value=1)
|
||||
int2 = g.node("StubInt", value=2)
|
||||
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
|
||||
|
||||
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
|
||||
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
|
||||
|
||||
# Block based on the comparison results
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
output = g.node("PreviewImage", images=blocker.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have blocked second image"
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async nodes are properly cached."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||
g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
# First run
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(sleep_node), "Should run first time"
|
||||
|
||||
# Second run - should be cached
|
||||
start_time = time.time()
|
||||
result2 = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
assert not result2.did_run(sleep_node), "Should be cached"
|
||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Node that generates async nodes dynamically
|
||||
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
num_async_nodes=3,
|
||||
sleep_duration=0.2)
|
||||
g.node("SaveImage", images=dynamic_async.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should execute async nodes in parallel within dynamic prompt
|
||||
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||
assert result.did_run(dynamic_async)
|
||||
|
||||
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async resources are properly cleaned up."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async nodes that use resources
|
||||
resource_nodes = []
|
||||
for i in range(5):
|
||||
node = g.node("TestAsyncResourceUser",
|
||||
value=image.out(0),
|
||||
resource_id=f"resource_{i}",
|
||||
duration=0.1)
|
||||
resource_nodes.append(node)
|
||||
g.node("PreviewImage", images=node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed
|
||||
for node in resource_nodes:
|
||||
assert result.did_run(node)
|
||||
|
||||
# Run again to ensure resources were cleaned up
|
||||
result2 = client.run(g)
|
||||
# Should be cached but not error due to resource conflicts
|
||||
for node in resource_nodes:
|
||||
assert not result2.did_run(node), "Should be cached"
|
||||
|
||||
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test cancellation of async operations."""
|
||||
# This would require implementing cancellation in the client
|
||||
# For now, we'll test that long-running async operations can be interrupted
|
||||
pass # TODO: Implement when cancellation API is available
|
||||
|
||||
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test workflows with both sync and async nodes."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
# Mix of sync and async operations
|
||||
# Sync: lazy mix images
|
||||
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
|
||||
# Async: sleep
|
||||
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
|
||||
# Sync: custom validation
|
||||
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
|
||||
# Async: sleep again
|
||||
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
|
||||
|
||||
output = g.node("SaveImage", images=async_op2.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed in correct order
|
||||
assert result.did_run(sync_op1)
|
||||
assert result.did_run(async_op1)
|
||||
assert result.did_run(sync_op2)
|
||||
assert result.did_run(async_op2)
|
||||
|
||||
# Image should be a mix of black and white (gray)
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"
|
||||
@@ -252,7 +252,7 @@ class TestExecution:
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value", [
|
||||
("StubInt", 5),
|
||||
("StubFloat", 5.0)
|
||||
("StubMask", 5.0)
|
||||
])
|
||||
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@@ -497,6 +497,69 @@ class TestExecution:
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create sleep nodes for each duration
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||
|
||||
# Add outputs to verify the execution
|
||||
_output1 = g.node("PreviewImage", images=sleep_node1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep_node2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep_node3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
||||
# We'll allow for up to 0.8s total to account for overhead
|
||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Create input images with different values
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
|
||||
parallel_sleep = g.node("TestParallelSleep",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
image3=image3.out(0),
|
||||
sleep1=0.4,
|
||||
sleep2=0.5,
|
||||
sleep3=0.6)
|
||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
|
||||
# Verify we get an image as output (blend of the three input images)
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
# Average pixel value should be around 170 (255 * 2 // 3)
|
||||
avg_value = numpy.array(result_images[0]).mean()
|
||||
assert avg_value == 170, f"Image average value {avg_value} should be 170"
|
||||
|
||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||
# only that one entry in the list is blocked.
|
||||
|
||||
@@ -3,6 +3,7 @@ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DI
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -13,6 +14,7 @@ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -20,4 +22,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
|
||||
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
|
||||
class TestAsyncValidation(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"threshold": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||
# Simulate async validation (e.g., checking remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if value > threshold:
|
||||
return f"Value {value} exceeds threshold {threshold}"
|
||||
return True
|
||||
|
||||
def process(self, value, threshold):
|
||||
# Create image based on value
|
||||
intensity = value / 10.0
|
||||
image = torch.ones([1, 512, 512, 3]) * intensity
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncError(ComfyNodeABC):
|
||||
"""Test node that errors during async execution."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "error_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def error_execution(self, value, error_after):
|
||||
await asyncio.sleep(error_after)
|
||||
raise RuntimeError("Intentional async execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncValidationError(ComfyNodeABC):
|
||||
"""Test node with async validation that always fails."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"max_value": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||
await asyncio.sleep(0.05)
|
||||
# Always fail validation for values > max_value
|
||||
if value > max_value:
|
||||
return f"Async validation failed: {value} > {max_value}"
|
||||
return True
|
||||
|
||||
def process(self, value, max_value):
|
||||
# This won't be reached if validation fails
|
||||
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncTimeout(ComfyNodeABC):
|
||||
"""Test node that simulates timeout scenarios."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
||||
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "timeout_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def timeout_execution(self, value, timeout, operation_time):
|
||||
try:
|
||||
# This will timeout if operation_time > timeout
|
||||
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
||||
return (value,)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
||||
|
||||
|
||||
class TestSyncError(ComfyNodeABC):
|
||||
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sync_error"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def sync_error(self, value):
|
||||
raise RuntimeError("Intentional sync execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncLazyCheck(ComfyNodeABC):
|
||||
"""Test node with async check_lazy_status."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": (IO.ANY, {"lazy": True}),
|
||||
"input2": (IO.ANY, {"lazy": True}),
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def check_lazy_status(self, condition, input1, input2):
|
||||
# Simulate async checking (e.g., querying remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
needed = []
|
||||
if condition and input1 is None:
|
||||
needed.append("input1")
|
||||
if not condition and input2 is None:
|
||||
needed.append("input2")
|
||||
return needed
|
||||
|
||||
def process(self, input1, input2, condition):
|
||||
# Return a simple image
|
||||
return (torch.ones([1, 512, 512, 3]),)
|
||||
|
||||
|
||||
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||
"""Test node that dynamically generates async nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"image2": ("IMAGE",),
|
||||
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
||||
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate_async_workflow"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create multiple async sleep nodes
|
||||
sleep_nodes = []
|
||||
for i in range(num_async_nodes):
|
||||
image = image1 if i % 2 == 0 else image2
|
||||
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
||||
sleep_nodes.append(sleep_node)
|
||||
|
||||
# Average all results
|
||||
if len(sleep_nodes) == 1:
|
||||
final_node = sleep_nodes[0]
|
||||
else:
|
||||
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
||||
for i, node in enumerate(sleep_nodes[1:], 2):
|
||||
avg_inputs[f"input{i}"] = node.out(0)
|
||||
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
||||
|
||||
return {
|
||||
"result": (final_node.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
|
||||
class TestAsyncResourceUser(ComfyNodeABC):
|
||||
"""Test node that uses resources during async execution."""
|
||||
|
||||
# Class-level resource tracking for testing
|
||||
_active_resources: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"resource_id": ("STRING", {"default": "resource_0"}),
|
||||
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "use_resource"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def use_resource(self, value, resource_id, duration):
|
||||
# Check if resource is already in use
|
||||
if self._active_resources.get(resource_id, False):
|
||||
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
||||
|
||||
# Mark resource as in use
|
||||
self._active_resources[resource_id] = True
|
||||
|
||||
try:
|
||||
# Simulate resource usage
|
||||
await asyncio.sleep(duration)
|
||||
return (value,)
|
||||
finally:
|
||||
# Always clean up resource
|
||||
self._active_resources[resource_id] = False
|
||||
|
||||
|
||||
class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||
"""Test async processing of batched inputs."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process_batch"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||
batch_size = images.shape[0]
|
||||
pbar = ProgressBar(batch_size, node_id=unique_id)
|
||||
|
||||
# Process each image in the batch
|
||||
processed = []
|
||||
for i in range(batch_size):
|
||||
# Simulate async processing
|
||||
await asyncio.sleep(process_time_per_item)
|
||||
|
||||
# Simple processing: invert the image
|
||||
processed_image = 1.0 - images[i:i+1]
|
||||
processed.append(processed_image)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Stack processed images
|
||||
result = torch.cat(processed, dim=0)
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||
"""Test concurrent execution limits for async nodes."""
|
||||
|
||||
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
||||
"node_id": ("INT", {"default": 0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "limited_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def limited_execution(self, value, duration, node_id):
|
||||
async with self._semaphore:
|
||||
# Node {node_id} acquired semaphore
|
||||
await asyncio.sleep(duration)
|
||||
# Node {node_id} releasing semaphore
|
||||
return (value,)
|
||||
|
||||
|
||||
# Add node mappings
|
||||
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncValidation": TestAsyncValidation,
|
||||
"TestAsyncError": TestAsyncError,
|
||||
"TestAsyncValidationError": TestAsyncValidationError,
|
||||
"TestAsyncTimeout": TestAsyncTimeout,
|
||||
"TestSyncError": TestSyncError,
|
||||
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
||||
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
||||
"TestAsyncResourceUser": TestAsyncResourceUser,
|
||||
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
||||
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
||||
}
|
||||
|
||||
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncValidation": "Test Async Validation",
|
||||
"TestAsyncError": "Test Async Error",
|
||||
"TestAsyncValidationError": "Test Async Validation Error",
|
||||
"TestAsyncTimeout": "Test Async Timeout",
|
||||
"TestSyncError": "Test Sync Error",
|
||||
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
||||
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
||||
"TestAsyncResourceUser": "Test Async Resource User",
|
||||
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
||||
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
import torch
|
||||
import time
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSamplingInExpansion:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
|
||||
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
|
||||
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "sampling_in_expansion"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create a basic image generation workflow using the input model, clip and vae
|
||||
# 1. Setup text prompts using the provided CLIP model
|
||||
positive_prompt = g.node("CLIPTextEncode",
|
||||
text=prompt,
|
||||
clip=clip)
|
||||
negative_prompt = g.node("CLIPTextEncode",
|
||||
text=negative_prompt,
|
||||
clip=clip)
|
||||
|
||||
# 2. Create empty latent with specified size
|
||||
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||
|
||||
# 3. Setup sampler and generate image latent
|
||||
sampler = g.node("KSampler",
|
||||
model=model,
|
||||
positive=positive_prompt.out(0),
|
||||
negative=negative_prompt.out(0),
|
||||
latent_image=empty_latent.out(0),
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
sampler_name="euler_ancestral",
|
||||
scheduler="normal")
|
||||
|
||||
# 4. Decode latent to image using VAE
|
||||
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
|
||||
|
||||
return {
|
||||
"result": (output.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
class TestParallelSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "parallel_sleep"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||
# Create a graph dynamically with three TestSleep nodes
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create sleep nodes for each duration and image
|
||||
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
|
||||
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
|
||||
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
|
||||
|
||||
# Blend the results using TestVariadicAverage
|
||||
blend = g.node("TestVariadicAverage",
|
||||
input1=sleep_node1.out(0),
|
||||
input2=sleep_node2.out(0),
|
||||
input3=sleep_node3.out(0))
|
||||
|
||||
return {
|
||||
"result": (blend.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestCustomValidation5": TestCustomValidation5,
|
||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestCustomValidation5": "Custom Validation 5",
|
||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user