Compare commits

...

21 Commits

Author SHA1 Message Date
bigcat88
ac1073be99 convert model_merging and video_model nodes to V3 schema 2026-02-06 17:23:42 +02:00
Jukka Seppänen
a1c101f861 EasyCache: Support LTX2 (#12231) 2026-02-06 00:43:09 -05:00
comfyanonymous
c2d7f07dbf Fix issue when using disable_unet_model_creation (#12315) 2026-02-05 19:24:09 -05:00
comfyanonymous
458292fef0 Fix some lowvram stuff with ace step 1.5 (#12312) 2026-02-05 19:15:04 -05:00
comfyanonymous
6555dc65b8 Make ace step 1.5 work without the llm. (#12311) 2026-02-05 16:43:45 -05:00
AustinMroz
2b70ab9ad0 Add a Create List node (#12173) 2026-02-05 01:18:21 -05:00
Comfy Org PR Bot
00efcc6cd0 Bump comfyui-frontend-package to 1.38.13 (#12238) 2026-02-05 01:17:37 -05:00
comfyanonymous
cb459573c8 ComfyUI v0.12.3 2026-02-05 01:13:35 -05:00
comfyanonymous
35183543e0 Add VAE tiled decode node for audio. (#12299) 2026-02-05 01:12:04 -05:00
blepping
a246cc02b2 Improvements to ACE-Steps 1.5 text encoding (#12283) 2026-02-05 00:17:37 -05:00
comfyanonymous
a50c32d63f Disable sage attention on ace step 1.5 (#12297) 2026-02-04 22:15:30 -05:00
comfyanonymous
6125b80979 Add llm sampling options and make reference audio work on ace step 1.5 (#12295) 2026-02-04 21:29:22 -05:00
comfyanonymous
c8fcbd66ee Try to fix ace text encoder slowness on some configs. (#12290) 2026-02-04 19:37:05 -05:00
comfyanonymous
26dd7eb421 Fix ace step nan issue on some hardware/pytorch configs. (#12289) 2026-02-04 18:25:06 -05:00
Alexander Piskun
e77b34dfea add File3DAny output to Load3D node; extend SaveGLB to accept File3DAny as input (#12276)
* add File3DAny output to Load3D node; extend SaveGLB node to accept File3DAny as input

* fix(grammar): capitalize letter
2026-02-04 11:35:38 -08:00
rattus
ef73070ea4 mp: Fix checkpoint saving (#12268)
Fix regression in the recent model saving refactor. Pass the non unet
pieces down the layers so that checkpoints are complete.
2026-02-04 02:08:45 -05:00
rattus
d30c609f5a utils: safetensors: dont slice data on torch level (#12266)
Torch has alignment enforcement when viewing with data type changes
but only relative to itself. Do all tensor constructions straight
off the memory-view individually so pytorch doesnt see an alignment
problem.

The is needed for handling misaligned safetensors weights, which are
reasonably common in third party models.

This limits usage of this safetensors loader to GPU compute only
as CPUs kernnel are very likely to bus error. But it works for
dynamic_vram, where we really dont want to take a deep copy and we
always use GPU copy_ which disentangles the misalignment.
2026-02-04 01:48:47 -05:00
comfyanonymous
5087f1d497 ComfyUI v0.12.2 2026-02-04 00:08:59 -05:00
comfyanonymous
a31681564d Fix crash with ace step 1.5 (#12264) 2026-02-04 00:03:21 -05:00
rattus
855849c658 mm: Remove Aimdo exemption for empty_cache (#12260)
Its more important to get the torch caching allocator GC up and running
than supporting the pyt2.7 bug. Switch it on.

Defeature dynamic_vram + pyt2.7.
2026-02-03 21:39:19 -05:00
comfyanonymous
fe2511468d Support the 4B ace step 1.5 lm model. (#12257)
Can be used as an alternative to the 1.7B
2026-02-03 19:01:38 -05:00
25 changed files with 1146 additions and 614 deletions

View File

@@ -7,6 +7,67 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
from comfy.ldm.flux.layers import timestep_embedding
def get_silence_latent(length, device):
head = torch.tensor([[[ 0.5707, 0.0982, 0.6909, -0.5658, 0.6266, 0.6996, -0.1365, -0.1291,
-0.0776, -0.1171, -0.2743, -0.8422, -0.1168, 1.5539, -4.6936, 0.7436,
-1.1846, -0.2637, 0.6933, -6.7266, 0.0966, -0.1187, -0.3501, -1.1736,
0.0587, -2.0517, -1.3651, 0.7508, -0.2490, -1.3548, -0.1290, -0.7261,
1.1132, -0.3249, 0.2337, 0.3004, 0.6605, -0.0298, -0.1989, -0.4041,
0.2843, -1.0963, -0.5519, 0.2639, -1.0436, -0.1183, 0.0640, 0.4460,
-1.1001, -0.6172, -1.3241, 1.1379, 0.5623, -0.1507, -0.1963, -0.4742,
-2.4697, 0.5302, 0.5381, 0.4636, -0.1782, -0.0687, 1.0333, 0.4202],
[ 0.3040, -0.1367, 0.6200, 0.0665, -0.0642, 0.4655, -0.1187, -0.0440,
0.2941, -0.2753, 0.0173, -0.2421, -0.0147, 1.5603, -2.7025, 0.7907,
-0.9736, -0.0682, 0.1294, -5.0707, -0.2167, 0.3302, -0.1513, -0.8100,
-0.3894, -0.2884, -0.3149, 0.8660, -0.3817, -1.7061, 0.5824, -0.4840,
0.6938, 0.1859, 0.1753, 0.3081, 0.0195, 0.1403, -0.0754, -0.2091,
0.1251, -0.1578, -0.4968, -0.1052, -0.4554, -0.0320, 0.1284, 0.4974,
-1.1889, -0.0344, -0.8313, 0.2953, 0.5445, -0.6249, -0.1595, -0.0682,
-3.1412, 0.0484, 0.4153, 0.8260, -0.1526, -0.0625, 0.5366, 0.8473],
[ 5.3524e-02, -1.7534e-01, 5.4443e-01, -4.3501e-01, -2.1317e-03,
3.7200e-01, -4.0143e-03, -1.5516e-01, -1.2968e-01, -1.5375e-01,
-7.7107e-02, -2.0593e-01, -3.2780e-01, 1.5142e+00, -2.6101e+00,
5.8698e-01, -1.2716e+00, -2.4773e-01, -2.7933e-02, -5.0799e+00,
1.1601e-01, 4.0987e-01, -2.2030e-02, -6.6495e-01, -2.0995e-01,
-6.3474e-01, -1.5893e-01, 8.2745e-01, -2.2992e-01, -1.6816e+00,
5.4440e-01, -4.9579e-01, 5.5128e-01, 3.0477e-01, 8.3052e-02,
-6.1782e-02, 5.9036e-03, 2.9553e-01, -8.0645e-02, -1.0060e-01,
1.9144e-01, -3.8124e-01, -7.2949e-01, 2.4520e-02, -5.0814e-01,
2.3977e-01, 9.2943e-02, 3.9256e-01, -1.1993e+00, -3.2752e-01,
-7.2707e-01, 2.9476e-01, 4.3542e-01, -8.8597e-01, -4.1686e-01,
-8.5390e-02, -2.9018e+00, 6.4988e-02, 5.3945e-01, 9.1988e-01,
5.8762e-02, -7.0098e-02, 6.4772e-01, 8.9118e-01],
[-3.2225e-02, -1.3195e-01, 5.6411e-01, -5.4766e-01, -5.2170e-03,
3.1425e-01, -5.4367e-02, -1.9419e-01, -1.3059e-01, -1.3660e-01,
-9.0984e-02, -1.9540e-01, -2.5590e-01, 1.5440e+00, -2.6349e+00,
6.8273e-01, -1.2532e+00, -1.9810e-01, -2.2793e-02, -5.0506e+00,
1.8818e-01, 5.0109e-01, 7.3546e-03, -6.8771e-01, -3.0676e-01,
-7.3257e-01, -1.6687e-01, 9.2232e-01, -1.8987e-01, -1.7267e+00,
5.3355e-01, -5.3179e-01, 4.4953e-01, 2.8820e-01, 1.3012e-01,
-2.0943e-01, -1.1348e-01, 3.3929e-01, -1.5069e-01, -1.2919e-01,
1.8929e-01, -3.6166e-01, -8.0756e-01, 6.6387e-02, -5.8867e-01,
1.6978e-01, 1.0134e-01, 3.3877e-01, -1.2133e+00, -3.2492e-01,
-8.1237e-01, 3.8101e-01, 4.3765e-01, -8.0596e-01, -4.4531e-01,
-4.7513e-02, -2.9266e+00, 1.1741e-03, 4.5123e-01, 9.3075e-01,
5.3688e-02, -1.9621e-01, 6.4530e-01, 9.3870e-01]]], device=device).movedim(-1, 1)
silence_latent = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, length)
silence_latent[:, :, :head.shape[-1]] = head
return silence_latent
def get_layer_class(operations, layer_name):
if operations is not None and hasattr(operations, layer_name):
return getattr(operations, layer_name)
@@ -183,7 +244,7 @@ class AceStepAttention(nn.Module):
else:
attn_bias = window_bias
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
attn_output = self.o_proj(attn_output)
return attn_output
@@ -677,7 +738,7 @@ class AttentionPooler(nn.Module):
def forward(self, x):
B, T, P, D = x.shape
x = self.embed_tokens(x)
special = self.special_token.expand(B, T, 1, -1)
special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1)
x = torch.cat([special, x], dim=2)
x = x.view(B * T, P + 1, D)
@@ -728,7 +789,7 @@ class FSQ(nn.Module):
self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
def bound(self, z):
levels_minus_1 = (self._levels - 1).to(z.dtype)
levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1)
scale = 2. / levels_minus_1
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
@@ -743,8 +804,8 @@ class FSQ(nn.Module):
return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
def codes_to_indices(self, zhat):
zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1))
return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32)
def forward(self, z):
orig_dtype = z.dtype
@@ -826,7 +887,7 @@ class ResidualFSQ(nn.Module):
x = self.project_in(x)
if hasattr(self, 'soft_clamp_input_value'):
sc_val = self.soft_clamp_input_value.to(x.dtype)
sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype)
x = (x / sc_val).tanh() * sc_val
quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
@@ -834,7 +895,7 @@ class ResidualFSQ(nn.Module):
all_indices = []
for layer, scale in zip(self.layers, self.scales):
scale = scale.to(residual.dtype)
scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype)
quantized, indices = layer(residual / scale)
quantized = quantized * scale
@@ -1035,28 +1096,26 @@ class AceStepConditionGenerationModel(nn.Module):
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
else:
assert False
# TODO ?
lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed)
lm_hints = self.detokenizer(lm_hints_5Hz)
lm_hints = lm_hints[:, :src_latents.shape[1], :]
if is_covers is None:
if is_covers is None or is_covers is True:
src_latents = lm_hints
else:
src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents)
elif is_covers is False:
src_latents = refer_audio_acoustic_hidden_states_packed
context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1)
return encoder_hidden, encoder_mask, context_latents
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs):
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
text_attention_mask = None
lyric_attention_mask = None
refer_audio_order_mask = None
attention_mask = None
chunk_masks = None
is_covers = None
src_latents = None
precomputed_lm_hints_25Hz = None
lyric_hidden_states = lyric_embed
@@ -1068,7 +1127,7 @@ class AceStepConditionGenerationModel(nn.Module):
if refer_audio_order_mask is None:
refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
if src_latents is None and is_covers is None:
if src_latents is None:
src_latents = x
if chunk_masks is None:

View File

@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape

View File

@@ -147,11 +147,11 @@ class BaseModel(torch.nn.Module):
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
@@ -1548,6 +1548,7 @@ class ACEStep15(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
device = kwargs["device"]
noise = kwargs["noise"]
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
@@ -1559,27 +1560,22 @@ class ACEStep15(BaseModel):
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
if refer_audio is None or len(refer_audio) == 0:
refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
pass_audio_codes = True
else:
refer_audio = refer_audio[-1]
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
out['is_covers'] = comfy.conds.CONDConstant(True)
pass_audio_codes = False
if pass_audio_codes:
audio_codes = kwargs.get("audio_codes", None)
if audio_codes is not None:
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
refer_audio = refer_audio[:, :, :750]
else:
out['is_covers'] = comfy.conds.CONDConstant(False)
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
audio_codes = kwargs.get("audio_codes", None)
if audio_codes is not None:
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
return out
class Omnigen2(BaseModel):

View File

@@ -1724,11 +1724,9 @@ def soft_empty_cache(force=False):
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())

View File

@@ -1400,7 +1400,7 @@ class ModelPatcher:
continue
key = "diffusion_model." + k
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
return self.model.state_dict_for_saving(unet_state_dict)
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):
self.unpin_all_weights()

View File

@@ -54,6 +54,8 @@ try:
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
else:

View File

@@ -976,7 +976,7 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
@@ -1444,7 +1444,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
elif clip_type == CLIPType.ACE:
clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data))
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
if TEModel.QWEN3_4B in te_models:
model_type = "qwen3_4b"
else:
model_type = "qwen3_2b"
clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel

View File

@@ -1625,8 +1625,16 @@ class ACEStep15(supported_models_base.BASE):
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect))
detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
if "dtype_llama" in detect_2b:
detect = detect_2b
detect["lm_model"] = "qwen3_2b"
elif "dtype_llama" in detect_4b:
detect = detect_4b
detect["lm_model"] = "qwen3_4b"
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]

