mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 03:30:01 +00:00
Compare commits
1 Commits
feat/core/
...
v3/model_m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac1073be99 |
@@ -1110,7 +1110,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
|
||||
return encoder_hidden, encoder_mask, context_latents
|
||||
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
|
||||
text_attention_mask = None
|
||||
lyric_attention_mask = None
|
||||
refer_audio_order_mask = None
|
||||
@@ -1140,9 +1140,6 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
||||
)
|
||||
|
||||
if replace_with_null_embeds:
|
||||
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
|
||||
|
||||
out = self.decoder(hidden_states=x,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@@ -463,8 +463,6 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@@ -514,7 +512,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -524,7 +522,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@@ -538,7 +536,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -557,7 +555,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@@ -565,8 +563,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@@ -878,14 +876,6 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@@ -894,6 +884,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@@ -1552,8 +1552,6 @@ class ACEStep15(BaseModel):
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if torch.count_nonzero(cross_attn) == 0:
|
||||
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
@@ -1577,10 +1575,6 @@ class ACEStep15(BaseModel):
|
||||
else:
|
||||
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
|
||||
if refer_audio.shape[2] < noise.shape[2]:
|
||||
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
return out
|
||||
|
||||
|
||||
@@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
@@ -1023,7 +1023,11 @@ class Anima(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Anima(self, device=device)
|
||||
@@ -1034,12 +1038,6 @@ class Anima(supported_models_base.BASE):
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
|
||||
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
if dtype is torch.float16:
|
||||
self.memory_usage_factor *= 1.4
|
||||
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
|
||||
@@ -23,7 +23,7 @@ class AnimaTokenizer:
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
@@ -1430,11 +1430,6 @@ class Schema:
|
||||
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
||||
accept_all_inputs: bool=False
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
lazy_outputs: bool=False
|
||||
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
||||
|
||||
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
||||
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@@ -1880,14 +1875,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._ACCEPT_ALL_INPUTS
|
||||
|
||||
_LAZY_OUTPUTS = None
|
||||
@final
|
||||
@classproperty
|
||||
def LAZY_OUTPUTS(cls): # noqa
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._LAZY_OUTPUTS
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
@@ -1930,8 +1917,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._NOT_IDEMPOTENT = schema.not_idempotent
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls._LAZY_OUTPUTS = schema.lazy_outputs
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
|
||||
@@ -5,7 +5,7 @@ import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
@@ -115,10 +115,6 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
|
||||
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
|
||||
expected = get_expected_outputs_for_node(dynprompt, node_id)
|
||||
signature.append(("expected_outputs", tuple(sorted(expected))))
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
|
||||
@@ -19,15 +19,6 @@ class NodeInputError(Exception):
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
"""Get the set of output indices that are connected downstream.
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||
"""
|
||||
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
@@ -36,7 +27,6 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
@@ -52,7 +42,6 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
@@ -70,26 +59,6 @@ class DynamicPrompt:
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def _build_expected_outputs_map(self):
|
||||
result = {}
|
||||
for node_id in self.all_node_ids():
|
||||
try:
|
||||
node_data = self.get_node(node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
for value in node_data.get("inputs", {}).values():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id not in result:
|
||||
result[from_node_id] = set()
|
||||
result[from_node_id].add(from_socket)
|
||||
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||
|
||||
def get_expected_outputs_map(self):
|
||||
if self._expected_outputs_map is None:
|
||||
self._build_expected_outputs_map()
|
||||
return self._expected_outputs_map
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
@@ -1,41 +1,23 @@
|
||||
import contextvars
|
||||
from typing import NamedTuple, FrozenSet
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
prompt_id: The ID of the current prompt execution
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
expected_outputs: Set of output indices that might be used downstream.
|
||||
Outputs NOT in this set are definitely unused (safe to skip).
|
||||
None means the information is not available.
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: int | None
|
||||
expected_outputs: FrozenSet[int] | None = None
|
||||
list_index: Optional[int]
|
||||
|
||||
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> ExecutionContext | None:
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
|
||||
def is_output_needed(output_index: int) -> bool:
|
||||
"""Check if an output at the given index is connected downstream.
|
||||
|
||||
Returns True if the output might be used (should be computed).
|
||||
Returns False if the output is definitely not connected (safe to skip).
|
||||
"""
|
||||
ctx = get_executing_context()
|
||||
if ctx is None or ctx.expected_outputs is None:
|
||||
return True
|
||||
return output_index in ctx.expected_outputs
|
||||
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
@@ -43,22 +25,15 @@ class CurrentNodeContext:
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
prompt_id: str,
|
||||
node_id: str,
|
||||
list_index: int | None = None,
|
||||
expected_outputs: FrozenSet[int] | None = None,
|
||||
):
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index,
|
||||
expected_outputs=expected_outputs,
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
)
|
||||
self.token = None
|
||||
|
||||
|
||||
@@ -622,7 +622,6 @@ class SamplerSASolver(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSASolver",
|
||||
search_aliases=["sde"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@@ -667,7 +666,6 @@ class SamplerSEEDS2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
search_aliases=["sde", "exp heun"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
|
||||
@@ -108,7 +108,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
|
||||
@@ -391,9 +391,8 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
dims = list(range(1, latent_vector_magnitude.ndim))
|
||||
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
|
||||
@@ -10,146 +10,198 @@ import json
|
||||
import os
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
class ModelMergeSimple:
|
||||
|
||||
class ModelMergeSimple(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",),
|
||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "merge"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ModelMergeSimple",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, model1, model2, ratio):
|
||||
@classmethod
|
||||
def execute(cls, model1, model2, ratio) -> io.NodeOutput:
|
||||
m = model1.clone()
|
||||
kp = model2.get_key_patches("diffusion_model.")
|
||||
for k in kp:
|
||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class ModelSubtract:
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class ModelSubtract(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",),
|
||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "merge"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ModelMergeSubtract",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, model1, model2, multiplier):
|
||||
@classmethod
|
||||
def execute(cls, model1, model2, multiplier) -> io.NodeOutput:
|
||||
m = model1.clone()
|
||||
kp = model2.get_key_patches("diffusion_model.")
|
||||
for k in kp:
|
||||
m.add_patches({k: kp[k]}, - multiplier, multiplier)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class ModelAdd:
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class ModelAdd(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "merge"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ModelMergeAdd",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, model1, model2):
|
||||
@classmethod
|
||||
def execute(cls, model1, model2) -> io.NodeOutput:
|
||||
m = model1.clone()
|
||||
kp = model2.get_key_patches("diffusion_model.")
|
||||
for k in kp:
|
||||
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class CLIPMergeSimple:
|
||||
class CLIPMergeSimple(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
"clip2": ("CLIP",),
|
||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "merge"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPMergeSimple",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Clip.Input("clip1"),
|
||||
io.Clip.Input("clip2"),
|
||||
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Clip.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, clip1, clip2, ratio):
|
||||
@classmethod
|
||||
def execute(cls, clip1, clip2, ratio) -> io.NodeOutput:
|
||||
m = clip1.clone()
|
||||
kp = clip2.get_key_patches()
|
||||
for k in kp:
|
||||
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||
continue
|
||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class CLIPSubtract:
|
||||
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
|
||||
class CLIPSubtract(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
"clip2": ("CLIP",),
|
||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "merge"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPMergeSubtract",
|
||||
search_aliases=["clip difference", "text encoder subtract"],
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Clip.Input("clip1"),
|
||||
io.Clip.Input("clip2"),
|
||||
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Clip.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, clip1, clip2, multiplier):
|
||||
@classmethod
|
||||
def execute(cls, clip1, clip2, multiplier) -> io.NodeOutput:
|
||||
m = clip1.clone()
|
||||
kp = clip2.get_key_patches()
|
||||
for k in kp:
|
||||
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||
continue
|
||||
m.add_patches({k: kp[k]}, - multiplier, multiplier)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class CLIPAdd:
|
||||
SEARCH_ALIASES = ["combine clip"]
|
||||
class CLIPAdd(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
"clip2": ("CLIP",),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "merge"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPMergeAdd",
|
||||
search_aliases=["combine clip"],
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Clip.Input("clip1"),
|
||||
io.Clip.Input("clip2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Clip.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, clip1, clip2):
|
||||
@classmethod
|
||||
def execute(cls, clip1, clip2) -> io.NodeOutput:
|
||||
m = clip1.clone()
|
||||
kp = clip2.get_key_patches()
|
||||
for k in kp:
|
||||
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||
continue
|
||||
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class ModelMergeBlocks:
|
||||
class ModelMergeBlocks(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",),
|
||||
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "merge"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ModelMergeBlocks",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, model1, model2, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, model1, model2, **kwargs) -> io.NodeOutput:
|
||||
m = model1.clone()
|
||||
kp = model2.get_key_patches("diffusion_model.")
|
||||
default_ratio = next(iter(kwargs.values()))
|
||||
@@ -165,7 +217,10 @@ class ModelMergeBlocks:
|
||||
last_arg_size = len(arg)
|
||||
|
||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
|
||||
@@ -226,59 +281,65 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||
|
||||
class CheckpointSave:
|
||||
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
class CheckpointSave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CheckpointSave",
|
||||
display_name="Save Checkpoint",
|
||||
search_aliases=["save model", "export checkpoint", "merge save"],
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Clip.Input("clip"),
|
||||
io.Vae.Input("vae"),
|
||||
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
def execute(cls, model, clip, vae, filename_prefix) -> io.NodeOutput:
|
||||
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
|
||||
return io.NodeOutput()
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
save = execute # TODO: remove
|
||||
|
||||
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||
return {}
|
||||
|
||||
class CLIPSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
class CLIPSave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPSave",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("filename_prefix", default="clip/ComfyUI"),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip": ("CLIP",),
|
||||
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
def execute(cls, clip, filename_prefix) -> io.NodeOutput:
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
if cls.hidden.prompt is not None:
|
||||
prompt_info = json.dumps(cls.hidden.prompt)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["format"] = "pt"
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
output_dir = folder_paths.get_output_directory()
|
||||
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||
current_clip_sd = {}
|
||||
@@ -295,7 +356,7 @@ class CLIPSave:
|
||||
replace_prefix[prefix] = ""
|
||||
replace_prefix["transformer."] = ""
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, output_dir)
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
@@ -303,76 +364,88 @@ class CLIPSave:
|
||||
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
||||
|
||||
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
return io.NodeOutput()
|
||||
|
||||
class VAESave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
save = execute # TODO: remove
|
||||
|
||||
|
||||
class VAESave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VAESave",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.String.Input("filename_prefix", default="vae/ComfyUI_vae"),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
def execute(cls, vae, filename_prefix) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
if cls.hidden.prompt is not None:
|
||||
prompt_info = json.dumps(cls.hidden.prompt)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
return io.NodeOutput()
|
||||
|
||||
class ModelSave:
|
||||
SEARCH_ALIASES = ["export model", "checkpoint save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
save = execute # TODO: remove
|
||||
|
||||
|
||||
class ModelSave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ModelSave",
|
||||
search_aliases=["export model", "checkpoint save"],
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.String.Input("filename_prefix", default="diffusion_models/ComfyUI"),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
def execute(cls, model, filename_prefix) -> io.NodeOutput:
|
||||
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
|
||||
return io.NodeOutput()
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
save = execute # TODO: remove
|
||||
|
||||
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSimple": ModelMergeSimple,
|
||||
"ModelMergeBlocks": ModelMergeBlocks,
|
||||
"ModelMergeSubtract": ModelSubtract,
|
||||
"ModelMergeAdd": ModelAdd,
|
||||
"CheckpointSave": CheckpointSave,
|
||||
"CLIPMergeSimple": CLIPMergeSimple,
|
||||
"CLIPMergeSubtract": CLIPSubtract,
|
||||
"CLIPMergeAdd": CLIPAdd,
|
||||
"CLIPSave": CLIPSave,
|
||||
"VAESave": VAESave,
|
||||
"ModelSave": ModelSave,
|
||||
}
|
||||
class ModelMergingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ModelMergeSimple,
|
||||
ModelMergeBlocks,
|
||||
ModelSubtract,
|
||||
ModelAdd,
|
||||
CheckpointSave,
|
||||
CLIPMergeSimple,
|
||||
CLIPSubtract,
|
||||
CLIPAdd,
|
||||
CLIPSave,
|
||||
VAESave,
|
||||
ModelSave,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointSave": "Save Checkpoint",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> ModelMergingExtension:
|
||||
return ModelMergingExtension()
|
||||
|
||||
@@ -1,356 +1,455 @@
|
||||
import comfy_extras.nodes_model_merging
|
||||
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["time_embed."] = argument
|
||||
arg_dict["label_emb."] = argument
|
||||
inputs.append(io.Float.Input("time_embed.", **argument))
|
||||
inputs.append(io.Float.Input("label_emb.", **argument))
|
||||
|
||||
for i in range(12):
|
||||
arg_dict["input_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("input_blocks.{}.".format(i), **argument))
|
||||
|
||||
for i in range(3):
|
||||
arg_dict["middle_block.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("middle_block.{}.".format(i), **argument))
|
||||
|
||||
for i in range(12):
|
||||
arg_dict["output_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("output_blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["out."] = argument
|
||||
inputs.append(io.Float.Input("out.", **argument))
|
||||
|
||||
return {"required": arg_dict}
|
||||
return io.Schema(
|
||||
node_id="ModelMergeSD1",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
|
||||
class ModelMergeSD2(ModelMergeSD1):
|
||||
# SD1 and SD2 have the same blocks
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
schema = ModelMergeSD1.define_schema()
|
||||
schema.node_id = "ModelMergeSD2"
|
||||
return schema
|
||||
|
||||
|
||||
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["time_embed."] = argument
|
||||
arg_dict["label_emb."] = argument
|
||||
inputs.append(io.Float.Input("time_embed.", **argument))
|
||||
inputs.append(io.Float.Input("label_emb.", **argument))
|
||||
|
||||
for i in range(9):
|
||||
arg_dict["input_blocks.{}".format(i)] = argument
|
||||
inputs.append(io.Float.Input("input_blocks.{}".format(i), **argument))
|
||||
|
||||
for i in range(3):
|
||||
arg_dict["middle_block.{}".format(i)] = argument
|
||||
inputs.append(io.Float.Input("middle_block.{}".format(i), **argument))
|
||||
|
||||
for i in range(9):
|
||||
arg_dict["output_blocks.{}".format(i)] = argument
|
||||
inputs.append(io.Float.Input("output_blocks.{}".format(i), **argument))
|
||||
|
||||
arg_dict["out."] = argument
|
||||
inputs.append(io.Float.Input("out.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeSDXL",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["pos_embed."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["context_embedder."] = argument
|
||||
arg_dict["y_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
inputs.append(io.Float.Input("pos_embed.", **argument))
|
||||
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("context_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("y_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
|
||||
for i in range(24):
|
||||
arg_dict["joint_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return {"required": arg_dict}
|
||||
return io.Schema(
|
||||
node_id="ModelMergeSD3_2B",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
|
||||
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["init_x_linear."] = argument
|
||||
arg_dict["positional_encoding"] = argument
|
||||
arg_dict["cond_seq_linear."] = argument
|
||||
arg_dict["register_tokens"] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
inputs.append(io.Float.Input("init_x_linear.", **argument))
|
||||
inputs.append(io.Float.Input("positional_encoding", **argument))
|
||||
inputs.append(io.Float.Input("cond_seq_linear.", **argument))
|
||||
inputs.append(io.Float.Input("register_tokens", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
|
||||
for i in range(4):
|
||||
arg_dict["double_layers.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("double_layers.{}.".format(i), **argument))
|
||||
|
||||
for i in range(32):
|
||||
arg_dict["single_layers.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("single_layers.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["modF."] = argument
|
||||
arg_dict["final_linear."] = argument
|
||||
inputs.append(io.Float.Input("modF.", **argument))
|
||||
inputs.append(io.Float.Input("final_linear.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeAuraflow",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["img_in."] = argument
|
||||
arg_dict["time_in."] = argument
|
||||
arg_dict["guidance_in"] = argument
|
||||
arg_dict["vector_in."] = argument
|
||||
arg_dict["txt_in."] = argument
|
||||
inputs.append(io.Float.Input("img_in.", **argument))
|
||||
inputs.append(io.Float.Input("time_in.", **argument))
|
||||
inputs.append(io.Float.Input("guidance_in", **argument))
|
||||
inputs.append(io.Float.Input("vector_in.", **argument))
|
||||
inputs.append(io.Float.Input("txt_in.", **argument))
|
||||
|
||||
for i in range(19):
|
||||
arg_dict["double_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("double_blocks.{}.".format(i), **argument))
|
||||
|
||||
for i in range(38):
|
||||
arg_dict["single_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("single_blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeFlux1",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["pos_embed."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["context_embedder."] = argument
|
||||
arg_dict["y_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
inputs.append(io.Float.Input("pos_embed.", **argument))
|
||||
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("context_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("y_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
|
||||
for i in range(38):
|
||||
arg_dict["joint_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeSD35_Large",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["pos_frequencies."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t5_y_embedder."] = argument
|
||||
arg_dict["t5_yproj."] = argument
|
||||
inputs.append(io.Float.Input("pos_frequencies.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t5_y_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t5_yproj.", **argument))
|
||||
|
||||
for i in range(48):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeMochiPreview",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["patchify_proj."] = argument
|
||||
arg_dict["adaln_single."] = argument
|
||||
arg_dict["caption_projection."] = argument
|
||||
inputs.append(io.Float.Input("patchify_proj.", **argument))
|
||||
inputs.append(io.Float.Input("adaln_single.", **argument))
|
||||
inputs.append(io.Float.Input("caption_projection.", **argument))
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["scale_shift_table"] = argument
|
||||
arg_dict["proj_out."] = argument
|
||||
inputs.append(io.Float.Input("scale_shift_table", **argument))
|
||||
inputs.append(io.Float.Input("proj_out.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeLTXV",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["extra_pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["affline_norm."] = argument
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("affline_norm.", **argument))
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.block{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeCosmos7B",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["extra_pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["affline_norm."] = argument
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("affline_norm.", **argument))
|
||||
|
||||
for i in range(36):
|
||||
arg_dict["blocks.block{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeCosmos14B",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["patch_embedding."] = argument
|
||||
arg_dict["time_embedding."] = argument
|
||||
arg_dict["time_projection."] = argument
|
||||
arg_dict["text_embedding."] = argument
|
||||
arg_dict["img_emb."] = argument
|
||||
inputs.append(io.Float.Input("patch_embedding.", **argument))
|
||||
inputs.append(io.Float.Input("time_embedding.", **argument))
|
||||
inputs.append(io.Float.Input("time_projection.", **argument))
|
||||
inputs.append(io.Float.Input("text_embedding.", **argument))
|
||||
inputs.append(io.Float.Input("img_emb.", **argument))
|
||||
|
||||
for i in range(40):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["head."] = argument
|
||||
inputs.append(io.Float.Input("head.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeWAN2_1",
|
||||
category="advanced/model_merging/model_specific",
|
||||
description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t_embedding_norm."] = argument
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeCosmosPredict2_2B",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t_embedding_norm."] = argument
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
|
||||
|
||||
for i in range(36):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||
|
||||
return io.Schema(
|
||||
node_id="ModelMergeCosmosPredict2_14B",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
def define_schema(cls):
|
||||
inputs = [
|
||||
io.Model.Input("model1"),
|
||||
io.Model.Input("model2"),
|
||||
]
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
arg_dict["pos_embeds."] = argument
|
||||
arg_dict["img_in."] = argument
|
||||
arg_dict["txt_norm."] = argument
|
||||
arg_dict["txt_in."] = argument
|
||||
arg_dict["time_text_embed."] = argument
|
||||
inputs.append(io.Float.Input("pos_embeds.", **argument))
|
||||
inputs.append(io.Float.Input("img_in.", **argument))
|
||||
inputs.append(io.Float.Input("txt_norm.", **argument))
|
||||
inputs.append(io.Float.Input("txt_in.", **argument))
|
||||
inputs.append(io.Float.Input("time_text_embed.", **argument))
|
||||
|
||||
for i in range(60):
|
||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
||||
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
|
||||
|
||||
arg_dict["proj_out."] = argument
|
||||
inputs.append(io.Float.Input("proj_out.", **argument))
|
||||
|
||||
return {"required": arg_dict}
|
||||
return io.Schema(
|
||||
node_id="ModelMergeQwenImage",
|
||||
category="advanced/model_merging/model_specific",
|
||||
inputs=inputs,
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
"ModelMergeSDXL": ModelMergeSDXL,
|
||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
||||
"ModelMergeAuraflow": ModelMergeAuraflow,
|
||||
"ModelMergeFlux1": ModelMergeFlux1,
|
||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||
"ModelMergeMochiPreview": ModelMergeMochiPreview,
|
||||
"ModelMergeLTXV": ModelMergeLTXV,
|
||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
"ModelMergeQwenImage": ModelMergeQwenImage,
|
||||
}
|
||||
|
||||
class ModelMergingModelSpecificExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ModelMergeSD1,
|
||||
ModelMergeSD2,
|
||||
ModelMergeSDXL,
|
||||
ModelMergeSD3_2B,
|
||||
ModelMergeAuraflow,
|
||||
ModelMergeFlux1,
|
||||
ModelMergeSD35_Large,
|
||||
ModelMergeMochiPreview,
|
||||
ModelMergeLTXV,
|
||||
ModelMergeCosmos7B,
|
||||
ModelMergeCosmos14B,
|
||||
ModelMergeWAN2_1,
|
||||
ModelMergeCosmosPredict2_2B,
|
||||
ModelMergeCosmosPredict2_14B,
|
||||
ModelMergeQwenImage,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ModelMergingModelSpecificExtension:
|
||||
return ModelMergingModelSpecificExtension()
|
||||
|
||||
@@ -6,44 +6,62 @@ import folder_paths
|
||||
import comfy_extras.nodes_model_merging
|
||||
import node_helpers
|
||||
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
class ImageOnlyCheckpointLoader:
|
||||
|
||||
class ImageOnlyCheckpointLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageOnlyCheckpointLoader",
|
||||
display_name="Image Only Checkpoint Loader (img2vid model)",
|
||||
category="loaders/video_models",
|
||||
inputs=[
|
||||
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.ClipVision.Output(),
|
||||
io.Vae.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "loaders/video_models"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name, output_vae=True, output_clip=True) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (out[0], out[3], out[2])
|
||||
return io.NodeOutput(out[0], out[3], out[2])
|
||||
|
||||
load_checkpoint = execute # TODO: remove
|
||||
|
||||
|
||||
class SVD_img2vid_Conditioning:
|
||||
class SVD_img2vid_Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"init_image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
|
||||
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
|
||||
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
|
||||
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SVD_img2vid_Conditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.ClipVision.Input("clip_vision"),
|
||||
io.Image.Input("init_image"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("video_frames", default=14, min=1, max=4096),
|
||||
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023),
|
||||
io.Int.Input("fps", default=6, min=1, max=1024),
|
||||
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
|
||||
@classmethod
|
||||
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) -> io.NodeOutput:
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||
@@ -54,20 +72,28 @@ class SVD_img2vid_Conditioning:
|
||||
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
|
||||
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
|
||||
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent})
|
||||
return io.NodeOutput(positive, negative, {"samples":latent})
|
||||
|
||||
class VideoLinearCFGGuidance:
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
class VideoLinearCFGGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoLinearCFGGuidance",
|
||||
category="sampling/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "sampling/video_models"
|
||||
|
||||
def patch(self, model, min_cfg):
|
||||
@classmethod
|
||||
def execute(cls, model, min_cfg) -> io.NodeOutput:
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
@@ -78,20 +104,28 @@ class VideoLinearCFGGuidance:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class VideoTriangleCFGGuidance:
|
||||
patch = execute # TODO: remove
|
||||
|
||||
|
||||
class VideoTriangleCFGGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoTriangleCFGGuidance",
|
||||
category="sampling/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "sampling/video_models"
|
||||
|
||||
def patch(self, model, min_cfg):
|
||||
@classmethod
|
||||
def execute(cls, model, min_cfg) -> io.NodeOutput:
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
@@ -105,57 +139,79 @@ class VideoTriangleCFGGuidance:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||
CATEGORY = "advanced/model_merging"
|
||||
patch = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageOnlyCheckpointSave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageOnlyCheckpointSave",
|
||||
search_aliases=["save model", "export checkpoint", "merge save"],
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.ClipVision.Input("clip_vision"),
|
||||
io.Vae.Input("vae"),
|
||||
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip_vision": ("CLIP_VISION",),
|
||||
"vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
def execute(cls, model, clip_vision, vae, filename_prefix) -> io.NodeOutput:
|
||||
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
|
||||
return io.NodeOutput()
|
||||
|
||||
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||
return {}
|
||||
save = execute # TODO: remove
|
||||
|
||||
|
||||
class ConditioningSetAreaPercentageVideo:
|
||||
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConditioningSetAreaPercentageVideo",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Float.Input("width", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("height", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("temporal", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("x", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("y", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("z", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, width, height, temporal, x, y, z, strength):
|
||||
@classmethod
|
||||
def execute(cls, conditioning, width, height, temporal, x, y, z, strength) -> io.NodeOutput:
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
|
||||
"strength": strength,
|
||||
"set_area_to_bounds": False})
|
||||
return (c, )
|
||||
return io.NodeOutput(c)
|
||||
|
||||
append = execute # TODO: remove
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
||||
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
||||
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
|
||||
}
|
||||
class VideoModelExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ImageOnlyCheckpointLoader,
|
||||
SVD_img2vid_Conditioning,
|
||||
VideoLinearCFGGuidance,
|
||||
VideoTriangleCFGGuidance,
|
||||
ImageOnlyCheckpointSave,
|
||||
ConditioningSetAreaPercentageVideo,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> VideoModelExtension:
|
||||
return VideoModelExtension()
|
||||
|
||||
40
execution.py
40
execution.py
@@ -31,7 +31,6 @@ from comfy_execution.graph import (
|
||||
ExecutionBlocker,
|
||||
ExecutionList,
|
||||
get_input_info,
|
||||
get_expected_outputs_for_node,
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
@@ -228,18 +227,7 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
func,
|
||||
allow_interrupt=False,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -289,12 +277,10 @@ async def _async_map_node_over_list(
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(
|
||||
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
|
||||
)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
@@ -303,7 +289,7 @@ async def _async_map_node_over_list(
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
@@ -341,17 +327,8 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@@ -545,10 +522,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
|
||||
#that we just want to cull out each model run.
|
||||
allocator = comfy.memory_management.aimdo_allocator
|
||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.38.13
|
||||
comfyui-workflow-templates==0.8.31
|
||||
comfyui-embedded-docs==0.4.1
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
"""Unit tests for the expected_outputs feature.
|
||||
|
||||
This feature allows nodes to know at runtime which outputs are connected downstream,
|
||||
enabling them to skip computing outputs that aren't needed.
|
||||
"""
|
||||
|
||||
from comfy_api.latest import IO
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from comfy_execution.utils import (
|
||||
CurrentNodeContext,
|
||||
ExecutionContext,
|
||||
get_executing_context,
|
||||
is_output_needed,
|
||||
)
|
||||
|
||||
|
||||
class TestGetExpectedOutputsForNode:
|
||||
"""Tests for get_expected_outputs_for_node() function."""
|
||||
|
||||
def test_single_output_connected(self):
|
||||
"""Test node with single output connected to one downstream node."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_multiple_outputs_partial_connected(self):
|
||||
"""Test node with multiple outputs, only some connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
# Output 1 is not connected
|
||||
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 2})
|
||||
assert 1 not in expected # Output 1 is definitely unused
|
||||
|
||||
def test_no_outputs_connected(self):
|
||||
"""Test node with no outputs connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "OtherNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_same_output_connected_multiple_times(self):
|
||||
"""Test same output connected to multiple downstream nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
|
||||
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_node_not_in_prompt(self):
|
||||
"""Test getting expected outputs for a node not in the prompt."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "999")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_chained_nodes(self):
|
||||
"""Test expected outputs in a chain of nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1's output 0 is connected to node 2
|
||||
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected_1 == frozenset({0})
|
||||
|
||||
# Node 2's output 0 is connected to node 3
|
||||
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
|
||||
assert expected_2 == frozenset({0})
|
||||
|
||||
# Node 3 has no downstream connections
|
||||
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
|
||||
assert expected_3 == frozenset()
|
||||
|
||||
def test_complex_graph(self):
|
||||
"""Test expected outputs in a complex graph with multiple connections."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
|
||||
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
|
||||
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1 has outputs 0, 1, 2 all connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1, 2})
|
||||
|
||||
def test_constant_inputs_ignored(self):
|
||||
"""Test that constant (non-link) inputs don't affect expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {
|
||||
"class_type": "ConsumerNode",
|
||||
"inputs": {
|
||||
"image": ["1", 0],
|
||||
"value": 42,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_ephemeral_node_invalidates_cache(self):
|
||||
"""Test that adding ephemeral nodes updates expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Initially only output 0 is connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
# Add an ephemeral node that connects to output 1
|
||||
dynprompt.add_ephemeral_node(
|
||||
"eph_1",
|
||||
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
|
||||
parent_id="2",
|
||||
display_id="2",
|
||||
)
|
||||
|
||||
# Now both outputs 0 and 1 should be expected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1})
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext with expected_outputs field."""
|
||||
|
||||
def test_context_with_expected_outputs(self):
|
||||
"""Test creating ExecutionContext with expected_outputs."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
|
||||
)
|
||||
assert ctx.prompt_id == "prompt-123"
|
||||
assert ctx.node_id == "node-456"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 2})
|
||||
|
||||
def test_context_without_expected_outputs(self):
|
||||
"""Test ExecutionContext defaults to None for expected_outputs."""
|
||||
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_context_empty_expected_outputs(self):
|
||||
"""Test ExecutionContext with empty expected_outputs set."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
|
||||
)
|
||||
assert ctx.expected_outputs == frozenset()
|
||||
assert len(ctx.expected_outputs) == 0
|
||||
|
||||
|
||||
class TestCurrentNodeContext:
|
||||
"""Tests for CurrentNodeContext context manager with expected_outputs."""
|
||||
|
||||
def test_context_manager_with_expected_outputs(self):
|
||||
"""Test CurrentNodeContext sets and resets context correctly."""
|
||||
assert get_executing_context() is None
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.prompt_id == "prompt-1"
|
||||
assert ctx.node_id == "node-1"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 1})
|
||||
|
||||
assert get_executing_context() is None
|
||||
|
||||
def test_context_manager_without_expected_outputs(self):
|
||||
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
|
||||
with CurrentNodeContext("prompt-1", "node-1"):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested CurrentNodeContext managers."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
|
||||
ctx1 = get_executing_context()
|
||||
assert ctx1.expected_outputs == frozenset({0})
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
|
||||
ctx2 = get_executing_context()
|
||||
assert ctx2.expected_outputs == frozenset({1, 2})
|
||||
assert ctx2.node_id == "node-2"
|
||||
|
||||
# After inner context exits, should be back to outer context
|
||||
ctx1_again = get_executing_context()
|
||||
assert ctx1_again.expected_outputs == frozenset({0})
|
||||
assert ctx1_again.node_id == "node-1"
|
||||
|
||||
def test_output_check_pattern(self):
|
||||
"""Test the typical pattern nodes will use to check expected outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Typical usage pattern
|
||||
if ctx and ctx.expected_outputs is not None:
|
||||
should_compute_0 = 0 in ctx.expected_outputs
|
||||
should_compute_1 = 1 in ctx.expected_outputs
|
||||
should_compute_2 = 2 in ctx.expected_outputs
|
||||
else:
|
||||
# Fallback when info not available
|
||||
should_compute_0 = should_compute_1 = should_compute_2 = True
|
||||
|
||||
assert should_compute_0 is True
|
||||
assert should_compute_1 is False # Not in expected_outputs
|
||||
assert should_compute_2 is True
|
||||
|
||||
|
||||
class TestSchemaLazyOutputs:
|
||||
"""Tests for lazy_outputs in V3 Schema."""
|
||||
|
||||
def test_schema_lazy_outputs_default(self):
|
||||
"""Test that lazy_outputs defaults to False."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is False
|
||||
|
||||
def test_schema_lazy_outputs_true(self):
|
||||
"""Test setting lazy_outputs to True."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is True
|
||||
|
||||
def test_v3_node_lazy_outputs_property(self):
|
||||
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
|
||||
|
||||
class TestNodeWithLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithLazyOutputs",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
|
||||
|
||||
def test_v3_node_lazy_outputs_default(self):
|
||||
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
|
||||
|
||||
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithoutLazyOutputs",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
|
||||
|
||||
|
||||
class TestIsOutputNeeded:
|
||||
"""Tests for is_output_needed() helper function."""
|
||||
|
||||
def test_output_needed_when_in_expected(self):
|
||||
"""Test that output is needed when in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(2) is True
|
||||
|
||||
def test_output_not_needed_when_not_in_expected(self):
|
||||
"""Test that output is not needed when not in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(1) is False
|
||||
assert is_output_needed(3) is False
|
||||
|
||||
def test_output_needed_when_no_context(self):
|
||||
"""Test that output is needed when no context."""
|
||||
assert get_executing_context() is None
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
|
||||
def test_output_needed_when_expected_outputs_is_none(self):
|
||||
"""Test that output is needed when expected_outputs is None."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, None):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
@@ -574,104 +574,6 @@ class TestExecution:
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
|
||||
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs contains all connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, all connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Connect all 3 outputs to preview nodes
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# All outputs should be white (255) since all are connected
|
||||
images0 = result.get_images(output0)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White pixels = 255, meaning output was in expected_outputs
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs only contains connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only some connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect outputs 0 and 2, leave output 1 disconnected
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
# output1 is intentionally not connected
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Connected outputs should be white (255)
|
||||
images0 = result.get_images(output0)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs works with single connected output."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only one connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect output 1
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
images1 = result.get_images(output1)
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
|
||||
# Output 1 should be white (connected), others are not visible in this test
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
"""Test that cache invalidates when output connections change."""
|
||||
g = builder
|
||||
# Use unique dimensions to avoid cache collision with other expected_outputs tests
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
|
||||
|
||||
# First run: only connect output 0
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(expected_outputs_node), "First run should execute the node"
|
||||
|
||||
# Second run: same connections, should be cached
|
||||
result2 = client.run(g)
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
|
||||
|
||||
# Third run: add connection to output 2
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result3 = client.run(g)
|
||||
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
|
||||
if server["should_cache_results"]:
|
||||
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
|
||||
|
||||
# Verify both outputs are now white
|
||||
images0 = result3.get_images(output0)
|
||||
images2 = result3.get_images(output2)
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
|
||||
@@ -6,7 +6,6 @@ from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@@ -483,57 +482,6 @@ class TestOutputNodeWithSocketOutput:
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestExpectedOutputs:
|
||||
"""Test node for the expected_outputs feature.
|
||||
|
||||
This node has 3 IMAGE outputs that encode which outputs were expected:
|
||||
- White image (255) if the output was in expected_outputs
|
||||
- Black image (0) if the output was NOT in expected_outputs
|
||||
|
||||
This allows integration tests to verify which outputs were expected by checking pixel values.
|
||||
"""
|
||||
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("output0", "output1", "output2")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Default: assume all outputs are expected (backwards compatibility)
|
||||
output0_expected = True
|
||||
output1_expected = True
|
||||
output2_expected = True
|
||||
|
||||
if ctx is not None and ctx.expected_outputs is not None:
|
||||
output0_expected = 0 in ctx.expected_outputs
|
||||
output1_expected = 1 in ctx.expected_outputs
|
||||
output2_expected = 2 in ctx.expected_outputs
|
||||
|
||||
# Return white image if expected, black if not
|
||||
# This allows tests to verify which outputs were expected via pixel values
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
return (
|
||||
white if output0_expected else black,
|
||||
white if output1_expected else black,
|
||||
white if output2_expected else black,
|
||||
)
|
||||
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -550,7 +498,6 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
"TestExpectedOutputs": TestExpectedOutputs,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -569,5 +516,4 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
"TestExpectedOutputs": "Test Expected Outputs",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user