mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 19:20:03 +00:00
Compare commits
5 Commits
master
...
feat/core/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06408af600 | ||
|
|
3902c86e83 | ||
|
|
01ef4e50ec | ||
|
|
50975a7a0d | ||
|
|
d987b0d32d |
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@@ -7,8 +7,6 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@@ -108,37 +106,3 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import math
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm.auto import trange as trange_, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@@ -14,7 +15,34 @@ import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
|
||||
def trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange_(*args, **kwargs)
|
||||
|
||||
pbar = trange_(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
@@ -195,20 +195,8 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
if text_ids is not None:
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@@ -29,34 +29,19 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@@ -1160,16 +1160,12 @@ class Anima(BaseModel):
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
@@ -55,11 +55,6 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
@@ -656,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
@@ -752,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||
with torch.inference_mode():
|
||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||
#thread local. So exploit that to escape context
|
||||
if enables_dynamic_vram():
|
||||
t = threading.Thread(
|
||||
target=load_models_gpu_thread,
|
||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
else:
|
||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@@ -1211,20 +1226,21 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
return v_tensor.to(dtype=dtype)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
||||
#a non comfy weight
|
||||
r.copy_(v_tensor)
|
||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
||||
return r
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
@@ -316,7 +317,7 @@ class ModelPatcher:
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
@@ -1491,9 +1492,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||
#with a high layer rate (e.g. autoregressive LLMs).
|
||||
#We have way more tools for acceleration on comfy weight offloading, so always
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
loading.sort(reverse=True)
|
||||
@@ -1525,7 +1524,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
setattr(m, param_key + "_function", weight_function)
|
||||
geometry = weight
|
||||
if not isinstance(weight, QuantizedTensor):
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
|
||||
weight._model_dtype = model_dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
return comfy.memory_management.vram_aligned_size(geometry)
|
||||
@@ -1551,14 +1550,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
|
||||
|
||||
25
comfy/ops.py
25
comfy/ops.py
@@ -83,18 +83,14 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
|
||||
if not resident:
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
@@ -144,13 +140,9 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
@@ -177,8 +169,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||
else:
|
||||
y = x
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
@@ -190,6 +182,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
s._v_signature=signature
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
@@ -122,26 +122,20 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||
memory_required = 1e20
|
||||
minimum_memory_required = None
|
||||
else:
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@@ -793,6 +793,8 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
@@ -801,7 +803,6 @@ class VAE:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
mp = comfy.model_patcher.CoreModelPatcher
|
||||
|
||||
@@ -16,7 +16,6 @@ def sample_manual_loop_no_classes(
|
||||
temperature: float = 0.85,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = None,
|
||||
min_p: float = 0.000,
|
||||
seed: int = 1,
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
@@ -24,8 +23,6 @@ def sample_manual_loop_no_classes(
|
||||
audio_end_id: int = 215669,
|
||||
eos_token_id: int = 151645,
|
||||
):
|
||||
if ids is None:
|
||||
return []
|
||||
device = model.execution_device
|
||||
|
||||
if execution_dtype is None:
|
||||
@@ -35,7 +32,6 @@ def sample_manual_loop_no_classes(
|
||||
execution_dtype = torch.float32
|
||||
|
||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||
embeds_batch = embeds.shape[0]
|
||||
for i, t in enumerate(paddings):
|
||||
attention_mask[i, :t] = 0
|
||||
attention_mask[i, t:] = 1
|
||||
@@ -45,27 +41,22 @@ def sample_manual_loop_no_classes(
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(seed)
|
||||
model_config = model.transformer.model.config
|
||||
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
|
||||
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
|
||||
for step in range(max_new_tokens):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
past_key_values = outputs[2]
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
cond_logits = next_token_logits[0:1]
|
||||
uncond_logits = next_token_logits[1:2]
|
||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||
else:
|
||||
cfg_logits = next_token_logits[0:1]
|
||||
cond_logits = next_token_logits[0:1]
|
||||
uncond_logits = next_token_logits[1:2]
|
||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||
|
||||
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
|
||||
if use_eos_score:
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||
|
||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||
@@ -73,7 +64,7 @@ def sample_manual_loop_no_classes(
|
||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||
|
||||
if use_eos_score:
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
cfg_logits[:, eos_token_id] = eos_score
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
@@ -81,12 +72,6 @@ def sample_manual_loop_no_classes(
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if min_p is not None and min_p > 0:
|
||||
probs = torch.softmax(cfg_logits, dim=-1)
|
||||
p_max = probs.max(dim=-1, keepdim=True).values
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
@@ -108,8 +93,8 @@ def sample_manual_loop_no_classes(
|
||||
break
|
||||
|
||||
embed, _, _, _ = model.process_tokens([[token]], device)
|
||||
embeds = embed.repeat(embeds_batch, 1, 1)
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
embeds = embed.repeat(2, 1, 1)
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
|
||||
output_audio_codes.append(token - audio_start_id)
|
||||
progress_bar.update_absolute(step)
|
||||
@@ -117,31 +102,24 @@ def sample_manual_loop_no_classes(
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
positive = positive[0]
|
||||
negative = negative[0]
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
negative = negative[0]
|
||||
neg_pad = 0
|
||||
if len(negative) < len(positive):
|
||||
neg_pad = (len(positive) - len(negative))
|
||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||
|
||||
neg_pad = 0
|
||||
if len(negative) < len(positive):
|
||||
neg_pad = (len(positive) - len(negative))
|
||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||
pos_pad = 0
|
||||
if len(negative) > len(positive):
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
pos_pad = 0
|
||||
if len(negative) > len(positive):
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
paddings = [pos_pad, neg_pad]
|
||||
ids = [positive, negative]
|
||||
else:
|
||||
paddings = []
|
||||
ids = [positive]
|
||||
|
||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
paddings = [pos_pad, neg_pad]
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
@@ -151,12 +129,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||
user_metas = {
|
||||
k: kwargs.pop(k)
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature")
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
|
||||
if k in kwargs
|
||||
}
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
|
||||
user_metas = {
|
||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||
for k, v in user_metas.items()
|
||||
@@ -169,11 +147,8 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||
|
||||
def _metas_to_cap(self, **kwargs) -> str:
|
||||
use_keys = ("bpm", "timesignature", "keyscale", "duration")
|
||||
use_keys = ("bpm", "duration", "keyscale", "timesignature")
|
||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
duration = user_metas["duration"]
|
||||
if duration == "N/A":
|
||||
user_metas["duration"] = "30 seconds"
|
||||
@@ -184,13 +159,9 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
text = text.strip()
|
||||
text_negative = kwargs.get("caption_negative", text).strip()
|
||||
out = {}
|
||||
lyrics = kwargs.get("lyrics", "")
|
||||
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
|
||||
duration = kwargs.get("duration", 120)
|
||||
if isinstance(duration, str):
|
||||
duration = float(duration.split(None, 1)[0])
|
||||
language = kwargs.get("language")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
@@ -199,55 +170,28 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
min_p = kwargs.get("min_p", 0.000)
|
||||
|
||||
|
||||
duration = math.ceil(duration)
|
||||
kwargs["duration"] = duration
|
||||
tokens_duration = duration * 5
|
||||
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
|
||||
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
|
||||
|
||||
metas_negative = {
|
||||
k.rsplit("_", 1)[0]: kwargs.pop(k)
|
||||
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
|
||||
if k in kwargs
|
||||
}
|
||||
if not kwargs.get("use_negative_caption"):
|
||||
_ = metas_negative.pop("caption", None)
|
||||
|
||||
cot_text = self._metas_to_cot(caption=text, **kwargs)
|
||||
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
|
||||
cot_text = self._metas_to_cot(caption = text, **kwargs)
|
||||
meta_cap = self._metas_to_cap(**kwargs)
|
||||
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
|
||||
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
|
||||
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
|
||||
|
||||
llm_prompts = {
|
||||
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
|
||||
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
|
||||
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
|
||||
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
|
||||
}
|
||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
|
||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
|
||||
|
||||
out = {
|
||||
prompt_key: self.qwen3_06b.tokenize_with_weights(
|
||||
prompt,
|
||||
prompt_key == "qwen3_06b" and return_word_ids,
|
||||
disable_weights = True,
|
||||
**kwargs,
|
||||
)
|
||||
for prompt_key, prompt in llm_prompts.items()
|
||||
}
|
||||
out["lm_metadata"] = {"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5,
|
||||
"seed": seed,
|
||||
"generate_audio_codes": generate_audio_codes,
|
||||
"cfg_scale": cfg_scale,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
}
|
||||
return out
|
||||
|
||||
@@ -308,7 +252,7 @@ class ACE15TEModel(torch.nn.Module):
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
@@ -27,7 +27,6 @@ from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
import json
|
||||
@@ -1156,32 +1155,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
@@ -1403,21 +1376,3 @@ def string_to_seed(data):
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
def deepcopy_list_dict(obj, memo=None):
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, dict):
|
||||
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
res = [deepcopy_list_dict(i, memo) for i in obj]
|
||||
else:
|
||||
res = obj
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
@@ -21,7 +21,6 @@ from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from comfy.patcher_extension import PatcherInjection
|
||||
|
||||
@@ -182,21 +181,18 @@ class BypassForwardHook:
|
||||
)
|
||||
return # Already injected
|
||||
|
||||
# Move adapter weights to compute device (GPU)
|
||||
# Use get_torch_device() instead of module.weight.device because
|
||||
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Get dtype from module weight if available
|
||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
||||
device = None
|
||||
dtype = None
|
||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||
device = self.module.weight.device
|
||||
dtype = self.module.weight.dtype
|
||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
||||
device = self.module.W_q.device
|
||||
dtype = self.module.W_q.dtype
|
||||
|
||||
# Only use dtype if it's a standard float type, not quantized
|
||||
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||
dtype = None
|
||||
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
if device is not None:
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
|
||||
self.original_forward = self.module.forward
|
||||
self.module.forward = self._bypass_forward
|
||||
|
||||
@@ -34,21 +34,6 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = False,
|
||||
) -> VideoInput | None:
|
||||
"""
|
||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||
|
||||
Returns:
|
||||
A new VideoInput, or None if the result would have negative duration
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
@@ -30,6 +29,7 @@ def container_to_output_format(container_format: str | None) -> str | None:
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
@@ -57,14 +57,12 @@ class VideoFromFile(VideoInput):
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self.__start_time = start_time
|
||||
self.__duration = duration
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@@ -98,16 +96,6 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
raw_duration = self._get_raw_duration()
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
if self.__duration:
|
||||
return min(self.__duration, duration_from_start)
|
||||
return duration_from_start
|
||||
|
||||
def _get_raw_duration(self) -> float:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
@@ -125,13 +113,9 @@ class VideoFromFile(VideoInput):
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for packet in frame_iterator:
|
||||
frame_count += 1
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
@@ -147,54 +131,36 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available and usable
|
||||
if (
|
||||
video_stream.frames
|
||||
and video_stream.frames > 0
|
||||
and not self.__start_time
|
||||
and not self.__duration
|
||||
):
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
duration_seconds = min(self.__duration, duration_from_start)
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
frame_count = 1
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= start_pts:
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= end_pts:
|
||||
break
|
||||
frame_count += 1
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@@ -233,21 +199,9 @@ class VideoFromFile(VideoInput):
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
# Get video frames
|
||||
frames = []
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
for frame in container.decode(video_stream):
|
||||
if frame.pts < start_pts:
|
||||
continue
|
||||
if self.__duration and frame.pts >= end_pts:
|
||||
break
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
@@ -255,44 +209,31 @@ class VideoFromFile(VideoInput):
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
# Use last stream for consistency
|
||||
if len(container.streams.audio):
|
||||
audio_stream = container.streams.audio[-1]
|
||||
audio_frames = []
|
||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||
frames = itertools.chain.from_iterable(
|
||||
map(resample, container.decode(audio_stream))
|
||||
)
|
||||
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
if has_first_frame:
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
if self.__duration:
|
||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||
})
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
@@ -309,7 +250,7 @@ class VideoFromFile(VideoInput):
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
@@ -321,14 +262,15 @@ class VideoFromFile(VideoInput):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@@ -362,21 +304,10 @@ class VideoFromFile(VideoInput):
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
if len(container.streams.video):
|
||||
return container.streams.video[0]
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def as_trimmed(
|
||||
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||
) -> VideoInput | None:
|
||||
trimmed = VideoFromFile(
|
||||
self.get_stream_source(),
|
||||
start_time=start_time + self.__start_time,
|
||||
duration=duration,
|
||||
)
|
||||
if trimmed.get_duration() < duration and strict_duration:
|
||||
return None
|
||||
return trimmed
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
@@ -391,7 +322,7 @@ class VideoFromComponents(VideoInput):
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
@@ -399,7 +330,7 @@ class VideoFromComponents(VideoInput):
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
@@ -426,10 +357,7 @@ class VideoFromComponents(VideoInput):
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
@@ -444,21 +372,12 @@ class VideoFromComponents(VideoInput):
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
output.mux(audio_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = True,
|
||||
) -> VideoInput | None:
|
||||
if self.get_duration() < start_time + duration:
|
||||
return None
|
||||
#TODO Consider tracking duration and trimming at time of save?
|
||||
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||
|
||||
@@ -1430,6 +1430,11 @@ class Schema:
|
||||
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
||||
accept_all_inputs: bool=False
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
lazy_outputs: bool=False
|
||||
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
||||
|
||||
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
||||
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@@ -1875,6 +1880,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._ACCEPT_ALL_INPUTS
|
||||
|
||||
_LAZY_OUTPUTS = None
|
||||
@final
|
||||
@classproperty
|
||||
def LAZY_OUTPUTS(cls): # noqa
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._LAZY_OUTPUTS
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
@@ -1917,6 +1930,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._NOT_IDEMPOTENT = schema.not_idempotent
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls._LAZY_OUTPUTS = schema.lazy_outputs
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
|
||||
8
comfy_api_nodes/apis/__init__.py
generated
8
comfy_api_nodes/apis/__init__.py
generated
@@ -1197,6 +1197,12 @@ class KlingImageGenImageReferenceType(str, Enum):
|
||||
face = 'face'
|
||||
|
||||
|
||||
class KlingImageGenModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_5 = 'kling-v1-5'
|
||||
kling_v2 = 'kling-v2'
|
||||
|
||||
|
||||
class KlingImageGenerationsRequest(BaseModel):
|
||||
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
|
||||
callback_url: Optional[AnyUrl] = Field(
|
||||
@@ -1212,7 +1218,7 @@ class KlingImageGenerationsRequest(BaseModel):
|
||||
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
|
||||
)
|
||||
image_reference: Optional[KlingImageGenImageReferenceType] = None
|
||||
model_name: str = Field(...)
|
||||
model_name: Optional[KlingImageGenModelName] = 'kling-v1'
|
||||
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative text prompt', max_length=200
|
||||
|
||||
@@ -1,22 +1,12 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MultiPromptEntry(BaseModel):
|
||||
index: int = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: str = Field(...)
|
||||
|
||||
|
||||
class OmniProText2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class OmniParamImage(BaseModel):
|
||||
@@ -36,10 +26,6 @@ class OmniProFirstLastFrameRequest(BaseModel):
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str | None = Field(None, description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class OmniProReferences2VideoRequest(BaseModel):
|
||||
@@ -52,10 +38,6 @@ class OmniProReferences2VideoRequest(BaseModel):
|
||||
duration: str | None = Field(..., description="From 3 to 10.")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str | None = Field(None, description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusVideoResult(BaseModel):
|
||||
@@ -72,7 +54,6 @@ class TaskStatusImageResult(BaseModel):
|
||||
class TaskStatusResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
images: list[TaskStatusImageResult] | None = Field(None)
|
||||
series_images: list[TaskStatusImageResult] | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusResponseData(BaseModel):
|
||||
@@ -96,42 +77,31 @@ class OmniImageParamImage(BaseModel):
|
||||
|
||||
|
||||
class OmniProImageRequest(BaseModel):
|
||||
model_name: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
model_name: str = Field(..., description="kling-image-o1")
|
||||
resolution: str = Field(..., description="'1k' or '2k'")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
n: int | None = Field(1, le=9)
|
||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||
result_type: str | None = Field(None, description="Set to 'series' for series generation")
|
||||
series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
|
||||
|
||||
|
||||
class TextToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(...)
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(...)
|
||||
prompt: str | None = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(...)
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
image: str = Field(...)
|
||||
image_tail: str | None = Field(None)
|
||||
duration: str = Field(...)
|
||||
prompt: str | None = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class MotionControlRequest(BaseModel):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,30 +30,6 @@ from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
)
|
||||
|
||||
_EUR_TO_USD = 1.19
|
||||
|
||||
|
||||
def _tier_price_eur(megapixels: float) -> float:
|
||||
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
|
||||
if megapixels <= 1.3:
|
||||
return 0.143
|
||||
if megapixels <= 3.0:
|
||||
return 0.286
|
||||
if megapixels <= 6.4:
|
||||
return 0.429
|
||||
return 1.716
|
||||
|
||||
|
||||
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
|
||||
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
|
||||
num_steps = int(math.log2(scale))
|
||||
total_eur = 0.0
|
||||
pixels = width * height
|
||||
for _ in range(num_steps):
|
||||
total_eur += _tier_price_eur(pixels / 1_000_000)
|
||||
pixels *= 4
|
||||
return round(total_eur * _EUR_TO_USD, 2)
|
||||
|
||||
|
||||
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
@@ -127,20 +103,11 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
expr="""
|
||||
(
|
||||
$ad := widgets.auto_downscale;
|
||||
$mins := $ad
|
||||
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
|
||||
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -201,10 +168,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
actual_scale = int(scale_factor.rstrip("x"))
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
||||
@@ -226,7 +189,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@@ -295,14 +257,8 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
expr="""
|
||||
(
|
||||
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -365,9 +321,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
||||
@@ -386,7 +339,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@@ -925,8 +877,8 @@ class MagnificExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
MagnificImageUpscalerCreativeNode,
|
||||
MagnificImageUpscalerPreciseV2Node,
|
||||
# MagnificImageUpscalerCreativeNode,
|
||||
# MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageStyleTransferNode,
|
||||
MagnificImageRelightNode,
|
||||
MagnificImageSkinEnhancerNode,
|
||||
|
||||
@@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
default=33,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Number of denoising steps",
|
||||
@@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=60,
|
||||
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
|
||||
default=33,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
@@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: int | None = 100,
|
||||
steps=60,
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
@@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
default=33,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Inference steps",
|
||||
|
||||
@@ -57,7 +57,6 @@ class _RequestConfig:
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||
multipart_parser: Callable | None
|
||||
max_retries: int
|
||||
max_retries_on_rate_limit: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
@@ -66,7 +65,6 @@ class _RequestConfig:
|
||||
final_label_on_success: str | None = "Completed"
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -80,7 +78,7 @@ class _PollUIState:
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
@@ -105,8 +103,6 @@ async def sync_op(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
@@ -126,8 +122,6 @@ async def sync_op(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@@ -149,9 +143,9 @@ async def poll_op(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 10,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@@ -200,8 +194,6 @@ async def sync_op_raw(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
@@ -230,8 +222,6 @@ async def sync_op_raw(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@@ -250,9 +240,9 @@ async def poll_op_raw(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 10,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@@ -516,7 +506,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
@@ -596,8 +586,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
rate_limit_attempts = 0
|
||||
rate_limit_delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: int | None = None
|
||||
extracted_price: float | None = None
|
||||
@@ -665,14 +653,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
@@ -697,33 +688,41 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
should_retry = False
|
||||
wait_time = 0.0
|
||||
retry_label = ""
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
wait_time = min(rate_limit_delay, 30.0)
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
wait_time,
|
||||
retry_label,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@@ -731,27 +730,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||
error_message=msg,
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
wait_time,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
@@ -771,14 +753,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
@@ -795,39 +780,45 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
if attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt - rate_limit_attempts,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
@@ -840,6 +831,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@@ -847,21 +855,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
|
||||
@@ -167,25 +167,27 @@ async def download_url_to_bytesio(
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -90,41 +91,38 @@ def log_request_response(
|
||||
Filenames are sanitized and length-limited for cross-platform safety.
|
||||
If we still fail to write, we fall back to appending into api.log.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -255,14 +255,17 @@ async def upload_file(
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
@@ -308,27 +311,31 @@ async def upload_file(
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
|
||||
@@ -5,7 +5,7 @@ import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
@@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
|
||||
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
|
||||
expected = get_expected_outputs_for_node(dynprompt, node_id)
|
||||
signature.append(("expected_outputs", tuple(sorted(expected))))
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
|
||||
@@ -19,6 +19,15 @@ class NodeInputError(Exception):
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
"""Get the set of output indices that are connected downstream.
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||
"""
|
||||
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
@@ -27,6 +36,7 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
@@ -42,6 +52,7 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
@@ -59,6 +70,26 @@ class DynamicPrompt:
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def _build_expected_outputs_map(self):
|
||||
result = {}
|
||||
for node_id in self.all_node_ids():
|
||||
try:
|
||||
node_data = self.get_node(node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
for value in node_data.get("inputs", {}).values():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id not in result:
|
||||
result[from_node_id] = set()
|
||||
result[from_node_id].add(from_socket)
|
||||
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||
|
||||
def get_expected_outputs_map(self):
|
||||
if self._expected_outputs_map is None:
|
||||
self._build_expected_outputs_map()
|
||||
return self._expected_outputs_map
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
@@ -20,60 +20,10 @@ class JobStatus:
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
||||
|
||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||
|
||||
|
||||
def has_3d_extension(filename: str) -> bool:
|
||||
lower = filename.lower()
|
||||
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
|
||||
|
||||
|
||||
def normalize_output_item(item):
|
||||
"""Normalize a single output list item for the jobs API.
|
||||
|
||||
Returns the normalized item, or None to exclude it.
|
||||
String items with 3D extensions become {filename, type, subfolder} dicts.
|
||||
"""
|
||||
if item is None:
|
||||
return None
|
||||
if isinstance(item, str):
|
||||
if has_3d_extension(item):
|
||||
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||
return None
|
||||
if isinstance(item, dict):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def normalize_outputs(outputs: dict) -> dict:
|
||||
"""Normalize raw node outputs for the jobs API.
|
||||
|
||||
Transforms string 3D filenames into file output dicts and removes
|
||||
None items. All other items (non-3D strings, dicts, etc.) are
|
||||
preserved as-is.
|
||||
"""
|
||||
normalized = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
if not isinstance(node_outputs, dict):
|
||||
normalized[node_id] = node_outputs
|
||||
continue
|
||||
normalized_node = {}
|
||||
for media_type, items in node_outputs.items():
|
||||
if media_type == 'animated' or not isinstance(items, list):
|
||||
normalized_node[media_type] = items
|
||||
continue
|
||||
normalized_items = []
|
||||
for item in items:
|
||||
if item is None:
|
||||
continue
|
||||
norm = normalize_output_item(item)
|
||||
normalized_items.append(norm if norm is not None else item)
|
||||
normalized_node[media_type] = normalized_items
|
||||
normalized[node_id] = normalized_node
|
||||
return normalized
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
@@ -95,9 +45,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
|
||||
Maintains backwards compatibility with existing logic.
|
||||
|
||||
Priority:
|
||||
1. media_type is 'images', 'video', 'audio', or '3d'
|
||||
1. media_type is 'images', 'video', or 'audio'
|
||||
2. format field starts with 'video/' or 'audio/'
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
||||
"""
|
||||
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
||||
return True
|
||||
@@ -189,7 +139,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
})
|
||||
|
||||
if include_outputs:
|
||||
job['outputs'] = normalize_outputs(outputs)
|
||||
job['outputs'] = outputs
|
||||
job['execution_status'] = status_info
|
||||
job['workflow'] = {
|
||||
'prompt': prompt,
|
||||
@@ -221,23 +171,18 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
normalized = normalize_output_item(item)
|
||||
if normalized is None:
|
||||
continue
|
||||
|
||||
count += 1
|
||||
|
||||
if preview_output is not None:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
||||
if preview_output is None and is_previewable(media_type, item):
|
||||
enriched = {
|
||||
**normalized,
|
||||
**item,
|
||||
'nodeId': node_id,
|
||||
'mediaType': media_type
|
||||
}
|
||||
if 'mediaType' not in normalized:
|
||||
enriched['mediaType'] = media_type
|
||||
if normalized.get('type') == 'output':
|
||||
if item.get('type') == 'output':
|
||||
preview_output = enriched
|
||||
elif fallback_preview is None:
|
||||
fallback_preview = enriched
|
||||
|
||||
@@ -1,23 +1,41 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
from typing import NamedTuple, FrozenSet
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
prompt_id: The ID of the current prompt execution
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
expected_outputs: Set of output indices that might be used downstream.
|
||||
Outputs NOT in this set are definitely unused (safe to skip).
|
||||
None means the information is not available.
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
list_index: int | None
|
||||
expected_outputs: FrozenSet[int] | None = None
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
def get_executing_context() -> ExecutionContext | None:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
|
||||
def is_output_needed(output_index: int) -> bool:
|
||||
"""Check if an output at the given index is connected downstream.
|
||||
|
||||
Returns True if the output might be used (should be computed).
|
||||
Returns False if the output is definitely not connected (safe to skip).
|
||||
"""
|
||||
ctx = get_executing_context()
|
||||
if ctx is None or ctx.expected_outputs is None:
|
||||
return True
|
||||
return output_index in ctx.expected_outputs
|
||||
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
@@ -25,15 +43,22 @@ class CurrentNodeContext:
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_id: str,
|
||||
node_id: str,
|
||||
list_index: int | None = None,
|
||||
expected_outputs: FrozenSet[int] | None = None,
|
||||
):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index,
|
||||
expected_outputs=expected_outputs,
|
||||
)
|
||||
self.token = None
|
||||
|
||||
|
||||
@@ -49,14 +49,13 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from tqdm.auto import trange
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
@@ -28,11 +27,6 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
"""
|
||||
CFGGuider with modifications for training specific logic
|
||||
"""
|
||||
|
||||
def __init__(self, *args, offloading=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.offloading = offloading
|
||||
|
||||
def outer_sample(
|
||||
self,
|
||||
noise,
|
||||
@@ -51,11 +45,9 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
noise.shape,
|
||||
self.conds,
|
||||
self.model_options,
|
||||
force_full_load=not self.offloading,
|
||||
force_offload=self.offloading,
|
||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
||||
)
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
@@ -412,97 +404,16 @@ def find_all_highest_child_module_with_forward(
|
||||
return result
|
||||
|
||||
|
||||
def find_modules_at_depth(
|
||||
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
||||
) -> list[nn.Module]:
|
||||
"""
|
||||
Find modules at a specific depth level for gradient checkpointing.
|
||||
|
||||
Args:
|
||||
model: The model to search
|
||||
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
||||
result: Accumulator for results
|
||||
current_depth: Current recursion depth
|
||||
name: Current module name for logging
|
||||
|
||||
Returns:
|
||||
List of modules at the target depth
|
||||
"""
|
||||
if result is None:
|
||||
result = []
|
||||
name = name or "root"
|
||||
|
||||
# Skip container modules (they don't have meaningful forward)
|
||||
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
||||
has_forward = hasattr(model, "forward") and not is_container
|
||||
|
||||
if has_forward:
|
||||
current_depth += 1
|
||||
if current_depth == depth:
|
||||
result.append(model)
|
||||
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
|
||||
# Recurse into children
|
||||
for next_name, child in model.named_children():
|
||||
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OffloadCheckpointFunction(torch.autograd.Function):
|
||||
"""
|
||||
Gradient checkpointing that works with weight offloading.
|
||||
|
||||
Forward: no_grad -> compute -> weights can be freed
|
||||
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
||||
|
||||
For single input, single output modules (Linear, Conv*).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, forward_fn):
|
||||
ctx.save_for_backward(x)
|
||||
ctx.forward_fn = forward_fn
|
||||
with torch.no_grad():
|
||||
return forward_fn(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor):
|
||||
x, = ctx.saved_tensors
|
||||
forward_fn = ctx.forward_fn
|
||||
|
||||
# Clear context early
|
||||
ctx.forward_fn = None
|
||||
|
||||
with torch.enable_grad():
|
||||
x_detached = x.detach().requires_grad_(True)
|
||||
y = forward_fn(x_detached)
|
||||
y.backward(grad_out)
|
||||
grad_x = x_detached.grad
|
||||
|
||||
# Explicit cleanup
|
||||
del y, x_detached, forward_fn
|
||||
|
||||
return grad_x, None
|
||||
|
||||
|
||||
def patch(m, offloading=False):
|
||||
def patch(m):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
|
||||
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
||||
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
def checkpointing_fwd(x):
|
||||
return OffloadCheckpointFunction.apply(x, org_forward)
|
||||
# Branch 2: Others -> standard checkpoint
|
||||
else:
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
@@ -1025,18 +936,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default=True,
|
||||
tooltip="Use gradient checkpointing for training.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"checkpoint_depth",
|
||||
default=1,
|
||||
min=1,
|
||||
max=5,
|
||||
tooltip="Depth level for gradient checkpointing.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Depth level for gradient checkpointing.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
@@ -1083,8 +982,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
offloading,
|
||||
existing_lora,
|
||||
bucket_mode,
|
||||
bypass_mode,
|
||||
@@ -1103,8 +1000,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype = lora_dtype[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
checkpoint_depth = checkpoint_depth[0]
|
||||
existing_lora = existing_lora[0]
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
@@ -1159,18 +1054,16 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
modules_to_patch = find_modules_at_depth(
|
||||
mp.model.diffusion_model, depth=checkpoint_depth
|
||||
)
|
||||
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
||||
for m in modules_to_patch:
|
||||
patch(m, offloading=offloading)
|
||||
for m in find_all_highest_child_module_with_forward(
|
||||
mp.model.diffusion_model
|
||||
):
|
||||
patch(m)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
# With force_full_load=False we should be able to have offloading
|
||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||
comfy.model_management.load_models_gpu(
|
||||
[mp], memory_required=1e20, force_full_load=not offloading
|
||||
[mp], memory_required=1e20, force_full_load=True
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1207,7 +1100,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
guider = TrainGuider(mp, offloading=offloading)
|
||||
guider = TrainGuider(mp)
|
||||
guider.set_conds(positive)
|
||||
|
||||
# Inject bypass hooks if bypass mode is enabled
|
||||
@@ -1220,7 +1113,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Run training loop
|
||||
try:
|
||||
comfy.model_management.in_training = True
|
||||
_run_training_loop(
|
||||
guider,
|
||||
train_sampler,
|
||||
@@ -1231,7 +1123,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
multi_res,
|
||||
)
|
||||
finally:
|
||||
comfy.model_management.in_training = False
|
||||
# Eject bypass hooks if they were injected
|
||||
if bypass_injections is not None:
|
||||
for injection in bypass_injections:
|
||||
@@ -1241,20 +1132,19 @@ class TrainLoraNode(io.ComfyNode):
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
|
||||
|
||||
# Finalize adapters
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
del adapter
|
||||
del all_weight_adapters
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
# mp in train node is highly specialized for training
|
||||
# use it in inference will result in bad behavior so we don't return it
|
||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
class LoraModelLoader(io.ComfyNode):#
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
@@ -1276,11 +1166,6 @@ class LoraModelLoader(io.ComfyNode):
|
||||
max=100.0,
|
||||
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"bypass",
|
||||
default=False,
|
||||
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(
|
||||
@@ -1290,18 +1175,13 @@ class LoraModelLoader(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, lora, strength_model, bypass=False):
|
||||
def execute(cls, model, lora, strength_model):
|
||||
if strength_model == 0:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
if bypass:
|
||||
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
else:
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
return io.NodeOutput(model_lora)
|
||||
|
||||
|
||||
|
||||
@@ -202,56 +202,6 @@ class LoadVideo(io.ComfyNode):
|
||||
|
||||
return True
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Video Slice",
|
||||
display_name="Video Slice",
|
||||
search_aliases=[
|
||||
"trim video duration",
|
||||
"skip first frames",
|
||||
"frame load cap",
|
||||
"start time",
|
||||
],
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
io.Float.Input(
|
||||
"start_time",
|
||||
default=0.0,
|
||||
max=1e5,
|
||||
min=-1e5,
|
||||
step=0.001,
|
||||
tooltip="Start time in seconds",
|
||||
),
|
||||
io.Float.Input(
|
||||
"duration",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
step=0.001,
|
||||
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"strict_duration",
|
||||
default=False,
|
||||
tooltip="If True, when the specified duration is not possible, an error will be raised.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
|
||||
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
|
||||
if trimmed is not None:
|
||||
return io.NodeOutput(trimmed)
|
||||
raise ValueError(
|
||||
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
|
||||
)
|
||||
|
||||
|
||||
class VideoExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -262,7 +212,6 @@ class VideoExtension(ComfyExtension):
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
VideoSlice,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.13.0"
|
||||
__version__ = "0.12.3"
|
||||
|
||||
49
execution.py
49
execution.py
@@ -13,11 +13,8 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
from latent_preview import set_preview_method
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
@@ -34,6 +31,7 @@ from comfy_execution.graph import (
|
||||
ExecutionBlocker,
|
||||
ExecutionList,
|
||||
get_input_info,
|
||||
get_expected_outputs_for_node,
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
@@ -230,7 +228,18 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
async def _async_map_node_over_list(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
func,
|
||||
allow_interrupt=False,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -280,10 +289,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
task = asyncio.create_task(
|
||||
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
|
||||
)
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
@@ -292,7 +303,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
@@ -330,8 +341,17 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
async def get_output_data(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@@ -525,15 +545,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
|
||||
#that we just want to cull out each model run.
|
||||
allocator = comfy.memory_management.aimdo_allocator
|
||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.model_vbar.vbars_analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
@@ -623,8 +642,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
|
||||
tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
|
||||
|
||||
error_details = {
|
||||
"node_id": real_node_id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.13.0"
|
||||
version = "0.12.3"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.38.13
|
||||
comfyui-workflow-templates==0.8.38
|
||||
comfyui-workflow-templates==0.8.31
|
||||
comfyui-embedded-docs==0.4.1
|
||||
torch
|
||||
torchsde
|
||||
@@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.1.8
|
||||
comfy-aimdo>=0.1.7
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
322
tests-unit/execution_test/expected_outputs_test.py
Normal file
322
tests-unit/execution_test/expected_outputs_test.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Unit tests for the expected_outputs feature.
|
||||
|
||||
This feature allows nodes to know at runtime which outputs are connected downstream,
|
||||
enabling them to skip computing outputs that aren't needed.
|
||||
"""
|
||||
|
||||
from comfy_api.latest import IO
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from comfy_execution.utils import (
|
||||
CurrentNodeContext,
|
||||
ExecutionContext,
|
||||
get_executing_context,
|
||||
is_output_needed,
|
||||
)
|
||||
|
||||
|
||||
class TestGetExpectedOutputsForNode:
|
||||
"""Tests for get_expected_outputs_for_node() function."""
|
||||
|
||||
def test_single_output_connected(self):
|
||||
"""Test node with single output connected to one downstream node."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_multiple_outputs_partial_connected(self):
|
||||
"""Test node with multiple outputs, only some connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
# Output 1 is not connected
|
||||
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 2})
|
||||
assert 1 not in expected # Output 1 is definitely unused
|
||||
|
||||
def test_no_outputs_connected(self):
|
||||
"""Test node with no outputs connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "OtherNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_same_output_connected_multiple_times(self):
|
||||
"""Test same output connected to multiple downstream nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
|
||||
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_node_not_in_prompt(self):
|
||||
"""Test getting expected outputs for a node not in the prompt."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "999")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_chained_nodes(self):
|
||||
"""Test expected outputs in a chain of nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1's output 0 is connected to node 2
|
||||
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected_1 == frozenset({0})
|
||||
|
||||
# Node 2's output 0 is connected to node 3
|
||||
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
|
||||
assert expected_2 == frozenset({0})
|
||||
|
||||
# Node 3 has no downstream connections
|
||||
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
|
||||
assert expected_3 == frozenset()
|
||||
|
||||
def test_complex_graph(self):
|
||||
"""Test expected outputs in a complex graph with multiple connections."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
|
||||
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
|
||||
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1 has outputs 0, 1, 2 all connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1, 2})
|
||||
|
||||
def test_constant_inputs_ignored(self):
|
||||
"""Test that constant (non-link) inputs don't affect expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {
|
||||
"class_type": "ConsumerNode",
|
||||
"inputs": {
|
||||
"image": ["1", 0],
|
||||
"value": 42,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_ephemeral_node_invalidates_cache(self):
|
||||
"""Test that adding ephemeral nodes updates expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Initially only output 0 is connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
# Add an ephemeral node that connects to output 1
|
||||
dynprompt.add_ephemeral_node(
|
||||
"eph_1",
|
||||
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
|
||||
parent_id="2",
|
||||
display_id="2",
|
||||
)
|
||||
|
||||
# Now both outputs 0 and 1 should be expected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1})
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext with expected_outputs field."""
|
||||
|
||||
def test_context_with_expected_outputs(self):
|
||||
"""Test creating ExecutionContext with expected_outputs."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
|
||||
)
|
||||
assert ctx.prompt_id == "prompt-123"
|
||||
assert ctx.node_id == "node-456"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 2})
|
||||
|
||||
def test_context_without_expected_outputs(self):
|
||||
"""Test ExecutionContext defaults to None for expected_outputs."""
|
||||
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_context_empty_expected_outputs(self):
|
||||
"""Test ExecutionContext with empty expected_outputs set."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
|
||||
)
|
||||
assert ctx.expected_outputs == frozenset()
|
||||
assert len(ctx.expected_outputs) == 0
|
||||
|
||||
|
||||
class TestCurrentNodeContext:
|
||||
"""Tests for CurrentNodeContext context manager with expected_outputs."""
|
||||
|
||||
def test_context_manager_with_expected_outputs(self):
|
||||
"""Test CurrentNodeContext sets and resets context correctly."""
|
||||
assert get_executing_context() is None
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.prompt_id == "prompt-1"
|
||||
assert ctx.node_id == "node-1"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 1})
|
||||
|
||||
assert get_executing_context() is None
|
||||
|
||||
def test_context_manager_without_expected_outputs(self):
|
||||
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
|
||||
with CurrentNodeContext("prompt-1", "node-1"):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested CurrentNodeContext managers."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
|
||||
ctx1 = get_executing_context()
|
||||
assert ctx1.expected_outputs == frozenset({0})
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
|
||||
ctx2 = get_executing_context()
|
||||
assert ctx2.expected_outputs == frozenset({1, 2})
|
||||
assert ctx2.node_id == "node-2"
|
||||
|
||||
# After inner context exits, should be back to outer context
|
||||
ctx1_again = get_executing_context()
|
||||
assert ctx1_again.expected_outputs == frozenset({0})
|
||||
assert ctx1_again.node_id == "node-1"
|
||||
|
||||
def test_output_check_pattern(self):
|
||||
"""Test the typical pattern nodes will use to check expected outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Typical usage pattern
|
||||
if ctx and ctx.expected_outputs is not None:
|
||||
should_compute_0 = 0 in ctx.expected_outputs
|
||||
should_compute_1 = 1 in ctx.expected_outputs
|
||||
should_compute_2 = 2 in ctx.expected_outputs
|
||||
else:
|
||||
# Fallback when info not available
|
||||
should_compute_0 = should_compute_1 = should_compute_2 = True
|
||||
|
||||
assert should_compute_0 is True
|
||||
assert should_compute_1 is False # Not in expected_outputs
|
||||
assert should_compute_2 is True
|
||||
|
||||
|
||||
class TestSchemaLazyOutputs:
|
||||
"""Tests for lazy_outputs in V3 Schema."""
|
||||
|
||||
def test_schema_lazy_outputs_default(self):
|
||||
"""Test that lazy_outputs defaults to False."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is False
|
||||
|
||||
def test_schema_lazy_outputs_true(self):
|
||||
"""Test setting lazy_outputs to True."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is True
|
||||
|
||||
def test_v3_node_lazy_outputs_property(self):
|
||||
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
|
||||
|
||||
class TestNodeWithLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithLazyOutputs",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
|
||||
|
||||
def test_v3_node_lazy_outputs_default(self):
|
||||
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
|
||||
|
||||
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithoutLazyOutputs",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
|
||||
|
||||
|
||||
class TestIsOutputNeeded:
|
||||
"""Tests for is_output_needed() helper function."""
|
||||
|
||||
def test_output_needed_when_in_expected(self):
|
||||
"""Test that output is needed when in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(2) is True
|
||||
|
||||
def test_output_not_needed_when_not_in_expected(self):
|
||||
"""Test that output is not needed when not in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(1) is False
|
||||
assert is_output_needed(3) is False
|
||||
|
||||
def test_output_needed_when_no_context(self):
|
||||
"""Test that output is needed when no context."""
|
||||
assert get_executing_context() is None
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
|
||||
def test_output_needed_when_expected_outputs_is_none(self):
|
||||
"""Test that output is needed when expected_outputs is None."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, None):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
@@ -574,6 +574,104 @@ class TestExecution:
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
|
||||
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs contains all connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, all connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Connect all 3 outputs to preview nodes
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# All outputs should be white (255) since all are connected
|
||||
images0 = result.get_images(output0)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White pixels = 255, meaning output was in expected_outputs
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs only contains connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only some connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect outputs 0 and 2, leave output 1 disconnected
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
# output1 is intentionally not connected
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Connected outputs should be white (255)
|
||||
images0 = result.get_images(output0)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs works with single connected output."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only one connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect output 1
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
images1 = result.get_images(output1)
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
|
||||
# Output 1 should be white (connected), others are not visible in this test
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
"""Test that cache invalidates when output connections change."""
|
||||
g = builder
|
||||
# Use unique dimensions to avoid cache collision with other expected_outputs tests
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
|
||||
|
||||
# First run: only connect output 0
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(expected_outputs_node), "First run should execute the node"
|
||||
|
||||
# Second run: same connections, should be cached
|
||||
result2 = client.run(g)
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
|
||||
|
||||
# Third run: add connection to output 2
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result3 = client.run(g)
|
||||
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
|
||||
if server["should_cache_results"]:
|
||||
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
|
||||
|
||||
# Verify both outputs are now white
|
||||
images0 = result3.get_images(output0)
|
||||
images2 = result3.get_images(output2)
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
|
||||
@@ -5,11 +5,8 @@ from comfy_execution.jobs import (
|
||||
is_previewable,
|
||||
normalize_queue_item,
|
||||
normalize_history_item,
|
||||
normalize_output_item,
|
||||
normalize_outputs,
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
has_3d_extension,
|
||||
)
|
||||
|
||||
|
||||
@@ -38,8 +35,8 @@ class TestIsPreviewable:
|
||||
"""Unit tests for is_previewable()"""
|
||||
|
||||
def test_previewable_media_types(self):
|
||||
"""Images, video, audio, 3d media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio', '3d']:
|
||||
"""Images, video, audio media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio']:
|
||||
assert is_previewable(media_type, {}) is True
|
||||
|
||||
def test_non_previewable_media_types(self):
|
||||
@@ -49,7 +46,7 @@ class TestIsPreviewable:
|
||||
|
||||
def test_3d_extensions_previewable(self):
|
||||
"""3D file extensions should be previewable regardless of media_type."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
item = {'filename': f'model{ext}'}
|
||||
assert is_previewable('files', item) is True
|
||||
|
||||
@@ -163,7 +160,7 @@ class TestGetOutputsSummary:
|
||||
|
||||
def test_3d_files_previewable(self):
|
||||
"""3D file extensions should be previewable."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
outputs = {
|
||||
'node1': {
|
||||
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
||||
@@ -195,64 +192,6 @@ class TestGetOutputsSummary:
|
||||
assert preview['mediaType'] == 'images'
|
||||
assert preview['subfolder'] == 'outputs'
|
||||
|
||||
def test_string_3d_filename_creates_preview(self):
|
||||
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
|
||||
Only the .glb counts — nulls and non-file strings are excluded."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['preview3d_abc123.glb', None, None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 1
|
||||
assert preview is not None
|
||||
assert preview['filename'] == 'preview3d_abc123.glb'
|
||||
assert preview['mediaType'] == '3d'
|
||||
assert preview['nodeId'] == 'node1'
|
||||
assert preview['type'] == 'output'
|
||||
|
||||
def test_string_non_3d_filename_no_preview(self):
|
||||
"""String items without 3D extensions should not create a preview."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['data.json', None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 0
|
||||
assert preview is None
|
||||
|
||||
def test_string_3d_filename_used_as_fallback(self):
|
||||
"""String 3D preview should be used when no dict items are previewable."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'latents': [{'filename': 'latent.safetensors'}],
|
||||
},
|
||||
'node2': {
|
||||
'result': ['model.glb', None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview is not None
|
||||
assert preview['filename'] == 'model.glb'
|
||||
assert preview['mediaType'] == '3d'
|
||||
|
||||
|
||||
class TestHas3DExtension:
|
||||
"""Unit tests for has_3d_extension()"""
|
||||
|
||||
def test_recognized_extensions(self):
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
assert has_3d_extension(f'model{ext}') is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert has_3d_extension('MODEL.GLB') is True
|
||||
assert has_3d_extension('Scene.GLTF') is True
|
||||
|
||||
def test_non_3d_extensions(self):
|
||||
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
|
||||
assert has_3d_extension(name) is False
|
||||
|
||||
|
||||
class TestApplySorting:
|
||||
"""Unit tests for apply_sorting()"""
|
||||
@@ -456,142 +395,3 @@ class TestNormalizeHistoryItem:
|
||||
'prompt': {'nodes': {'1': {}}},
|
||||
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
||||
}
|
||||
|
||||
def test_include_outputs_normalizes_3d_strings(self):
|
||||
"""Detail view should transform string 3D filenames into file output dicts."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-3d',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890},
|
||||
['node1'],
|
||||
),
|
||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||
'outputs': {
|
||||
'node1': {
|
||||
'result': ['preview3d_abc123.glb', None, None]
|
||||
}
|
||||
},
|
||||
}
|
||||
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
|
||||
|
||||
assert job['outputs_count'] == 1
|
||||
result_items = job['outputs']['node1']['result']
|
||||
assert len(result_items) == 1
|
||||
assert result_items[0] == {
|
||||
'filename': 'preview3d_abc123.glb',
|
||||
'type': 'output',
|
||||
'subfolder': '',
|
||||
'mediaType': '3d',
|
||||
}
|
||||
|
||||
def test_include_outputs_preserves_dict_items(self):
|
||||
"""Detail view normalization should pass dict items through unchanged."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-img',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890},
|
||||
['node1'],
|
||||
),
|
||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||
'outputs': {
|
||||
'node1': {
|
||||
'images': [
|
||||
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
|
||||
|
||||
assert job['outputs_count'] == 1
|
||||
assert job['outputs']['node1']['images'] == [
|
||||
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||
]
|
||||
|
||||
|
||||
class TestNormalizeOutputItem:
|
||||
"""Unit tests for normalize_output_item()"""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert normalize_output_item(None) is None
|
||||
|
||||
def test_string_3d_extension_synthesizes_dict(self):
|
||||
result = normalize_output_item('model.glb')
|
||||
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||
|
||||
def test_string_non_3d_extension_returns_none(self):
|
||||
assert normalize_output_item('data.json') is None
|
||||
|
||||
def test_string_no_extension_returns_none(self):
|
||||
assert normalize_output_item('camera_info_string') is None
|
||||
|
||||
def test_dict_passes_through(self):
|
||||
item = {'filename': 'test.png', 'type': 'output'}
|
||||
assert normalize_output_item(item) is item
|
||||
|
||||
def test_other_types_return_none(self):
|
||||
assert normalize_output_item(42) is None
|
||||
assert normalize_output_item(True) is None
|
||||
|
||||
|
||||
class TestNormalizeOutputs:
|
||||
"""Unit tests for normalize_outputs()"""
|
||||
|
||||
def test_empty_outputs(self):
|
||||
assert normalize_outputs({}) == {}
|
||||
|
||||
def test_dict_items_pass_through(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == outputs
|
||||
|
||||
def test_3d_string_synthesized(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['model.glb', None, None],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {
|
||||
'node1': {
|
||||
'result': [
|
||||
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
def test_animated_key_preserved(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||
'animated': [True],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result['node1']['animated'] == [True]
|
||||
|
||||
def test_non_dict_node_outputs_preserved(self):
|
||||
outputs = {'node1': 'unexpected_value'}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {'node1': 'unexpected_value'}
|
||||
|
||||
def test_none_items_filtered_but_other_types_preserved(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['data.json', None, [1, 2, 3]],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {
|
||||
'node1': {
|
||||
'result': ['data.json', [1, 2, 3]],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@@ -482,6 +483,57 @@ class TestOutputNodeWithSocketOutput:
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestExpectedOutputs:
|
||||
"""Test node for the expected_outputs feature.
|
||||
|
||||
This node has 3 IMAGE outputs that encode which outputs were expected:
|
||||
- White image (255) if the output was in expected_outputs
|
||||
- Black image (0) if the output was NOT in expected_outputs
|
||||
|
||||
This allows integration tests to verify which outputs were expected by checking pixel values.
|
||||
"""
|
||||
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("output0", "output1", "output2")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Default: assume all outputs are expected (backwards compatibility)
|
||||
output0_expected = True
|
||||
output1_expected = True
|
||||
output2_expected = True
|
||||
|
||||
if ctx is not None and ctx.expected_outputs is not None:
|
||||
output0_expected = 0 in ctx.expected_outputs
|
||||
output1_expected = 1 in ctx.expected_outputs
|
||||
output2_expected = 2 in ctx.expected_outputs
|
||||
|
||||
# Return white image if expected, black if not
|
||||
# This allows tests to verify which outputs were expected via pixel values
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
return (
|
||||
white if output0_expected else black,
|
||||
white if output1_expected else black,
|
||||
white if output2_expected else black,
|
||||
)
|
||||
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -498,6 +550,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
"TestExpectedOutputs": TestExpectedOutputs,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -516,4 +569,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
"TestExpectedOutputs": "Test Expected Outputs",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user