View File

@@ -3,6 +3,7 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
import yaml
import comfy.utils
@@ -19,6 +20,7 @@ def sample_manual_loop_no_classes(
min_tokens: int = 1,
max_new_tokens: int = 2048,
audio_start_id: int = 151669, # The cutoff ID for audio codes
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
device = model.execution_device
@@ -60,6 +62,7 @@ def sample_manual_loop_no_classes(
remove_logit_value = torch.finfo(cfg_logits.dtype).min
# Only generate audio tokens
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
cfg_logits[:, eos_token_id] = eos_score
@@ -99,9 +102,7 @@ def sample_manual_loop_no_classes(
return output_audio_codes
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
cfg_scale = 2.0
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
positive = [[token for token, _ in inner_list] for inner_list in positive]
negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
@@ -118,34 +119,80 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
if v not in {"unspecified", None}
}
if len(user_metas):
meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip()
else:
meta_yaml = ""
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
use_keys = ("bpm", "duration", "keyscale", "timesignature")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
elif isinstance(duration, (str, int, float)):
user_metas["duration"] = f"{math.ceil(float(duration))} seconds"
else:
raise TypeError("Unexpected type for duration key, must be str, int or float")
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
out = {}
lyrics = kwargs.get("lyrics", "")
bpm = kwargs.get("bpm", 120)
duration = kwargs.get("duration", 120)
keyscale = kwargs.get("keyscale", "C major")
timesignature = kwargs.get("timesignature", 2)
language = kwargs.get("language", "en")
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
generate_audio_codes = kwargs.get("generate_audio_codes", True)
cfg_scale = kwargs.get("cfg_scale", 2.0)
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
duration = math.ceil(duration)
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
kwargs["duration"] = duration
meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
cot_text = self._metas_to_cot(caption = text, **kwargs)
meta_cap = self._metas_to_cap(**kwargs)
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
}
return out
@@ -162,14 +209,34 @@ class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ACE15TEModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}):
def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
super().__init__()
if dtype_llama is None:
dtype_llama = dtype
model = None
self.constant = 0.4375
if lm_model == "qwen3_4b":
model = Qwen3_4B_ACE15
self.constant = 0.5625
elif lm_model == "qwen3_2b":
model = Qwen3_2B_ACE15
self.lm_model = lm_model
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options)
if model is not None:
setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
self.dtypes = set([dtype, dtype_llama])
def encode_token_weights(self, token_weight_pairs):
@@ -181,18 +248,26 @@ class ACE15TEModel(torch.nn.Module):
self.qwen3_06b.set_clip_options({"layer": [0]})
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
lm_metadata = token_weight_pairs["lm_metadata"]
audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
out["audio_codes"] = [audio_codes]
return base_out, None, out
def set_clip_options(self, options):
self.qwen3_06b.set_clip_options(options)
self.qwen3_2b.set_clip_options(options)
lm_model = getattr(self, self.lm_model, None)
if lm_model is not None:
lm_model.set_clip_options(options)
def reset_clip_options(self):
self.qwen3_06b.reset_clip_options()
self.qwen3_2b.reset_clip_options()
lm_model = getattr(self, self.lm_model, None)
if lm_model is not None:
lm_model.reset_clip_options()
def load_sd(self, sd):
if "model.layers.0.post_attention_layernorm.weight" in sd:
@@ -200,11 +275,11 @@ class ACE15TEModel(torch.nn.Module):
if shape[0] == 1024:
return self.qwen3_06b.load_sd(sd)
else:
return self.qwen3_2b.load_sd(sd)
return getattr(self, self.lm_model).load_sd(sd)
def memory_estimation_function(self, token_weight_pairs, device=None):
lm_metadata = token_weight_pairs["lm_metadata"]
constant = 0.4375
constant = self.constant
if comfy.model_management.should_use_bf16(device):
constant *= 0.5
@@ -213,11 +288,11 @@ class ACE15TEModel(torch.nn.Module):
num_tokens += lm_metadata['min_tokens']
return num_tokens * constant * 1024 * 1024
def te(dtype_llama=None, llama_quantization_metadata=None):
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
class ACE15TEModel_(ACE15TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options)
super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
return ACE15TEModel_

