Compare commits

..

5 Commits

Author SHA1 Message Date
Alexander Piskun
06408af600 Merge branch 'master' into feat/core/expected_outputs 2026-02-09 20:31:18 +02:00
Alexander Piskun
3902c86e83 Merge branch 'master' into feat/core/expected_outputs 2026-02-06 09:01:47 +02:00
bigcat88
01ef4e50ec fix: precompute expected_outputs map to avoid O(n²) graph traversal 2026-02-06 08:29:01 +02:00
bigcat88
50975a7a0d add a helper function for easy use 2026-02-05 09:05:42 +02:00
bigcat88
d987b0d32d feat: add expected_outputs feature for lazy output computation 2026-02-05 09:05:42 +02:00
33 changed files with 874 additions and 1600 deletions

View File

@@ -195,20 +195,8 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
def preprocess_text_embeds(self, text_embeds, text_ids):
if text_ids is not None:
out = self.llm_adapter(text_embeds, text_ids)
if t5xxl_weights is not None:
out = out * t5xxl_weights
if out.shape[1] < 512:
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
return out
return self.llm_adapter(text_embeds, text_ids)
else:
return text_embeds
def forward(self, x, timesteps, context, **kwargs):
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
if t5xxl_ids is not None:
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
return super().forward(x, timesteps, context, **kwargs)

View File

@@ -29,34 +29,19 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device)
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try:
import comfy.quant_ops
q_apply_rope = comfy.quant_ops.ck.apply_rope
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
apply_rope = _apply_rope
apply_rope1 = _apply_rope1
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@@ -1160,16 +1160,12 @@ class Anima(BaseModel):
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import threading
import torch
import sys
@@ -55,11 +55,6 @@ cpu_state = CPUState.GPU
total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types():
float8_types = []
try:
@@ -656,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -752,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_loaded_models.insert(0, loaded_model)
return
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
def load_model_gpu(model):
return load_models_gpu([model])
@@ -1211,16 +1226,21 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
return v_tensor.to(dtype=dtype)
r = torch.empty_like(weight, dtype=dtype, device=device)
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations

View File

@@ -19,6 +19,7 @@
from __future__ import annotations
import collections
import copy
import inspect
import logging
import math
@@ -316,7 +317,7 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
@@ -1491,9 +1492,7 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#We have way more tools for acceleration on comfy weight offloading, so always
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
@@ -1542,7 +1541,6 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
allocated_size += v_weight_size
else:
@@ -1557,10 +1555,8 @@ class ModelPatcherDynamic(ModelPatcher):
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")

View File

@@ -87,7 +87,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = s._v_tensor
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
@@ -169,8 +169,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
else:
y = x
if update_weight:
orig.copy_(y)
for f in fns:

View File

@@ -122,26 +122,20 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
memory_required = 1e20
minimum_memory_required = None
else:
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models

View File

@@ -793,6 +793,8 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
model_management.archive_model_dtypes(self.first_stage_model)
if device is None:
device = model_management.vae_device()
self.device = device
@@ -801,7 +803,6 @@ class VAE:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
model_management.archive_model_dtypes(self.first_stage_model)
self.output_device = model_management.intermediate_device()
mp = comfy.model_patcher.CoreModelPatcher

View File

