mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 21:21:05 +00:00
Compare commits
22 Commits
v0.12.3
...
feat/core/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06408af600 | ||
|
|
a0302cc6a8 | ||
|
|
f350a84261 | ||
|
|
3760d74005 | ||
|
|
9bf5aa54db | ||
|
|
5ff4fdedba | ||
|
|
17e7df43d1 | ||
|
|
039955c527 | ||
|
|
6a26328842 | ||
|
|
204e65b8dc | ||
|
|
a831c19b70 | ||
|
|
eba6c940fd | ||
|
|
3902c86e83 | ||
|
|
01ef4e50ec | ||
|
|
a1c101f861 | ||
|
|
c2d7f07dbf | ||
|
|
458292fef0 | ||
|
|
6555dc65b8 | ||
|
|
50975a7a0d | ||
|
|
d987b0d32d | ||
|
|
2b70ab9ad0 | ||
|
|
00efcc6cd0 |
@@ -7,6 +7,67 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.flux.layers import timestep_embedding
|
||||
|
||||
def get_silence_latent(length, device):
|
||||
head = torch.tensor([[[ 0.5707, 0.0982, 0.6909, -0.5658, 0.6266, 0.6996, -0.1365, -0.1291,
|
||||
-0.0776, -0.1171, -0.2743, -0.8422, -0.1168, 1.5539, -4.6936, 0.7436,
|
||||
-1.1846, -0.2637, 0.6933, -6.7266, 0.0966, -0.1187, -0.3501, -1.1736,
|
||||
0.0587, -2.0517, -1.3651, 0.7508, -0.2490, -1.3548, -0.1290, -0.7261,
|
||||
1.1132, -0.3249, 0.2337, 0.3004, 0.6605, -0.0298, -0.1989, -0.4041,
|
||||
0.2843, -1.0963, -0.5519, 0.2639, -1.0436, -0.1183, 0.0640, 0.4460,
|
||||
-1.1001, -0.6172, -1.3241, 1.1379, 0.5623, -0.1507, -0.1963, -0.4742,
|
||||
-2.4697, 0.5302, 0.5381, 0.4636, -0.1782, -0.0687, 1.0333, 0.4202],
|
||||
[ 0.3040, -0.1367, 0.6200, 0.0665, -0.0642, 0.4655, -0.1187, -0.0440,
|
||||
0.2941, -0.2753, 0.0173, -0.2421, -0.0147, 1.5603, -2.7025, 0.7907,
|
||||
-0.9736, -0.0682, 0.1294, -5.0707, -0.2167, 0.3302, -0.1513, -0.8100,
|
||||
-0.3894, -0.2884, -0.3149, 0.8660, -0.3817, -1.7061, 0.5824, -0.4840,
|
||||
0.6938, 0.1859, 0.1753, 0.3081, 0.0195, 0.1403, -0.0754, -0.2091,
|
||||
0.1251, -0.1578, -0.4968, -0.1052, -0.4554, -0.0320, 0.1284, 0.4974,
|
||||
-1.1889, -0.0344, -0.8313, 0.2953, 0.5445, -0.6249, -0.1595, -0.0682,
|
||||
-3.1412, 0.0484, 0.4153, 0.8260, -0.1526, -0.0625, 0.5366, 0.8473],
|
||||
[ 5.3524e-02, -1.7534e-01, 5.4443e-01, -4.3501e-01, -2.1317e-03,
|
||||
3.7200e-01, -4.0143e-03, -1.5516e-01, -1.2968e-01, -1.5375e-01,
|
||||
-7.7107e-02, -2.0593e-01, -3.2780e-01, 1.5142e+00, -2.6101e+00,
|
||||
5.8698e-01, -1.2716e+00, -2.4773e-01, -2.7933e-02, -5.0799e+00,
|
||||
1.1601e-01, 4.0987e-01, -2.2030e-02, -6.6495e-01, -2.0995e-01,
|
||||
-6.3474e-01, -1.5893e-01, 8.2745e-01, -2.2992e-01, -1.6816e+00,
|
||||
5.4440e-01, -4.9579e-01, 5.5128e-01, 3.0477e-01, 8.3052e-02,
|
||||
-6.1782e-02, 5.9036e-03, 2.9553e-01, -8.0645e-02, -1.0060e-01,
|
||||
1.9144e-01, -3.8124e-01, -7.2949e-01, 2.4520e-02, -5.0814e-01,
|
||||
2.3977e-01, 9.2943e-02, 3.9256e-01, -1.1993e+00, -3.2752e-01,
|
||||
-7.2707e-01, 2.9476e-01, 4.3542e-01, -8.8597e-01, -4.1686e-01,
|
||||
-8.5390e-02, -2.9018e+00, 6.4988e-02, 5.3945e-01, 9.1988e-01,
|
||||
5.8762e-02, -7.0098e-02, 6.4772e-01, 8.9118e-01],
|
||||
[-3.2225e-02, -1.3195e-01, 5.6411e-01, -5.4766e-01, -5.2170e-03,
|
||||
3.1425e-01, -5.4367e-02, -1.9419e-01, -1.3059e-01, -1.3660e-01,
|
||||
-9.0984e-02, -1.9540e-01, -2.5590e-01, 1.5440e+00, -2.6349e+00,
|
||||
6.8273e-01, -1.2532e+00, -1.9810e-01, -2.2793e-02, -5.0506e+00,
|
||||
1.8818e-01, 5.0109e-01, 7.3546e-03, -6.8771e-01, -3.0676e-01,
|
||||
-7.3257e-01, -1.6687e-01, 9.2232e-01, -1.8987e-01, -1.7267e+00,
|
||||
5.3355e-01, -5.3179e-01, 4.4953e-01, 2.8820e-01, 1.3012e-01,
|
||||
-2.0943e-01, -1.1348e-01, 3.3929e-01, -1.5069e-01, -1.2919e-01,
|
||||
1.8929e-01, -3.6166e-01, -8.0756e-01, 6.6387e-02, -5.8867e-01,
|
||||
1.6978e-01, 1.0134e-01, 3.3877e-01, -1.2133e+00, -3.2492e-01,
|
||||
-8.1237e-01, 3.8101e-01, 4.3765e-01, -8.0596e-01, -4.4531e-01,
|
||||
-4.7513e-02, -2.9266e+00, 1.1741e-03, 4.5123e-01, 9.3075e-01,
|
||||
5.3688e-02, -1.9621e-01, 6.4530e-01, 9.3870e-01]]], device=device).movedim(-1, 1)
|
||||
|
||||
silence_latent = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
|
||||
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
|
||||
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
|
||||
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
|
||||
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
|
||||
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
|
||||
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
|
||||
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
|
||||
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
|
||||
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
|
||||
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
|
||||
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, length)
|
||||
silence_latent[:, :, :head.shape[-1]] = head
|
||||
return silence_latent
|
||||
|
||||
|
||||
def get_layer_class(operations, layer_name):
|
||||
if operations is not None and hasattr(operations, layer_name):
|
||||
return getattr(operations, layer_name)
|
||||
@@ -677,7 +738,7 @@ class AttentionPooler(nn.Module):
|
||||
def forward(self, x):
|
||||
B, T, P, D = x.shape
|
||||
x = self.embed_tokens(x)
|
||||
special = self.special_token.expand(B, T, 1, -1)
|
||||
special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1)
|
||||
x = torch.cat([special, x], dim=2)
|
||||
x = x.view(B * T, P + 1, D)
|
||||
|
||||
@@ -728,7 +789,7 @@ class FSQ(nn.Module):
|
||||
self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
|
||||
|
||||
def bound(self, z):
|
||||
levels_minus_1 = (self._levels - 1).to(z.dtype)
|
||||
levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1)
|
||||
scale = 2. / levels_minus_1
|
||||
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
|
||||
|
||||
@@ -743,8 +804,8 @@ class FSQ(nn.Module):
|
||||
return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
|
||||
|
||||
def codes_to_indices(self, zhat):
|
||||
zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
|
||||
return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
|
||||
zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1))
|
||||
return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32)
|
||||
|
||||
def forward(self, z):
|
||||
orig_dtype = z.dtype
|
||||
@@ -826,7 +887,7 @@ class ResidualFSQ(nn.Module):
|
||||
x = self.project_in(x)
|
||||
|
||||
if hasattr(self, 'soft_clamp_input_value'):
|
||||
sc_val = self.soft_clamp_input_value.to(x.dtype)
|
||||
sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype)
|
||||
x = (x / sc_val).tanh() * sc_val
|
||||
|
||||
quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
|
||||
@@ -834,7 +895,7 @@ class ResidualFSQ(nn.Module):
|
||||
all_indices = []
|
||||
|
||||
for layer, scale in zip(self.layers, self.scales):
|
||||
scale = scale.to(residual.dtype)
|
||||
scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype)
|
||||
|
||||
quantized, indices = layer(residual / scale)
|
||||
quantized = quantized * scale
|
||||
@@ -1040,22 +1101,21 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
lm_hints = self.detokenizer(lm_hints_5Hz)
|
||||
|
||||
lm_hints = lm_hints[:, :src_latents.shape[1], :]
|
||||
if is_covers is None:
|
||||
if is_covers is None or is_covers is True:
|
||||
src_latents = lm_hints
|
||||
else:
|
||||
src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents)
|
||||
elif is_covers is False:
|
||||
src_latents = refer_audio_acoustic_hidden_states_packed
|
||||
|
||||
context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1)
|
||||
|
||||
return encoder_hidden, encoder_mask, context_latents
|
||||
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs):
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
|
||||
text_attention_mask = None
|
||||
lyric_attention_mask = None
|
||||
refer_audio_order_mask = None
|
||||
attention_mask = None
|
||||
chunk_masks = None
|
||||
is_covers = None
|
||||
src_latents = None
|
||||
precomputed_lm_hints_25Hz = None
|
||||
lyric_hidden_states = lyric_embed
|
||||
@@ -1067,7 +1127,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
if refer_audio_order_mask is None:
|
||||
refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
|
||||
|
||||
if src_latents is None and is_covers is None:
|
||||
if src_latents is None:
|
||||
src_latents = x
|
||||
|
||||
if chunk_masks is None:
|
||||
@@ -1080,6 +1140,9 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
||||
)
|
||||
|
||||
if replace_with_null_embeds:
|
||||
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
|
||||
|
||||
out = self.decoder(hidden_states=x,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@@ -463,6 +463,8 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@@ -512,7 +514,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -522,7 +524,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@@ -536,7 +538,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -555,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@@ -563,8 +565,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@@ -884,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@@ -147,11 +147,11 @@ class BaseModel(torch.nn.Module):
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
@@ -1552,6 +1552,8 @@ class ACEStep15(BaseModel):
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if torch.count_nonzero(cross_attn) == 0:
|
||||
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
@@ -1560,22 +1562,11 @@ class ACEStep15(BaseModel):
|
||||
|
||||
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
|
||||
if refer_audio is None or len(refer_audio) == 0:
|
||||
refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
|
||||
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
|
||||
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
|
||||
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
|
||||
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
|
||||
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
|
||||
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
|
||||
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
|
||||
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
|
||||
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
|
||||
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
|
||||
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
|
||||
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
pass_audio_codes = True
|
||||
else:
|
||||
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||
out['is_covers'] = comfy.conds.CONDConstant(True)
|
||||
pass_audio_codes = False
|
||||
|
||||
if pass_audio_codes:
|
||||
@@ -1583,6 +1574,12 @@ class ACEStep15(BaseModel):
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
refer_audio = refer_audio[:, :, :750]
|
||||
else:
|
||||
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
|
||||
if refer_audio.shape[2] < noise.shape[2]:
|
||||
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
return out
|
||||
|
||||
@@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
@@ -1023,11 +1023,7 @@ class Anima(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Anima(self, device=device)
|
||||
@@ -1038,6 +1034,12 @@ class Anima(supported_models_base.BASE):
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
|
||||
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
if dtype is torch.float16:
|
||||
self.memory_usage_factor *= 1.4
|
||||
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
|
||||
@@ -23,7 +23,7 @@ class AnimaTokenizer:
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
@@ -1430,6 +1430,11 @@ class Schema:
|
||||
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
||||
accept_all_inputs: bool=False
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
lazy_outputs: bool=False
|
||||
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
||||
|
||||
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
||||
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@@ -1875,6 +1880,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._ACCEPT_ALL_INPUTS
|
||||
|
||||
_LAZY_OUTPUTS = None
|
||||
@final
|
||||
@classproperty
|
||||
def LAZY_OUTPUTS(cls): # noqa
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._LAZY_OUTPUTS
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
@@ -1917,6 +1930,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._NOT_IDEMPOTENT = schema.not_idempotent
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls._LAZY_OUTPUTS = schema.lazy_outputs
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
|
||||
@@ -5,7 +5,7 @@ import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
@@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
|
||||
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
|
||||
expected = get_expected_outputs_for_node(dynprompt, node_id)
|
||||
signature.append(("expected_outputs", tuple(sorted(expected))))
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
|
||||
@@ -19,6 +19,15 @@ class NodeInputError(Exception):
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
"""Get the set of output indices that are connected downstream.
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||
"""
|
||||
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
@@ -27,6 +36,7 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
@@ -42,6 +52,7 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
@@ -59,6 +70,26 @@ class DynamicPrompt:
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def _build_expected_outputs_map(self):
|
||||
result = {}
|
||||
for node_id in self.all_node_ids():
|
||||
try:
|
||||
node_data = self.get_node(node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
for value in node_data.get("inputs", {}).values():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id not in result:
|
||||
result[from_node_id] = set()
|
||||
result[from_node_id].add(from_socket)
|
||||
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||
|
||||
def get_expected_outputs_map(self):
|
||||
if self._expected_outputs_map is None:
|
||||
self._build_expected_outputs_map()
|
||||
return self._expected_outputs_map
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
@@ -1,23 +1,41 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
from typing import NamedTuple, FrozenSet
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
prompt_id: The ID of the current prompt execution
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
expected_outputs: Set of output indices that might be used downstream.
|
||||
Outputs NOT in this set are definitely unused (safe to skip).
|
||||
None means the information is not available.
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
list_index: int | None
|
||||
expected_outputs: FrozenSet[int] | None = None
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
def get_executing_context() -> ExecutionContext | None:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
|
||||
def is_output_needed(output_index: int) -> bool:
|
||||
"""Check if an output at the given index is connected downstream.
|
||||
|
||||
Returns True if the output might be used (should be computed).
|
||||
Returns False if the output is definitely not connected (safe to skip).
|
||||
"""
|
||||
ctx = get_executing_context()
|
||||
if ctx is None or ctx.expected_outputs is None:
|
||||
return True
|
||||
return output_index in ctx.expected_outputs
|
||||
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
@@ -25,15 +43,22 @@ class CurrentNodeContext:
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_id: str,
|
||||
node_id: str,
|
||||
list_index: int | None = None,
|
||||
expected_outputs: FrozenSet[int] | None = None,
|
||||
):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index,
|
||||
expected_outputs=expected_outputs,
|
||||
)
|
||||
self.token = None
|
||||
|
||||
|
||||
@@ -622,6 +622,7 @@ class SamplerSASolver(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSASolver",
|
||||
search_aliases=["sde"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@@ -666,6 +667,7 @@ class SamplerSEEDS2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
search_aliases=["sde", "exp heun"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
|
||||
@@ -9,6 +9,14 @@ if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _extract_tensor(data, output_channels):
|
||||
"""Extract tensor from data, handling both single tensors and lists."""
|
||||
if isinstance(data, list):
|
||||
# LTX2 AV tensors: [video, audio]
|
||||
return data[0][:, :output_channels], data[1][:, :output_channels]
|
||||
return data[:, :output_channels], None
|
||||
|
||||
|
||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
transformer_options: dict[str] = args[-1]
|
||||
@@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
x, ax = _extract_tensor(args[0], easycache.output_channels)
|
||||
sigmas = transformer_options["sigmas"]
|
||||
uuids = transformer_options["uuids"]
|
||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||
@@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if easycache.skip_current_step and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
result = easycache.apply_cache_diff(x, uuids)
|
||||
if ax is not None:
|
||||
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||
return [result, result_audio]
|
||||
return result
|
||||
if easycache.initial_step:
|
||||
easycache.first_cond_uuid = uuids[0]
|
||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||
@@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
easycache.skip_current_step = True
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
result = easycache.apply_cache_diff(x, uuids)
|
||||
if ax is not None:
|
||||
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||
return [result, result_audio]
|
||||
return result
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
full_output: torch.Tensor = executor(*args, **kwargs)
|
||||
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
|
||||
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
@@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||
if audio_output is not None:
|
||||
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
|
||||
if has_first_cond_uuid:
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
return output
|
||||
return full_output
|
||||
|
||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
@@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(timestep)
|
||||
@@ -197,6 +216,7 @@ class EasyCacheHolder:
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
||||
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
@@ -245,20 +265,21 @@ class EasyCacheHolder:
|
||||
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||
if self.first_cond_uuid in uuids and not is_audio:
|
||||
self.total_steps_skipped += 1
|
||||
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
# slice out only what is relevant to this cond
|
||||
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
||||
slicing = []
|
||||
skip_this_dim = True
|
||||
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
||||
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
|
||||
if skip_this_dim:
|
||||
skip_this_dim = False
|
||||
continue
|
||||
@@ -270,10 +291,11 @@ class EasyCacheHolder:
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
batch_slice = batch_slice + slicing
|
||||
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if output.shape[1:] != x.shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
@@ -293,7 +315,7 @@ class EasyCacheHolder:
|
||||
diff = output - x
|
||||
batch_offset = diff.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
|
||||
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||
return self.first_cond_uuid in uuids
|
||||
@@ -324,6 +346,8 @@ class EasyCacheHolder:
|
||||
self.output_prev_norm = None
|
||||
del self.uuid_cache_diffs
|
||||
self.uuid_cache_diffs = {}
|
||||
del self.uuid_cache_diffs_audio
|
||||
self.uuid_cache_diffs_audio = {}
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
return self
|
||||
|
||||
@@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
dims = list(range(1, latent_vector_magnitude.ndim))
|
||||
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
|
||||
47
comfy_extras/nodes_toolkit.py
Normal file
47
comfy_extras/nodes_toolkit.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class CreateList(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template_matchtype = io.MatchType.Template("type")
|
||||
template_autogrow = io.Autogrow.TemplatePrefix(
|
||||
input=io.MatchType.Input("input", template=template_matchtype),
|
||||
prefix="input",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="CreateList",
|
||||
display_name="Create List",
|
||||
category="logic",
|
||||
is_input_list=True,
|
||||
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
|
||||
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
|
||||
outputs=[
|
||||
io.MatchType.Output(
|
||||
template=template_matchtype,
|
||||
is_output_list=True,
|
||||
display_name="list",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||
output_list = []
|
||||
for input in inputs.values():
|
||||
output_list += input
|
||||
return io.NodeOutput(output_list)
|
||||
|
||||
|
||||
class ToolkitExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
CreateList,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ToolkitExtension:
|
||||
return ToolkitExtension()
|
||||
40
execution.py
40
execution.py
@@ -31,6 +31,7 @@ from comfy_execution.graph import (
|
||||
ExecutionBlocker,
|
||||
ExecutionList,
|
||||
get_input_info,
|
||||
get_expected_outputs_for_node,
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
@@ -227,7 +228,18 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
async def _async_map_node_over_list(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
func,
|
||||
allow_interrupt=False,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -277,10 +289,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
task = asyncio.create_task(
|
||||
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
|
||||
)
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
@@ -289,7 +303,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
@@ -327,8 +341,17 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
async def get_output_data(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@@ -522,9 +545,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
|
||||
#that we just want to cull out each model run.
|
||||
allocator = comfy.memory_management.aimdo_allocator
|
||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
|
||||
3
nodes.py
3
nodes.py
@@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_lora_debug.py",
|
||||
"nodes_color.py"
|
||||
"nodes_color.py",
|
||||
"nodes_toolkit.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.37.11
|
||||
comfyui-frontend-package==1.38.13
|
||||
comfyui-workflow-templates==0.8.31
|
||||
comfyui-embedded-docs==0.4.0
|
||||
comfyui-embedded-docs==0.4.1
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
322
tests-unit/execution_test/expected_outputs_test.py
Normal file
322
tests-unit/execution_test/expected_outputs_test.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Unit tests for the expected_outputs feature.
|
||||
|
||||
This feature allows nodes to know at runtime which outputs are connected downstream,
|
||||
enabling them to skip computing outputs that aren't needed.
|
||||
"""
|
||||
|
||||
from comfy_api.latest import IO
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from comfy_execution.utils import (
|
||||
CurrentNodeContext,
|
||||
ExecutionContext,
|
||||
get_executing_context,
|
||||
is_output_needed,
|
||||
)
|
||||
|
||||
|
||||
class TestGetExpectedOutputsForNode:
|
||||
"""Tests for get_expected_outputs_for_node() function."""
|
||||
|
||||
def test_single_output_connected(self):
|
||||
"""Test node with single output connected to one downstream node."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_multiple_outputs_partial_connected(self):
|
||||
"""Test node with multiple outputs, only some connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
# Output 1 is not connected
|
||||
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 2})
|
||||
assert 1 not in expected # Output 1 is definitely unused
|
||||
|
||||
def test_no_outputs_connected(self):
|
||||
"""Test node with no outputs connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "OtherNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_same_output_connected_multiple_times(self):
|
||||
"""Test same output connected to multiple downstream nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
|
||||
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_node_not_in_prompt(self):
|
||||
"""Test getting expected outputs for a node not in the prompt."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "999")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_chained_nodes(self):
|
||||
"""Test expected outputs in a chain of nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1's output 0 is connected to node 2
|
||||
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected_1 == frozenset({0})
|
||||
|
||||
# Node 2's output 0 is connected to node 3
|
||||
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
|
||||
assert expected_2 == frozenset({0})
|
||||
|
||||
# Node 3 has no downstream connections
|
||||
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
|
||||
assert expected_3 == frozenset()
|
||||
|
||||
def test_complex_graph(self):
|
||||
"""Test expected outputs in a complex graph with multiple connections."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
|
||||
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
|
||||
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1 has outputs 0, 1, 2 all connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1, 2})
|
||||
|
||||
def test_constant_inputs_ignored(self):
|
||||
"""Test that constant (non-link) inputs don't affect expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {
|
||||
"class_type": "ConsumerNode",
|
||||
"inputs": {
|
||||
"image": ["1", 0],
|
||||
"value": 42,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_ephemeral_node_invalidates_cache(self):
|
||||
"""Test that adding ephemeral nodes updates expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Initially only output 0 is connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
# Add an ephemeral node that connects to output 1
|
||||
dynprompt.add_ephemeral_node(
|
||||
"eph_1",
|
||||
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
|
||||
parent_id="2",
|
||||
display_id="2",
|
||||
)
|
||||
|
||||
# Now both outputs 0 and 1 should be expected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1})
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext with expected_outputs field."""
|
||||
|
||||
def test_context_with_expected_outputs(self):
|
||||
"""Test creating ExecutionContext with expected_outputs."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
|
||||
)
|
||||
assert ctx.prompt_id == "prompt-123"
|
||||
assert ctx.node_id == "node-456"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 2})
|
||||
|
||||
def test_context_without_expected_outputs(self):
|
||||
"""Test ExecutionContext defaults to None for expected_outputs."""
|
||||
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_context_empty_expected_outputs(self):
|
||||
"""Test ExecutionContext with empty expected_outputs set."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
|
||||
)
|
||||
assert ctx.expected_outputs == frozenset()
|
||||
assert len(ctx.expected_outputs) == 0
|
||||
|
||||
|
||||
class TestCurrentNodeContext:
|
||||
"""Tests for CurrentNodeContext context manager with expected_outputs."""
|
||||
|
||||
def test_context_manager_with_expected_outputs(self):
|
||||
"""Test CurrentNodeContext sets and resets context correctly."""
|
||||
assert get_executing_context() is None
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.prompt_id == "prompt-1"
|
||||
assert ctx.node_id == "node-1"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 1})
|
||||
|
||||
assert get_executing_context() is None
|
||||
|
||||
def test_context_manager_without_expected_outputs(self):
|
||||
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
|
||||
with CurrentNodeContext("prompt-1", "node-1"):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested CurrentNodeContext managers."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
|
||||
ctx1 = get_executing_context()
|
||||
assert ctx1.expected_outputs == frozenset({0})
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
|
||||
ctx2 = get_executing_context()
|
||||
assert ctx2.expected_outputs == frozenset({1, 2})
|
||||
assert ctx2.node_id == "node-2"
|
||||
|
||||
# After inner context exits, should be back to outer context
|
||||
ctx1_again = get_executing_context()
|
||||
assert ctx1_again.expected_outputs == frozenset({0})
|
||||
assert ctx1_again.node_id == "node-1"
|
||||
|
||||
def test_output_check_pattern(self):
|
||||
"""Test the typical pattern nodes will use to check expected outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Typical usage pattern
|
||||
if ctx and ctx.expected_outputs is not None:
|
||||
should_compute_0 = 0 in ctx.expected_outputs
|
||||
should_compute_1 = 1 in ctx.expected_outputs
|
||||
should_compute_2 = 2 in ctx.expected_outputs
|
||||
else:
|
||||
# Fallback when info not available
|
||||
should_compute_0 = should_compute_1 = should_compute_2 = True
|
||||
|
||||
assert should_compute_0 is True
|
||||
assert should_compute_1 is False # Not in expected_outputs
|
||||
assert should_compute_2 is True
|
||||
|
||||
|
||||
class TestSchemaLazyOutputs:
|
||||
"""Tests for lazy_outputs in V3 Schema."""
|
||||
|
||||
def test_schema_lazy_outputs_default(self):
|
||||
"""Test that lazy_outputs defaults to False."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is False
|
||||
|
||||
def test_schema_lazy_outputs_true(self):
|
||||
"""Test setting lazy_outputs to True."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is True
|
||||
|
||||
def test_v3_node_lazy_outputs_property(self):
|
||||
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
|
||||
|
||||
class TestNodeWithLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithLazyOutputs",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
|
||||
|
||||
def test_v3_node_lazy_outputs_default(self):
|
||||
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
|
||||
|
||||
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithoutLazyOutputs",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
|
||||
|
||||
|
||||
class TestIsOutputNeeded:
|
||||
"""Tests for is_output_needed() helper function."""
|
||||
|
||||
def test_output_needed_when_in_expected(self):
|
||||
"""Test that output is needed when in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(2) is True
|
||||
|
||||
def test_output_not_needed_when_not_in_expected(self):
|
||||
"""Test that output is not needed when not in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(1) is False
|
||||
assert is_output_needed(3) is False
|
||||
|
||||
def test_output_needed_when_no_context(self):
|
||||
"""Test that output is needed when no context."""
|
||||
assert get_executing_context() is None
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
|
||||
def test_output_needed_when_expected_outputs_is_none(self):
|
||||
"""Test that output is needed when expected_outputs is None."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, None):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
@@ -574,6 +574,104 @@ class TestExecution:
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
|
||||
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs contains all connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, all connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Connect all 3 outputs to preview nodes
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# All outputs should be white (255) since all are connected
|
||||
images0 = result.get_images(output0)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White pixels = 255, meaning output was in expected_outputs
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs only contains connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only some connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect outputs 0 and 2, leave output 1 disconnected
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
# output1 is intentionally not connected
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Connected outputs should be white (255)
|
||||
images0 = result.get_images(output0)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs works with single connected output."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only one connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect output 1
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
images1 = result.get_images(output1)
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
|
||||
# Output 1 should be white (connected), others are not visible in this test
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
"""Test that cache invalidates when output connections change."""
|
||||
g = builder
|
||||
# Use unique dimensions to avoid cache collision with other expected_outputs tests
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
|
||||
|
||||
# First run: only connect output 0
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(expected_outputs_node), "First run should execute the node"
|
||||
|
||||
# Second run: same connections, should be cached
|
||||
result2 = client.run(g)
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
|
||||
|
||||
# Third run: add connection to output 2
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result3 = client.run(g)
|
||||
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
|
||||
if server["should_cache_results"]:
|
||||
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
|
||||
|
||||
# Verify both outputs are now white
|
||||
images0 = result3.get_images(output0)
|
||||
images2 = result3.get_images(output2)
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
|
||||
@@ -6,6 +6,7 @@ from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@@ -482,6 +483,57 @@ class TestOutputNodeWithSocketOutput:
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestExpectedOutputs:
|
||||
"""Test node for the expected_outputs feature.
|
||||
|
||||
This node has 3 IMAGE outputs that encode which outputs were expected:
|
||||
- White image (255) if the output was in expected_outputs
|
||||
- Black image (0) if the output was NOT in expected_outputs
|
||||
|
||||
This allows integration tests to verify which outputs were expected by checking pixel values.
|
||||
"""
|
||||
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("output0", "output1", "output2")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Default: assume all outputs are expected (backwards compatibility)
|
||||
output0_expected = True
|
||||
output1_expected = True
|
||||
output2_expected = True
|
||||
|
||||
if ctx is not None and ctx.expected_outputs is not None:
|
||||
output0_expected = 0 in ctx.expected_outputs
|
||||
output1_expected = 1 in ctx.expected_outputs
|
||||
output2_expected = 2 in ctx.expected_outputs
|
||||
|
||||
# Return white image if expected, black if not
|
||||
# This allows tests to verify which outputs were expected via pixel values
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
return (
|
||||
white if output0_expected else black,
|
||||
white if output1_expected else black,
|
||||
white if output2_expected else black,
|
||||
)
|
||||
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -498,6 +550,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
"TestExpectedOutputs": TestExpectedOutputs,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -516,4 +569,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
"TestExpectedOutputs": "Test Expected Outputs",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user