View File

@@ -150,6 +150,29 @@ class Qwen3_2B_ACE15_lm_Config:
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_4B_ACE15_lm_Config:
vocab_size: int = 217204
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
@@ -628,10 +651,10 @@ class Llama2_(nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
if seq_len > 1:
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
if mask is not None:
mask += causal_mask
else:
@@ -739,6 +762,21 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
class Llama2(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
@@ -767,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, torch.nn.Module):
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
@@ -776,7 +814,7 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06B_ACE15_Config(**config_dict)
@@ -785,7 +823,7 @@ class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
@@ -794,22 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def logits(self, x):
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
class Qwen3_4B(BaseLlama, torch.nn.Module):
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -818,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, torch.nn.Module):
class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4B_ACE15_lm_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)

View File

@@ -82,14 +82,12 @@ _TYPES = {
def load_safetensors(ckpt):
f = open(ckpt, "rb")
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
mv = mv[8 + header_size:]
sd = {}
for name, info in header.items():
@@ -97,7 +95,13 @@ def load_safetensors(ckpt):
continue
start, end = info["data_offsets"]
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
if start == end:
sd[name] = torch.empty(info["shape"], dtype =_TYPES[info["dtype"]])
else:
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
return sd, header.get("__metadata__", {}),

View File

@@ -44,13 +44,18 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
@@ -100,14 +105,15 @@ class EmptyAceStep15LatentAudio(io.ComfyNode):
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "type": "audio"})
class ReferenceTimbreAudio(io.ComfyNode):
class ReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ReferenceTimbreAudio",
display_name="Reference Audio",
category="advanced/conditioning/audio",
is_experimental=True,
description="This node sets the reference audio for timbre (for ace step 1.5)",
description="This node sets the reference audio for ace step 1.5",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
@@ -131,7 +137,7 @@ class AceExtension(ComfyExtension):
EmptyAceStepLatentAudio,
TextEncodeAceStepAudio15,
EmptyAceStep15LatentAudio,
ReferenceTimbreAudio,
ReferenceAudio,
]
async def comfy_entrypoint() -> AceExtension:

View File

@@ -94,6 +94,19 @@ class VAEEncodeAudio(IO.ComfyNode):
encode = execute # TODO: remove
def vae_decode_audio(vae, samples, tile=None, overlap=None):
if tile is not None:
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
else:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
class VAEDecodeAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -111,16 +124,33 @@ class VAEDecodeAudio(IO.ComfyNode):
@classmethod
def execute(cls, vae, samples) -> IO.NodeOutput:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
return IO.NodeOutput(vae_decode_audio(vae, samples))
decode = execute # TODO: remove
class VAEDecodeAudioTiled(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeAudioTiled",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio (Tiled)",
category="latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
],
outputs=[IO.Audio.Output()],
)
@classmethod
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
class SaveAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -675,6 +705,7 @@ class AudioExtension(ComfyExtension):
EmptyLatentAudio,
VAEEncodeAudio,
VAEDecodeAudio,
VAEDecodeAudioTiled,
SaveAudio,
SaveAudioMP3,
SaveAudioOpus,

View File

@@ -9,6 +9,14 @@ if TYPE_CHECKING:
from uuid import UUID
def _extract_tensor(data, output_channels):
"""Extract tensor from data, handling both single tensors and lists."""
if isinstance(data, list):
# LTX2 AV tensors: [video, audio]
return data[0][:, :output_channels], data[1][:, :output_channels]
return data[:, :output_channels], None
def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args
transformer_options: dict[str] = args[-1]
@@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if not transformer_options:
transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"]
x: torch.Tensor = args[0][:, :easycache.output_channels]
x, ax = _extract_tensor(args[0], easycache.output_channels)
sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if easycache.skip_current_step and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
return easycache.apply_cache_diff(x, uuids)
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
if easycache.initial_step:
easycache.first_cond_uuid = uuids[0]
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
@@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
return easycache.apply_cache_diff(x, uuids)
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
else:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
full_output: torch.Tensor = executor(*args, **kwargs)
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
if has_first_cond_uuid and easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
@@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev, uuids)
if audio_output is not None:
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
if has_first_cond_uuid:
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return output
return full_output
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
@@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
# prepare next x_prev
x: torch.Tensor = args[0][:, :easycache.output_channels]
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
@@ -197,6 +216,7 @@ class EasyCacheHolder:
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
@@ -245,20 +265,21 @@ class EasyCacheHolder:
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
if self.first_cond_uuid in uuids and not is_audio:
self.total_steps_skipped += 1
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# slice out only what is relevant to this cond
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_this_dim = True
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
if skip_this_dim:
skip_this_dim = False
continue
@@ -270,10 +291,11 @@ class EasyCacheHolder:
else:
slicing.append(slice(None))
batch_slice = batch_slice + slicing
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if output.shape[1:] != x.shape[1:]:
if not self.allow_mismatch:
@@ -293,7 +315,7 @@ class EasyCacheHolder:
diff = output - x
batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids
@@ -324,6 +346,8 @@ class EasyCacheHolder:
self.output_prev_norm = None
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
del self.uuid_cache_diffs_audio
self.uuid_cache_diffs_audio = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self