@@ -3,7 +3,6 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
from tqdm.auto import trange
import yaml
import comfy.utils
@@ -24,8 +23,6 @@ def sample_manual_loop_no_classes(
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
if ids is None:
return []
device = model.execution_device
if execution_dtype is None:
@@ -35,7 +32,6 @@ def sample_manual_loop_no_classes(
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
@@ -45,27 +41,22 @@ def sample_manual_loop_no_classes(
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in trange(max_new_tokens, desc="LM sampling"):
for step in range(max_new_tokens):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
if cfg_scale != 1.0:
cond_logits = next_token_logits[0:1]
uncond_logits = next_token_logits[1:2]
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
else:
cfg_logits = next_token_logits[0:1]
cond_logits = next_token_logits[0:1]
uncond_logits = next_token_logits[1:2]
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
if use_eos_score:
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
@@ -73,7 +64,7 @@ def sample_manual_loop_no_classes(
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
if use_eos_score:
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
@@ -102,8 +93,8 @@ def sample_manual_loop_no_classes(
break
embed, _, _, _ = model.process_tokens([[token]], device)
embeds = embed.repeat(embeds_batch, 1, 1)
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
embeds = embed.repeat(2, 1, 1)
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
@@ -113,29 +104,22 @@ def sample_manual_loop_no_classes(
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
positive = [[token for token, _ in inner_list] for inner_list in positive]
negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
negative = negative[0]
if cfg_scale != 1.0:
negative = [[token for token, _ in inner_list] for inner_list in negative]
negative = negative[0]
neg_pad = 0
if len(negative) < len(positive):
neg_pad = (len(positive) - len(negative))
negative = [model.special_tokens["pad"]] * neg_pad + negative
neg_pad = 0
if len(negative) < len(positive):
neg_pad = (len(positive) - len(negative))
negative = [model.special_tokens["pad"]] * neg_pad + negative
pos_pad = 0
if len(negative) > len(positive):
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
pos_pad = 0
if len(negative) > len(positive):
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
ids = [positive, negative]
else:
paddings = []
ids = [positive]
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
paddings = [pos_pad, neg_pad]
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -145,12 +129,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
for k in ("bpm", "duration", "keyscale", "timesignature")
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature[:-2]
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
@@ -163,11 +147,8 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
use_keys = ("bpm", "timesignature", "keyscale", "duration")
use_keys = ("bpm", "duration", "keyscale", "timesignature")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature[:-2]
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
@@ -178,13 +159,9 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
text = text.strip()
text_negative = kwargs.get("caption_negative", text).strip()
out = {}
lyrics = kwargs.get("lyrics", "")
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
duration = kwargs.get("duration", 120)
if isinstance(duration, str):
duration = float(duration.split(None, 1)[0])
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
@@ -194,46 +171,21 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
duration = math.ceil(duration)
kwargs["duration"] = duration
tokens_duration = duration * 5
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
metas_negative = {
k.rsplit("_", 1)[0]: kwargs.pop(k)
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
if k in kwargs
}
if not kwargs.get("use_negative_caption"):
_ = metas_negative.pop("caption", None)
cot_text = self._metas_to_cot(caption=text, **kwargs)
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
cot_text = self._metas_to_cot(caption = text, **kwargs)
meta_cap = self._metas_to_cap(**kwargs)
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
llm_prompts = {
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
}
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
out = {
prompt_key: self.qwen3_06b.tokenize_with_weights(
prompt,
prompt_key == "qwen3_06b" and return_word_ids,
disable_weights = True,
**kwargs,
)
for prompt_key, prompt in llm_prompts.items()
}
out["lm_metadata"] = {"min_tokens": min_tokens,
"max_tokens": max_tokens,
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
@@ -300,7 +252,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
out["audio_codes"] = [audio_codes]
return base_out, None, out

View File

@@ -1376,21 +1376,3 @@ def string_to_seed(data):
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def deepcopy_list_dict(obj, memo=None):
if memo is None:
memo = {}
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if isinstance(obj, dict):
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
elif isinstance(obj, list):
res = [deepcopy_list_dict(i, memo) for i in obj]
else:
res = obj
memo[obj_id] = res
return res

View File

@@ -21,7 +21,6 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
@@ -182,21 +181,18 @@ class BypassForwardHook:
)
return # Already injected
# Move adapter weights to compute device (GPU)
# Use get_torch_device() instead of module.weight.device because
# with offloading, module weights may be on CPU while compute happens on GPU
device = comfy.model_management.get_torch_device()
# Get dtype from module weight if available
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
device = None
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
# Only use dtype if it's a standard float type, not quantized
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = None
self._move_adapter_weights_to_device(device, dtype)
if device is not None:
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward

View File

@@ -34,21 +34,6 @@ class VideoInput(ABC):
"""
pass
@abstractmethod
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = False,
) -> VideoInput | None:
"""
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
Returns:
A new VideoInput, or None if the result would have negative duration
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without

View File

@@ -6,7 +6,6 @@ from typing import Optional
from .._input import AudioInput, VideoInput
import av
import io
import itertools
import json
import numpy as np
import math
@@ -30,6 +29,7 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
@@ -57,14 +57,12 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
self.__start_time = start_time
self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -98,16 +96,6 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
raw_duration = self._get_raw_duration()
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
if self.__duration:
return min(self.__duration, duration_from_start)
return duration_from_start
def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
@@ -125,13 +113,9 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for packet in frame_iterator:
frame_count += 1
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
@@ -147,54 +131,36 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available and usable
if (
video_stream.frames
and video_stream.frames > 0
and not self.__start_time
and not self.__duration
):
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
raw_duration = float(video_stream.duration * video_stream.time_base)
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
duration_seconds = min(self.__duration, duration_from_start)
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
frame_count = 1
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for frame in frame_iterator:
if frame.pts >= start_pts:
break
else:
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
for frame in frame_iterator:
if frame.pts >= end_pts:
break
frame_count += 1
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -233,21 +199,9 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
video_stream = self._get_first_video_stream(container)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
# Get video frames
frames = []
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
@@ -255,44 +209,31 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
container.seek(start_pts, stream=video_stream)
# Use last stream for consistency
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
audio_frames = []
resample = av.audio.resampler.AudioResampler(format='fltp').resample
frames = itertools.chain.from_iterable(
map(resample, container.decode(audio_stream))
)
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
if has_first_frame:
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
if self.__duration:
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
})
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -309,7 +250,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -321,14 +262,15 @@ class VideoFromFile(VideoInput):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if self.__start_time or self.__duration:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path, format=format, codec=codec, metadata=metadata
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
@@ -362,21 +304,10 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
if len(container.streams.video):
return container.streams.video[0]
raise ValueError(f"No video stream found in file '{self.__file}'")
def as_trimmed(
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
) -> VideoInput | None:
trimmed = VideoFromFile(
self.get_stream_source(),
start_time=start_time + self.__start_time,
duration=duration,
)
if trimmed.get_duration() < duration and strict_duration:
return None
return trimmed
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput):
@@ -391,7 +322,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate,
frame_rate=self.__components.frame_rate
)
def save_to(
@@ -399,7 +330,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@@ -426,10 +357,7 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -444,21 +372,12 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = True,
) -> VideoInput | None:
if self.get_duration() < start_time + duration:
return None
#TODO Consider tracking duration and trimming at time of save?
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)

View File

@@ -1430,6 +1430,11 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
lazy_outputs: bool=False
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
Use this for nodes that can skip computing outputs that aren't connected downstream.
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
def validate(self):
'''Validate the schema:
@@ -1875,6 +1880,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._ACCEPT_ALL_INPUTS
_LAZY_OUTPUTS = None
@final
@classproperty
def LAZY_OUTPUTS(cls): # noqa
if cls._LAZY_OUTPUTS is None:
cls.GET_SCHEMA()
return cls._LAZY_OUTPUTS
@final
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
@@ -1917,6 +1930,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._NOT_IDEMPOTENT = schema.not_idempotent
if cls._ACCEPT_ALL_INPUTS is None:
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
if cls._LAZY_OUTPUTS is None:
cls._LAZY_OUTPUTS = schema.lazy_outputs
if cls._RETURN_TYPES is None:
output = []

View File

@@ -1197,6 +1197,12 @@ class KlingImageGenImageReferenceType(str, Enum):
face = 'face'
class KlingImageGenModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_5 = 'kling-v1-5'
kling_v2 = 'kling-v2'
class KlingImageGenerationsRequest(BaseModel):
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
callback_url: Optional[AnyUrl] = Field(
@@ -1212,7 +1218,7 @@ class KlingImageGenerationsRequest(BaseModel):
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
)
image_reference: Optional[KlingImageGenImageReferenceType] = None
model_name: str = Field(...)
model_name: Optional[KlingImageGenModelName] = 'kling-v1'
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
negative_prompt: Optional[str] = Field(
None, description='Negative text prompt', max_length=200

View File

@@ -1,22 +1,12 @@
from pydantic import BaseModel, Field
class MultiPromptEntry(BaseModel):
index: int = Field(...)
prompt: str = Field(...)
duration: str = Field(...)
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
sound: str = Field(..., description="'on' or 'off'")
class OmniParamImage(BaseModel):
@@ -36,10 +26,6 @@ class OmniProFirstLastFrameRequest(BaseModel):
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str | None = Field(None, description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class OmniProReferences2VideoRequest(BaseModel):
@@ -52,10 +38,6 @@ class OmniProReferences2VideoRequest(BaseModel):
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str | None = Field(None, description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class TaskStatusVideoResult(BaseModel):
@@ -72,7 +54,6 @@ class TaskStatusImageResult(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
series_images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusResponseData(BaseModel):
@@ -96,42 +77,31 @@ class OmniImageParamImage(BaseModel):
class OmniProImageRequest(BaseModel):
model_name: str = Field(...)
resolution: str = Field(...)
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
result_type: str | None = Field(None, description="Set to 'series' for series generation")
series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(...)
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(...)
prompt: str | None = Field(...)
negative_prompt: str | None = Field(None)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(...)
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
image_tail: str | None = Field(None)
duration: str = Field(...)
prompt: str | None = Field(...)
negative_prompt: str | None = Field(None)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class MotionControlRequest(BaseModel):

File diff suppressed because it is too large Load Diff

View File

@@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
default=33,
min=1,
max=100,
step=1,
tooltip="Number of denoising steps",
@@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=60,
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
default=33,
min=1,
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
@@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
steps=60,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
@@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
default=33,
min=1,
max=100,
step=1,
tooltip="Inference steps",

View File

@@ -143,9 +143,9 @@ async def poll_op(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 10,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 1.4,
retry_backoff_per_poll: float = 2.0,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -240,9 +240,9 @@ async def poll_op_raw(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 10,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 1.4,
retry_backoff_per_poll: float = 2.0,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,

View File

@@ -5,7 +5,7 @@ import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from abc import ABC, abstractmethod
import nodes
@@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet):
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
expected = get_expected_outputs_for_node(dynprompt, node_id)
signature.append(("expected_outputs", tuple(sorted(expected))))
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):

View File

@@ -19,6 +19,15 @@ class NodeInputError(Exception):
class NodeNotFoundError(Exception):
pass
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
"""Get the set of output indices that are connected downstream.
Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip.
"""
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
@@ -27,6 +36,7 @@ class DynamicPrompt:
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
self._expected_outputs_map = None
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
@@ -42,6 +52,7 @@ class DynamicPrompt:
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
self._expected_outputs_map = None
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
@@ -59,6 +70,26 @@ class DynamicPrompt:
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def _build_expected_outputs_map(self):
result = {}
for node_id in self.all_node_ids():
try:
node_data = self.get_node(node_id)
except NodeNotFoundError:
continue
for value in node_data.get("inputs", {}).values():
if is_link(value):
from_node_id, from_socket = value
if from_node_id not in result:
result[from_node_id] = set()
result[from_node_id].add(from_socket)
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
def get_expected_outputs_map(self):
if self._expected_outputs_map is None:
self._build_expected_outputs_map()
return self._expected_outputs_map
def get_original_prompt(self):
return self.original_prompt

View File

@@ -20,60 +20,10 @@ class JobStatus:
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
def has_3d_extension(filename: str) -> bool:
lower = filename.lower()
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
def normalize_output_item(item):
"""Normalize a single output list item for the jobs API.
Returns the normalized item, or None to exclude it.
String items with 3D extensions become {filename, type, subfolder} dicts.
"""
if item is None:
return None
if isinstance(item, str):
if has_3d_extension(item):
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
return None
if isinstance(item, dict):
return item
return None
def normalize_outputs(outputs: dict) -> dict:
"""Normalize raw node outputs for the jobs API.
Transforms string 3D filenames into file output dicts and removes
None items. All other items (non-3D strings, dicts, etc.) are
preserved as-is.
"""
normalized = {}
for node_id, node_outputs in outputs.items():
if not isinstance(node_outputs, dict):
normalized[node_id] = node_outputs
continue
normalized_node = {}
for media_type, items in node_outputs.items():
if media_type == 'animated' or not isinstance(items, list):
normalized_node[media_type] = items
continue
normalized_items = []
for item in items:
if item is None:
continue
norm = normalize_output_item(item)
normalized_items.append(norm if norm is not None else item)
normalized_node[media_type] = normalized_items
normalized[node_id] = normalized_node
return normalized
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
@@ -95,9 +45,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
Maintains backwards compatibility with existing logic.
Priority:
1. media_type is 'images', 'video', 'audio', or '3d'
1. media_type is 'images', 'video', or 'audio'
2. format field starts with 'video/' or 'audio/'
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
@@ -189,7 +139,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
})
if include_outputs:
job['outputs'] = normalize_outputs(outputs)
job['outputs'] = outputs
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
@@ -221,23 +171,18 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
normalized = normalize_output_item(item)
if normalized is None:
continue
count += 1
if preview_output is not None:
if not isinstance(item, dict):
continue
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
if preview_output is None and is_previewable(media_type, item):
enriched = {
**normalized,
**item,
'nodeId': node_id,
'mediaType': media_type
}
if 'mediaType' not in normalized:
enriched['mediaType'] = media_type
if normalized.get('type') == 'output':
if item.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched

View File

@@ -1,23 +1,41 @@
import contextvars
from typing import Optional, NamedTuple
from typing import NamedTuple, FrozenSet
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
Attributes:
prompt_id: The ID of the current prompt execution
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
expected_outputs: Set of output indices that might be used downstream.
Outputs NOT in this set are definitely unused (safe to skip).
None means the information is not available.
"""
prompt_id: str
node_id: str
list_index: Optional[int]
list_index: int | None
expected_outputs: FrozenSet[int] | None = None
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]:
def get_executing_context() -> ExecutionContext | None:
return current_executing_context.get(None)
def is_output_needed(output_index: int) -> bool:
"""Check if an output at the given index is connected downstream.
Returns True if the output might be used (should be computed).
Returns False if the output is definitely not connected (safe to skip).
"""
ctx = get_executing_context()
if ctx is None or ctx.expected_outputs is None:
return True
return output_index in ctx.expected_outputs
class CurrentNodeContext:
"""
Context manager for setting the current executing node context.
@@ -25,15 +43,22 @@ class CurrentNodeContext:
Sets the current_executing_context on enter and resets it on exit.
Example:
with CurrentNodeContext(node_id="123", list_index=0):
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
def __init__(
self,
prompt_id: str,
node_id: str,
list_index: int | None = None,
expected_outputs: FrozenSet[int] | None = None,
):
self.context = ExecutionContext(
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
prompt_id=prompt_id,
node_id=node_id,
list_index=list_index,
expected_outputs=expected_outputs,
)
self.token = None

View File

@@ -4,7 +4,6 @@ import os
import numpy as np
import safetensors
import torch
import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont
@@ -28,11 +27,6 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample(
self,
noise,
@@ -51,11 +45,9 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape,
self.conds,
self.model_options,
force_full_load=not self.offloading,
force_offload=self.offloading,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
)
)
torch.cuda.empty_cache()
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -412,97 +404,16 @@ def find_all_highest_child_module_with_forward(
return result
def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
@@ -1025,18 +936,6 @@ class TrainLoraNode(io.ComfyNode):
default=True,
tooltip="Use gradient checkpointing for training.",
),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
@@ -1083,8 +982,6 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype,
algorithm,
gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora,
bucket_mode,
bypass_mode,
@@ -1103,8 +1000,6 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
@@ -1159,18 +1054,16 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing
if gradient_checkpointing:
modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model, depth=checkpoint_depth
)
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=not offloading
[mp], memory_required=1e20, force_full_load=True
)
torch.cuda.empty_cache()
@@ -1207,7 +1100,7 @@ class TrainLoraNode(io.ComfyNode):
)
# Setup guider
guider = TrainGuider(mp, offloading=offloading)
guider = TrainGuider(mp)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
@@ -1220,7 +1113,6 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop
try:
comfy.model_management.in_training = True
_run_training_loop(
guider,
train_sampler,
@@ -1231,7 +1123,6 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
@@ -1241,20 +1132,19 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m)
del train_sampler, optimizer
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
del adapter
del all_weight_adapters
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):
class LoraModelLoader(io.ComfyNode):#
@classmethod
def define_schema(cls):
return io.Schema(
@@ -1276,11 +1166,6 @@ class LoraModelLoader(io.ComfyNode):
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
],
outputs=[
io.Model.Output(
@@ -1290,18 +1175,13 @@ class LoraModelLoader(io.ComfyNode):
)
@classmethod
def execute(cls, model, lora, strength_model, bypass=False):
def execute(cls, model, lora, strength_model):
if strength_model == 0:
return io.NodeOutput(model)
if bypass:
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)

View File

@@ -202,56 +202,6 @@ class LoadVideo(io.ComfyNode):
return True
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
category="image/video",
inputs=[
io.Video.Input("video"),
io.Float.Input(
"start_time",
default=0.0,
max=1e5,
min=-1e5,
step=0.001,
tooltip="Start time in seconds",
),
io.Float.Input(
"duration",
default=0.0,
min=0.0,
step=0.001,
tooltip="Duration in seconds, or 0 for unlimited duration",
),
io.Boolean.Input(
"strict_duration",
default=False,
tooltip="If True, when the specified duration is not possible, an error will be raised.",
),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
if trimmed is not None:
return io.NodeOutput(trimmed)
raise ValueError(
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
)
class VideoExtension(ComfyExtension):
@override
@@ -262,7 +212,6 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
VideoSlice,
]
async def comfy_entrypoint() -> VideoExtension:

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.13.0"
__version__ = "0.12.3"

View File

@@ -13,11 +13,8 @@ from contextlib import nullcontext
import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
@@ -34,6 +31,7 @@ from comfy_execution.graph import (
ExecutionBlocker,
ExecutionList,
get_input_info,
get_expected_outputs_for_node,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
@@ -230,7 +228,18 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
async def _async_map_node_over_list(
prompt_id,
unique_id,
obj,
input_data_all,
func,
allow_interrupt=False,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -280,10 +289,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
else:
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
task = asyncio.create_task(
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
)
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
@@ -292,7 +303,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
else:
results.append(task)
else:
with CurrentNodeContext(prompt_id, unique_id, index):
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
result = f(**inputs)
results.append(result)
else:
@@ -330,8 +341,17 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
async def get_output_data(
prompt_id,
unique_id,
obj,
input_data_all,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@@ -525,15 +545,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
#that we just want to cull out each model run.
allocator = comfy.memory_management.aimdo_allocator
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
try:
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
finally:
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
torch.cuda.synchronize()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.13.0"
version = "0.12.3"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.13
comfyui-workflow-templates==0.8.38
comfyui-workflow-templates==0.8.31
comfyui-embedded-docs==0.4.1
torch
torchsde
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.1.8
comfy-aimdo>=0.1.7
requests
#non essential dependencies:

View File

@@ -0,0 +1,322 @@
"""Unit tests for the expected_outputs feature.
This feature allows nodes to know at runtime which outputs are connected downstream,
enabling them to skip computing outputs that aren't needed.
"""
from comfy_api.latest import IO
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from comfy_execution.utils import (
CurrentNodeContext,
ExecutionContext,
get_executing_context,
is_output_needed,
)
class TestGetExpectedOutputsForNode:
"""Tests for get_expected_outputs_for_node() function."""
def test_single_output_connected(self):
"""Test node with single output connected to one downstream node."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_multiple_outputs_partial_connected(self):
"""Test node with multiple outputs, only some connected."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
# Output 1 is not connected
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 2})
assert 1 not in expected # Output 1 is definitely unused
def test_no_outputs_connected(self):
"""Test node with no outputs connected."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "OtherNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset()
def test_same_output_connected_multiple_times(self):
"""Test same output connected to multiple downstream nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_node_not_in_prompt(self):
"""Test getting expected outputs for a node not in the prompt."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "999")
assert expected == frozenset()
def test_chained_nodes(self):
"""Test expected outputs in a chain of nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1's output 0 is connected to node 2
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
assert expected_1 == frozenset({0})
# Node 2's output 0 is connected to node 3
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
assert expected_2 == frozenset({0})
# Node 3 has no downstream connections
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
assert expected_3 == frozenset()
def test_complex_graph(self):
"""Test expected outputs in a complex graph with multiple connections."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1 has outputs 0, 1, 2 all connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1, 2})
def test_constant_inputs_ignored(self):
"""Test that constant (non-link) inputs don't affect expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {
"class_type": "ConsumerNode",
"inputs": {
"image": ["1", 0],
"value": 42,
"name": "test",
},
},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_ephemeral_node_invalidates_cache(self):
"""Test that adding ephemeral nodes updates expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Initially only output 0 is connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
# Add an ephemeral node that connects to output 1
dynprompt.add_ephemeral_node(
"eph_1",
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
parent_id="2",
display_id="2",
)
# Now both outputs 0 and 1 should be expected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1})
class TestExecutionContext:
"""Tests for ExecutionContext with expected_outputs field."""
def test_context_with_expected_outputs(self):
"""Test creating ExecutionContext with expected_outputs."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
)
assert ctx.prompt_id == "prompt-123"
assert ctx.node_id == "node-456"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 2})
def test_context_without_expected_outputs(self):
"""Test ExecutionContext defaults to None for expected_outputs."""
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
assert ctx.expected_outputs is None
def test_context_empty_expected_outputs(self):
"""Test ExecutionContext with empty expected_outputs set."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
)
assert ctx.expected_outputs == frozenset()
assert len(ctx.expected_outputs) == 0
class TestCurrentNodeContext:
"""Tests for CurrentNodeContext context manager with expected_outputs."""
def test_context_manager_with_expected_outputs(self):
"""Test CurrentNodeContext sets and resets context correctly."""
assert get_executing_context() is None
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
ctx = get_executing_context()
assert ctx is not None
assert ctx.prompt_id == "prompt-1"
assert ctx.node_id == "node-1"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 1})
assert get_executing_context() is None
def test_context_manager_without_expected_outputs(self):
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
with CurrentNodeContext("prompt-1", "node-1"):
ctx = get_executing_context()
assert ctx is not None
assert ctx.expected_outputs is None
def test_nested_context_managers(self):
"""Test nested CurrentNodeContext managers."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
ctx1 = get_executing_context()
assert ctx1.expected_outputs == frozenset({0})
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
ctx2 = get_executing_context()
assert ctx2.expected_outputs == frozenset({1, 2})
assert ctx2.node_id == "node-2"
# After inner context exits, should be back to outer context
ctx1_again = get_executing_context()
assert ctx1_again.expected_outputs == frozenset({0})
assert ctx1_again.node_id == "node-1"
def test_output_check_pattern(self):
"""Test the typical pattern nodes will use to check expected outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
ctx = get_executing_context()
# Typical usage pattern
if ctx and ctx.expected_outputs is not None:
should_compute_0 = 0 in ctx.expected_outputs
should_compute_1 = 1 in ctx.expected_outputs
should_compute_2 = 2 in ctx.expected_outputs
else:
# Fallback when info not available
should_compute_0 = should_compute_1 = should_compute_2 = True
assert should_compute_0 is True
assert should_compute_1 is False # Not in expected_outputs
assert should_compute_2 is True
class TestSchemaLazyOutputs:
"""Tests for lazy_outputs in V3 Schema."""
def test_schema_lazy_outputs_default(self):
"""Test that lazy_outputs defaults to False."""
schema = IO.Schema(
node_id="TestNode",
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is False
def test_schema_lazy_outputs_true(self):
"""Test setting lazy_outputs to True."""
schema = IO.Schema(
node_id="TestNode",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is True
def test_v3_node_lazy_outputs_property(self):
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
class TestNodeWithLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithLazyOutputs",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
def test_v3_node_lazy_outputs_default(self):
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithoutLazyOutputs",
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
class TestIsOutputNeeded:
"""Tests for is_output_needed() helper function."""
def test_output_needed_when_in_expected(self):
"""Test that output is needed when in expected_outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
assert is_output_needed(0) is True
assert is_output_needed(2) is True
def test_output_not_needed_when_not_in_expected(self):
"""Test that output is not needed when not in expected_outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
assert is_output_needed(1) is False
assert is_output_needed(3) is False
def test_output_needed_when_no_context(self):
"""Test that output is needed when no context."""
assert get_executing_context() is None
assert is_output_needed(0) is True
assert is_output_needed(1) is True
def test_output_needed_when_expected_outputs_is_none(self):
"""Test that output is needed when expected_outputs is None."""
with CurrentNodeContext("prompt-1", "node-1", 0, None):
assert is_output_needed(0) is True
assert is_output_needed(1) is True

View File

@@ -574,6 +574,104 @@ class TestExecution:
else:
assert result.did_run(test_node), "The execution should have been re-run"
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs contains all connected outputs."""
g = builder
# Create a node with 3 outputs, all connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Connect all 3 outputs to preview nodes
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result = client.run(g)
# All outputs should be white (255) since all are connected
images0 = result.get_images(output0)
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images0) == 1, "Should have 1 image for output0"
assert len(images1) == 1, "Should have 1 image for output1"
assert len(images2) == 1, "Should have 1 image for output2"
# White pixels = 255, meaning output was in expected_outputs
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs only contains connected outputs."""
g = builder
# Create a node with 3 outputs, only some connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Only connect outputs 0 and 2, leave output 1 disconnected
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
# output1 is intentionally not connected
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result = client.run(g)
# Connected outputs should be white (255)
images0 = result.get_images(output0)
images2 = result.get_images(output2)
assert len(images0) == 1, "Should have 1 image for output0"
assert len(images2) == 1, "Should have 1 image for output2"
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs works with single connected output."""
g = builder
# Create a node with 3 outputs, only one connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Only connect output 1
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
result = client.run(g)
images1 = result.get_images(output1)
assert len(images1) == 1, "Should have 1 image for output1"
# Output 1 should be white (connected), others are not visible in this test
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
"""Test that cache invalidates when output connections change."""
g = builder
# Use unique dimensions to avoid cache collision with other expected_outputs tests
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
# First run: only connect output 0
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
result1 = client.run(g)
assert result1.did_run(expected_outputs_node), "First run should execute the node"
# Second run: same connections, should be cached
result2 = client.run(g)
if server["should_cache_results"]:
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
# Third run: add connection to output 2
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result3 = client.run(g)
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
if server["should_cache_results"]:
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
# Verify both outputs are now white
images0 = result3.get_images(output0)
images2 = result3.get_images(output2)
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized

View File

@@ -5,11 +5,8 @@ from comfy_execution.jobs import (
is_previewable,
normalize_queue_item,
normalize_history_item,
normalize_output_item,
normalize_outputs,
get_outputs_summary,
apply_sorting,
has_3d_extension,
)
@@ -38,8 +35,8 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio, 3d media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d']:
"""Images, video, audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
@@ -49,7 +46,7 @@ class TestIsPreviewable:
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
@@ -163,7 +160,7 @@ class TestGetOutputsSummary:
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
@@ -195,64 +192,6 @@ class TestGetOutputsSummary:
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
def test_string_3d_filename_creates_preview(self):
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
Only the .glb counts — nulls and non-file strings are excluded."""
outputs = {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 1
assert preview is not None
assert preview['filename'] == 'preview3d_abc123.glb'
assert preview['mediaType'] == '3d'
assert preview['nodeId'] == 'node1'
assert preview['type'] == 'output'
def test_string_non_3d_filename_no_preview(self):
"""String items without 3D extensions should not create a preview."""
outputs = {
'node1': {
'result': ['data.json', None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 0
assert preview is None
def test_string_3d_filename_used_as_fallback(self):
"""String 3D preview should be used when no dict items are previewable."""
outputs = {
'node1': {
'latents': [{'filename': 'latent.safetensors'}],
},
'node2': {
'result': ['model.glb', None]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None
assert preview['filename'] == 'model.glb'
assert preview['mediaType'] == '3d'
class TestHas3DExtension:
"""Unit tests for has_3d_extension()"""
def test_recognized_extensions(self):
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
assert has_3d_extension(f'model{ext}') is True
def test_case_insensitive(self):
assert has_3d_extension('MODEL.GLB') is True
assert has_3d_extension('Scene.GLTF') is True
def test_non_3d_extensions(self):
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
assert has_3d_extension(name) is False
class TestApplySorting:
"""Unit tests for apply_sorting()"""
@@ -456,142 +395,3 @@ class TestNormalizeHistoryItem:
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}
def test_include_outputs_normalizes_3d_strings(self):
"""Detail view should transform string 3D filenames into file output dicts."""
history_item = {
'prompt': (
5,
'prompt-3d',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
},
}
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
assert job['outputs_count'] == 1
result_items = job['outputs']['node1']['result']
assert len(result_items) == 1
assert result_items[0] == {
'filename': 'preview3d_abc123.glb',
'type': 'output',
'subfolder': '',
'mediaType': '3d',
}
def test_include_outputs_preserves_dict_items(self):
"""Detail view normalization should pass dict items through unchanged."""
history_item = {
'prompt': (
5,
'prompt-img',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'images': [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
}
},
}
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
assert job['outputs_count'] == 1
assert job['outputs']['node1']['images'] == [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
class TestNormalizeOutputItem:
"""Unit tests for normalize_output_item()"""
def test_none_returns_none(self):
assert normalize_output_item(None) is None
def test_string_3d_extension_synthesizes_dict(self):
result = normalize_output_item('model.glb')
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
def test_string_non_3d_extension_returns_none(self):
assert normalize_output_item('data.json') is None
def test_string_no_extension_returns_none(self):
assert normalize_output_item('camera_info_string') is None
def test_dict_passes_through(self):
item = {'filename': 'test.png', 'type': 'output'}
assert normalize_output_item(item) is item
def test_other_types_return_none(self):
assert normalize_output_item(42) is None
assert normalize_output_item(True) is None
class TestNormalizeOutputs:
"""Unit tests for normalize_outputs()"""
def test_empty_outputs(self):
assert normalize_outputs({}) == {}
def test_dict_items_pass_through(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
}
}
result = normalize_outputs(outputs)
assert result == outputs
def test_3d_string_synthesized(self):
outputs = {
'node1': {
'result': ['model.glb', None, None],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': [
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
],
}
}
def test_animated_key_preserved(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
'animated': [True],
}
}
result = normalize_outputs(outputs)
assert result['node1']['animated'] == [True]
def test_non_dict_node_outputs_preserved(self):
outputs = {'node1': 'unexpected_value'}
result = normalize_outputs(outputs)
assert result == {'node1': 'unexpected_value'}
def test_none_items_filtered_but_other_types_preserved(self):
outputs = {
'node1': {
'result': ['data.json', None, [1, 2, 3]],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': ['data.json', [1, 2, 3]],
}
}

View File

@@ -6,6 +6,7 @@ from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
from comfy_execution.utils import get_executing_context
class TestLazyMixImages:
@classmethod
@@ -482,6 +483,57 @@ class TestOutputNodeWithSocketOutput:
result = image * value
return (result,)
class TestExpectedOutputs:
"""Test node for the expected_outputs feature.
This node has 3 IMAGE outputs that encode which outputs were expected:
- White image (255) if the output was in expected_outputs
- Black image (0) if the output was NOT in expected_outputs
This allows integration tests to verify which outputs were expected by checking pixel values.
"""
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("output0", "output1", "output2")
FUNCTION = "execute"
CATEGORY = "_for_testing"
def execute(self, height, width):
ctx = get_executing_context()
# Default: assume all outputs are expected (backwards compatibility)
output0_expected = True
output1_expected = True
output2_expected = True
if ctx is not None and ctx.expected_outputs is not None:
output0_expected = 0 in ctx.expected_outputs
output1_expected = 1 in ctx.expected_outputs
output2_expected = 2 in ctx.expected_outputs
# Return white image if expected, black if not
# This allows tests to verify which outputs were expected via pixel values
white = torch.ones(1, height, width, 3)
black = torch.zeros(1, height, width, 3)
return (
white if output0_expected else black,
white if output1_expected else black,
white if output2_expected else black,
)
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
@@ -498,6 +550,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
"TestExpectedOutputs": TestExpectedOutputs,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@@ -516,4 +569,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
"TestExpectedOutputs": "Test Expected Outputs",
}