View File

@@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
display_name="Save 3D Model",
search_aliases=["export 3d model", "save mesh"],
category="3d",
is_output_node=True,
@@ -626,8 +627,14 @@ class SaveGLB(IO.ComfyNode):
IO.Mesh.Input("mesh"),
types=[
IO.File3DGLB,
IO.File3DGLTF,
IO.File3DOBJ,
IO.File3DFBX,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DAny,
],
tooltip="Mesh or GLB file to save",
tooltip="Mesh or 3D file to save",
),
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
],
@@ -649,7 +656,8 @@ class SaveGLB(IO.ComfyNode):
if isinstance(mesh, Types.File3D):
# Handle File3D input - save BytesIO data to output folder
f = f"{filename}_{counter:05}_.glb"
ext = mesh.format or "glb"
f = f"{filename}_{counter:05}_.{ext}"
mesh.save_to(os.path.join(full_output_folder, f))
results.append({
"filename": f,

View File

@@ -45,6 +45,7 @@ class Load3D(IO.ComfyNode):
IO.Image.Output(display_name="normal"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Video.Output(display_name="recording_video"),
IO.File3DAny.Output(display_name="model_3d"),
],
)
@@ -66,7 +67,8 @@ class Load3D(IO.ComfyNode):
video = InputImpl.VideoFromFile(recording_video_path)
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
process = execute # TODO: remove

View File

@@ -10,146 +10,198 @@ import json
import os
from comfy.cli_args import args
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSimple:
class ModelMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSimple",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, ratio):
@classmethod
def execute(cls, model1, model2, ratio) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
class ModelSubtract:
merge = execute # TODO: remove
class ModelSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSubtract",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, multiplier):
@classmethod
def execute(cls, model1, model2, multiplier) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
class ModelAdd:
merge = execute # TODO: remove
class ModelAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeAdd",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2):
@classmethod
def execute(cls, model1, model2) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPMergeSimple:
class CLIPMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSimple",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, ratio):
@classmethod
def execute(cls, clip1, clip2, ratio) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPSubtract:
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
class CLIPSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSubtract",
search_aliases=["clip difference", "text encoder subtract"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier):
@classmethod
def execute(cls, clip1, clip2, multiplier) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPAdd:
SEARCH_ALIASES = ["combine clip"]
class CLIPAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeAdd",
search_aliases=["combine clip"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2):
@classmethod
def execute(cls, clip1, clip2) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class ModelMergeBlocks:
class ModelMergeBlocks(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeBlocks",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, **kwargs):
@classmethod
def execute(cls, model1, model2, **kwargs) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
default_ratio = next(iter(kwargs.values()))
@@ -165,7 +217,10 @@ class ModelMergeBlocks:
last_arg_size = len(arg)
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
@@ -226,59 +281,65 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CheckpointSave",
display_name="Save Checkpoint",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.Clip.Input("clip"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, clip, vae, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CLIPSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPSave",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip"),
io.String.Input("filename_prefix", default="clip/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
def execute(cls, clip, filename_prefix) -> io.NodeOutput:
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
output_dir = folder_paths.get_output_directory()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
@@ -295,7 +356,7 @@ class CLIPSave:
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
@@ -303,76 +364,88 @@ class CLIPSave:
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class VAESave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class VAESave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAESave",
category="advanced/model_merging",
inputs=[
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="vae/ComfyUI_vae"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
def execute(cls, vae, filename_prefix) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class ModelSave:
SEARCH_ALIASES = ["export model", "checkpoint save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class ModelSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelSave",
search_aliases=["export model", "checkpoint save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.String.Input("filename_prefix", default="diffusion_models/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
"ModelMergeSubtract": ModelSubtract,
"ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
class ModelMergingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSimple,
ModelMergeBlocks,
ModelSubtract,
ModelAdd,
CheckpointSave,
CLIPMergeSimple,
CLIPSubtract,
CLIPAdd,
CLIPSave,
VAESave,
ModelSave,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}
async def comfy_entrypoint() -> ModelMergingExtension:
return ModelMergingExtension()

View File

@@ -1,356 +1,455 @@
import comfy_extras.nodes_model_merging
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(12):
arg_dict["input_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}.".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}.".format(i), **argument))
for i in range(12):
arg_dict["output_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}.".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeSD2(ModelMergeSD1):
# SD1 and SD2 have the same blocks
@classmethod
def define_schema(cls):
schema = ModelMergeSD1.define_schema()
schema.node_id = "ModelMergeSD2"
return schema
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}".format(i), **argument))
for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return io.Schema(
node_id="ModelMergeSDXL",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(24):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD3_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["init_x_linear."] = argument
arg_dict["positional_encoding"] = argument
arg_dict["cond_seq_linear."] = argument
arg_dict["register_tokens"] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("init_x_linear.", **argument))
inputs.append(io.Float.Input("positional_encoding", **argument))
inputs.append(io.Float.Input("cond_seq_linear.", **argument))
inputs.append(io.Float.Input("register_tokens", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(4):
arg_dict["double_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_layers.{}.".format(i), **argument))
for i in range(32):
arg_dict["single_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_layers.{}.".format(i), **argument))
arg_dict["modF."] = argument
arg_dict["final_linear."] = argument
inputs.append(io.Float.Input("modF.", **argument))
inputs.append(io.Float.Input("final_linear.", **argument))
return io.Schema(
node_id="ModelMergeAuraflow",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("time_in.", **argument))
inputs.append(io.Float.Input("guidance_in", **argument))
inputs.append(io.Float.Input("vector_in.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_blocks.{}.".format(i), **argument))
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeFlux1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(38):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeSD35_Large",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_frequencies."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t5_y_embedder."] = argument
arg_dict["t5_yproj."] = argument
inputs.append(io.Float.Input("pos_frequencies.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t5_y_embedder.", **argument))
inputs.append(io.Float.Input("t5_yproj.", **argument))
for i in range(48):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeMochiPreview",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patchify_proj."] = argument
arg_dict["adaln_single."] = argument
arg_dict["caption_projection."] = argument
inputs.append(io.Float.Input("patchify_proj.", **argument))
inputs.append(io.Float.Input("adaln_single.", **argument))
inputs.append(io.Float.Input("caption_projection.", **argument))
for i in range(28):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["scale_shift_table"] = argument
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("scale_shift_table", **argument))
inputs.append(io.Float.Input("proj_out.", **argument))
return io.Schema(
node_id="ModelMergeLTXV",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(28):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos7B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(36):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
inputs.append(io.Float.Input("patch_embedding.", **argument))
inputs.append(io.Float.Input("time_embedding.", **argument))
inputs.append(io.Float.Input("time_projection.", **argument))
inputs.append(io.Float.Input("text_embedding.", **argument))
inputs.append(io.Float.Input("img_emb.", **argument))
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["head."] = argument
inputs.append(io.Float.Input("head.", **argument))
return io.Schema(
node_id="ModelMergeWAN2_1",
category="advanced/model_merging/model_specific",
description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(36):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embeds."] = argument
arg_dict["img_in."] = argument
arg_dict["txt_norm."] = argument
arg_dict["txt_in."] = argument
arg_dict["time_text_embed."] = argument
inputs.append(io.Float.Input("pos_embeds.", **argument))
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("txt_norm.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
inputs.append(io.Float.Input("time_text_embed.", **argument))
for i in range(60):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("proj_out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeQwenImage",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeAuraflow": ModelMergeAuraflow,
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
"ModelMergeMochiPreview": ModelMergeMochiPreview,
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage,
}
class ModelMergingModelSpecificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSD1,
ModelMergeSD2,
ModelMergeSDXL,
ModelMergeSD3_2B,
ModelMergeAuraflow,
ModelMergeFlux1,
ModelMergeSD35_Large,
ModelMergeMochiPreview,
ModelMergeLTXV,
ModelMergeCosmos7B,
ModelMergeCosmos14B,
ModelMergeWAN2_1,
ModelMergeCosmosPredict2_2B,
ModelMergeCosmosPredict2_14B,
ModelMergeQwenImage,
]
async def comfy_entrypoint() -> ModelMergingModelSpecificExtension:
return ModelMergingModelSpecificExtension()

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CreateList(io.ComfyNode):
@classmethod
def define_schema(cls):
template_matchtype = io.MatchType.Template("type")
template_autogrow = io.Autogrow.TemplatePrefix(
input=io.MatchType.Input("input", template=template_matchtype),
prefix="input",
)
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="logic",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
outputs=[
io.MatchType.Output(
template=template_matchtype,
is_output_list=True,
display_name="list",
),
],
)
@classmethod
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
output_list = []
for input in inputs.values():
output_list += input
return io.NodeOutput(output_list)
class ToolkitExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CreateList,
]
async def comfy_entrypoint() -> ToolkitExtension:
return ToolkitExtension()

View File

@@ -6,44 +6,62 @@ import folder_paths
import comfy_extras.nodes_model_merging
import node_helpers
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ImageOnlyCheckpointLoader:
class ImageOnlyCheckpointLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointLoader",
display_name="Image Only Checkpoint Loader (img2vid model)",
category="loaders/video_models",
inputs=[
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
],
outputs=[
io.Model.Output(),
io.ClipVision.Output(),
io.Vae.Output(),
],
)
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
@classmethod
def execute(cls, ckpt_name, output_vae=True, output_clip=True) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
return io.NodeOutput(out[0], out[3], out[2])
load_checkpoint = execute # TODO: remove
class SVD_img2vid_Conditioning:
class SVD_img2vid_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
def define_schema(cls):
return io.Schema(
node_id="SVD_img2vid_Conditioning",
category="conditioning/video_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("video_frames", default=14, min=1, max=4096),
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023),
io.Int.Input("fps", default=6, min=1, max=1024),
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) -> io.NodeOutput:
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@@ -54,20 +72,28 @@ class SVD_img2vid_Conditioning:
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
return io.NodeOutput(positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
encode = execute # TODO: remove
class VideoLinearCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoLinearCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@@ -78,20 +104,28 @@ class VideoLinearCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class VideoTriangleCFGGuidance:
patch = execute # TODO: remove
class VideoTriangleCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoTriangleCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@@ -105,57 +139,79 @@ class VideoTriangleCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "advanced/model_merging"
patch = execute # TODO: remove
class ImageOnlyCheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointSave",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.ClipVision.Input("clip_vision"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def execute(cls, model, clip_vision, vae, filename_prefix) -> io.NodeOutput:
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
save = execute # TODO: remove
class ConditioningSetAreaPercentageVideo:
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
def define_schema(cls):
return io.Schema(
node_id="ConditioningSetAreaPercentageVideo",
category="conditioning",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("width", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("height", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("temporal", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("y", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("z", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "conditioning"
def append(self, conditioning, width, height, temporal, x, y, z, strength):
@classmethod
def execute(cls, conditioning, width, height, temporal, x, y, z, strength) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
}
class VideoModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ImageOnlyCheckpointLoader,
SVD_img2vid_Conditioning,
VideoLinearCFGGuidance,
VideoTriangleCFGGuidance,
ImageOnlyCheckpointSave,
ConditioningSetAreaPercentageVideo,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}
async def comfy_entrypoint() -> VideoModelExtension:
return VideoModelExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.12.1"
__version__ = "0.12.3"

View File

@@ -192,7 +192,10 @@ import comfy_aimdo.control
import comfy_aimdo.torch
if enables_dynamic_vram():
if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug()
elif args.verbose == 'CRITICAL':
@@ -208,7 +211,7 @@ if enables_dynamic_vram():
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
logging.info("DynamicVRAM support detected and enabled")
else:
logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None

View File

@@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes():
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_lora_debug.py",
"nodes_color.py"
"nodes_color.py",
"nodes_toolkit.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.12.1"
version = "0.12.3"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.37.11
comfyui-frontend-package==1.38.13
comfyui-workflow-templates==0.8.31
comfyui-embedded-docs==0.4.0
torch