Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski
2026-02-17 02:53:06 -08:00
340 changed files with 53273 additions and 14261 deletions

View File

@@ -28,12 +28,45 @@ class TextEncodeAceStepAudio(io.ComfyNode):
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return io.NodeOutput(conditioning)
class TextEncodeAceStepAudio15(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeAceStepAudio1.5",
category="conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
io.Int.Input("bpm", default=120, min=10, max=300),
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
class EmptyAceStepLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyAceStepLatentAudio",
display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio",
inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
@@ -51,12 +84,61 @@ class EmptyAceStepLatentAudio(io.ComfyNode):
return io.NodeOutput({"samples": latent, "type": "audio"})
class EmptyAceStep15LatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyAceStep1.5LatentAudio",
display_name="Empty Ace Step 1.5 Latent Audio",
category="latent/audio",
inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
io.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[io.Latent.Output()],
)
@classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = round((seconds * 48000 / 1920))
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "type": "audio"})
class ReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ReferenceTimbreAudio",
display_name="Reference Audio",
category="advanced/conditioning/audio",
is_experimental=True,
description="This node sets the reference audio for ace step 1.5",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Conditioning.Output(),
]
)
@classmethod
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
return io.NodeOutput(conditioning)
class AceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeAceStepAudio,
EmptyAceStepLatentAudio,
TextEncodeAceStepAudio15,
EmptyAceStep15LatentAudio,
ReferenceAudio,
]
async def comfy_entrypoint() -> AceExtension:

View File

@@ -28,6 +28,7 @@ class AlignYourStepsScheduler(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AlignYourStepsScheduler",
search_aliases=["AYS scheduler"],
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),

View File

@@ -55,7 +55,8 @@ class APG(io.ComfyNode):
def pre_cfg_function(args):
nonlocal running_avg, prev_sigma
if len(args["conds_out"]) == 1: return args["conds_out"]
if len(args["conds_out"]) == 1:
return args["conds_out"]
cond = args["conds_out"][0]
uncond = args["conds_out"][1]

View File

@@ -71,6 +71,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPAttentionMultiply",
search_aliases=["clip attention scale", "text encoder attention"],
category="_for_testing/attention_experiments",
inputs=[
io.Clip.Input("clip"),

View File

@@ -6,279 +6,253 @@ import torch
import comfy.model_management
import folder_paths
import os
import io
import json
import random
import hashlib
import node_helpers
import logging
from comfy.cli_args import args
from comfy.comfy_types import FileLocator
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, UI
class EmptyLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
class EmptyLatentAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="latent/audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[IO.Latent.Output()],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return IO.NodeOutput({"samples":latent, "type": "audio"})
class ConditioningStableAudio:
generate = execute # TODO: remove
class ConditioningStableAudio(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
}}
def define_schema(cls):
return IO.Schema(
node_id="ConditioningStableAudio",
category="conditioning",
inputs=[
IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
],
)
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, positive, negative, seconds_start, seconds_total):
@classmethod
def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
return (positive, negative)
return IO.NodeOutput(positive, negative)
class VAEEncodeAudio:
append = execute # TODO: remove
class VAEEncodeAudio(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
def define_schema(cls):
return IO.Schema(
node_id="VAEEncodeAudio",
search_aliases=["audio to latent"],
display_name="VAE Encode Audio",
category="latent/audio",
inputs=[
IO.Audio.Input("audio"),
IO.Vae.Input("vae"),
],
outputs=[IO.Latent.Output()],
)
CATEGORY = "latent/audio"
def encode(self, vae, audio):
@classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
if vae_sample_rate != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
else:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
return ({"samples":t}, )
return IO.NodeOutput({"samples": t})
class VAEDecodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "decode"
encode = execute # TODO: remove
CATEGORY = "latent/audio"
def decode(self, vae, samples):
def vae_decode_audio(vae, samples, tile=None, overlap=None):
if tile is not None:
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
else:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, )
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
# Prepare metadata dictionary
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class VAEDecodeAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeAudio",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio",
category="latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
],
outputs=[IO.Audio.Output()],
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
def execute(cls, vae, samples) -> IO.NodeOutput:
return IO.NodeOutput(vae_decode_audio(vae, samples))
RETURN_TYPES = ()
FUNCTION = "save_flac"
decode = execute # TODO: remove
OUTPUT_NODE = True
CATEGORY = "audio"
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
class SaveAudioMP3:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class VAEDecodeAudioTiled(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeAudioTiled",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio (Tiled)",
category="latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
],
outputs=[IO.Audio.Output()],
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
RETURN_TYPES = ()
FUNCTION = "save_mp3"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class SaveAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudio",
search_aliases=["export flac"],
display_name="Save Audio (FLAC)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
)
RETURN_TYPES = ()
FUNCTION = "save_opus"
save_flac = execute # TODO: remove
OUTPUT_NODE = True
CATEGORY = "audio"
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
class SaveAudioMP3(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioMP3",
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required":
{"audio": ("AUDIO", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
save_mp3 = execute # TODO: remove
class SaveAudioOpus(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioOpus",
search_aliases=["export opus"],
display_name="Save Audio (Opus)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
save_opus = execute # TODO: remove
class PreviewAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PreviewAudio",
search_aliases=["play audio"],
display_name="Preview Audio",
category="audio",
inputs=[
IO.Audio.Input("audio"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
save_flac = execute # TODO: remove
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
@@ -316,26 +290,31 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
wav = f32_pcm(wav)
return wav, sr
class LoadAudio:
class LoadAudio(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
return IO.Schema(
node_id="LoadAudio",
search_aliases=["import audio", "open audio", "audio file"],
display_name="Load Audio",
category="audio",
inputs=[
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
],
outputs=[IO.Audio.Output()],
)
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
return IO.NodeOutput(audio)
@classmethod
def IS_CHANGED(s, audio):
def fingerprint_inputs(cls, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
@@ -343,46 +322,71 @@ class LoadAudio:
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, audio):
def validate_inputs(cls, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
class RecordAudio:
load = execute # TODO: remove
class RecordAudio(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
def define_schema(cls):
return IO.Schema(
node_id="RecordAudio",
search_aliases=["microphone input", "audio capture", "voice input"],
display_name="Record Audio",
category="audio",
inputs=[
IO.Custom("AUDIO_RECORD").Input("audio"),
],
outputs=[IO.Audio.Output()],
)
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
return IO.NodeOutput(audio)
load = execute # TODO: remove
class TrimAudioDuration:
class TrimAudioDuration(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
},
}
def define_schema(cls):
return IO.Schema(
node_id="TrimAudioDuration",
search_aliases=["cut audio", "audio clip", "shorten audio"],
display_name="Trim Audio Duration",
description="Trim audio tensor into chosen time range.",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.Float.Input(
"start_index",
default=0.0,
min=-0xffffffffffffffff,
max=0xffffffffffffffff,
step=0.01,
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
),
IO.Float.Input(
"duration",
default=60.0,
min=0.0,
step=0.01,
tooltip="Duration in seconds",
),
],
outputs=[IO.Audio.Output()],
)
FUNCTION = "trim"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Trim audio tensor into chosen time range."
def trim(self, audio, start_index, duration):
@classmethod
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
@@ -399,23 +403,31 @@ class TrimAudioDuration:
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
trim = execute # TODO: remove
class SplitAudioChannels:
class SplitAudioChannels(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
}}
def define_schema(cls):
return IO.Schema(
node_id="SplitAudioChannels",
search_aliases=["stereo to mono"],
display_name="Split Audio Channels",
description="Separates the audio into left and right channels.",
category="audio",
inputs=[
IO.Audio.Input("audio"),
],
outputs=[
IO.Audio.Output(display_name="left"),
IO.Audio.Output(display_name="right"),
],
)
RETURN_TYPES = ("AUDIO", "AUDIO")
RETURN_NAMES = ("left", "right")
FUNCTION = "separate"
CATEGORY = "audio"
DESCRIPTION = "Separates the audio into left and right channels."
def separate(self, audio):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
@@ -425,7 +437,61 @@ class SplitAudioChannels:
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
separate = execute # TODO: remove
class JoinAudioChannels(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="JoinAudioChannels",
display_name="Join Audio Channels",
description="Joins left and right mono audio channels into a stereo audio.",
category="audio",
inputs=[
IO.Audio.Input("audio_left"),
IO.Audio.Input("audio_right"),
],
outputs=[
IO.Audio.Output(display_name="audio"),
],
)
@classmethod
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
waveform_left = audio_left["waveform"]
sample_rate_left = audio_left["sample_rate"]
waveform_right = audio_right["waveform"]
sample_rate_right = audio_right["sample_rate"]
if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
raise ValueError("AudioJoin: Both input audios must be mono.")
# Handle different sample rates by resampling to the higher rate
waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
waveform_left, sample_rate_left, waveform_right, sample_rate_right
)
# Handle different lengths by trimming to the shorter length
length_left = waveform_left.shape[-1]
length_right = waveform_right.shape[-1]
if length_left != length_right:
min_length = min(length_left, length_right)
if length_left > min_length:
logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
waveform_left = waveform_left[..., :min_length]
if length_right > min_length:
logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
waveform_right = waveform_right[..., :min_length]
# Join the channels into stereo
left_channel = waveform_left[..., 0:1, :]
right_channel = waveform_right[..., 0:1, :]
stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
@@ -443,21 +509,30 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
return waveform_1, waveform_2, output_sample_rate
class AudioConcat:
class AudioConcat(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
}}
def define_schema(cls):
return IO.Schema(
node_id="AudioConcat",
search_aliases=["join audio", "combine audio", "append audio"],
display_name="Audio Concat",
description="Concatenates the audio1 to audio2 in the specified direction.",
category="audio",
inputs=[
IO.Audio.Input("audio1"),
IO.Audio.Input("audio2"),
IO.Combo.Input(
"direction",
options=['after', 'before'],
default="after",
tooltip="Whether to append audio2 after or before audio1.",
)
],
outputs=[IO.Audio.Output()],
)
RETURN_TYPES = ("AUDIO",)
FUNCTION = "concat"
CATEGORY = "audio"
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
def concat(self, audio1, audio2, direction):
@classmethod
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -477,26 +552,34 @@ class AudioConcat:
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
concat = execute # TODO: remove
class AudioMerge:
class AudioMerge(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
},
}
def define_schema(cls):
return IO.Schema(
node_id="AudioMerge",
search_aliases=["mix audio", "overlay audio", "layer audio"],
display_name="Audio Merge",
description="Combine two audio tracks by overlaying their waveforms.",
category="audio",
inputs=[
IO.Audio.Input("audio1"),
IO.Audio.Input("audio2"),
IO.Combo.Input(
"merge_method",
options=["add", "mean", "subtract", "multiply"],
tooltip="The method used to combine the audio waveforms.",
)
],
outputs=[IO.Audio.Output()],
)
FUNCTION = "merge"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
def merge(self, audio1, audio2, merge_method):
@classmethod
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -530,85 +613,114 @@ class AudioMerge:
if max_val > 1.0:
waveform = waveform / max_val
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
merge = execute # TODO: remove
class AudioAdjustVolume:
class AudioAdjustVolume(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
}}
def define_schema(cls):
return IO.Schema(
node_id="AudioAdjustVolume",
search_aliases=["audio gain", "loudness", "audio level"],
display_name="Audio Adjust Volume",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.Int.Input(
"volume",
default=1,
min=-100,
max=100,
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
)
],
outputs=[IO.Audio.Output()],
)
RETURN_TYPES = ("AUDIO",)
FUNCTION = "adjust_volume"
CATEGORY = "audio"
def adjust_volume(self, audio, volume):
@classmethod
def execute(cls, audio, volume) -> IO.NodeOutput:
if volume == 0:
return (audio,)
return IO.NodeOutput(audio)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
return ({"waveform": waveform, "sample_rate": sample_rate},)
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
adjust_volume = execute # TODO: remove
class EmptyAudio:
class EmptyAudio(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
}}
def define_schema(cls):
return IO.Schema(
node_id="EmptyAudio",
search_aliases=["blank audio"],
display_name="Empty Audio",
category="audio",
inputs=[
IO.Float.Input(
"duration",
default=60.0,
min=0.0,
max=0xffffffffffffffff,
step=0.01,
tooltip="Duration of the empty audio clip in seconds",
),
IO.Int.Input(
"sample_rate",
default=44100,
tooltip="Sample rate of the empty audio clip.",
min=1,
max=192000,
),
IO.Int.Input(
"channels",
default=2,
min=1,
max=2,
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
),
],
outputs=[IO.Audio.Output()],
)
RETURN_TYPES = ("AUDIO",)
FUNCTION = "create_empty_audio"
CATEGORY = "audio"
def create_empty_audio(self, duration, sample_rate, channels):
@classmethod
def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
return ({"waveform": waveform, "sample_rate": sample_rate},)
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
create_empty_audio = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"SaveAudioMP3": SaveAudioMP3,
"SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
"TrimAudioDuration": TrimAudioDuration,
"SplitAudioChannels": SplitAudioChannels,
"AudioConcat": AudioConcat,
"AudioMerge": AudioMerge,
"AudioAdjustVolume": AudioAdjustVolume,
"EmptyAudio": EmptyAudio,
}
class AudioExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
EmptyLatentAudio,
VAEEncodeAudio,
VAEDecodeAudio,
VAEDecodeAudioTiled,
SaveAudio,
SaveAudioMP3,
SaveAudioOpus,
LoadAudio,
PreviewAudio,
ConditioningStableAudio,
RecordAudio,
TrimAudioDuration,
SplitAudioChannels,
JoinAudioChannels,
AudioConcat,
AudioMerge,
AudioAdjustVolume,
EmptyAudio,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentAudio": "Empty Latent Audio",
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
"TrimAudioDuration": "Trim Audio Duration",
"SplitAudioChannels": "Split Audio Channels",
"AudioConcat": "Audio Concat",
"AudioMerge": "Audio Merge",
"AudioAdjustVolume": "Audio Adjust Volume",
"EmptyAudio": "Empty Audio",
}
async def comfy_entrypoint() -> AudioExtension:
return AudioExtension()

View File

@@ -10,6 +10,7 @@ class Canny(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Canny",
search_aliases=["edge detection", "outline", "contour detection", "line art"],
category="image/preprocessors",
inputs=[
io.Image.Input("image"),

View File

@@ -0,0 +1,42 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class ColorToRGBInt(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ColorToRGBInt",
display_name="Color to RGB Int",
category="utils",
description="Convert a color to a RGB integer value.",
inputs=[
io.Color.Input("color"),
],
outputs=[
io.Int.Output(display_name="rgb_int"),
],
)
@classmethod
def execute(
cls,
color: str,
) -> io.NodeOutput:
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
class ColorExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ColorToRGBInt]
async def comfy_entrypoint() -> ColorExtension:
return ColorExtension()

View File

@@ -109,6 +109,7 @@ class PorterDuffImageComposite(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PorterDuffImageComposite",
search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
display_name="Porter-Duff Image Composite",
category="mask/compositing",
inputs=[
@@ -165,6 +166,7 @@ class SplitImageWithAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SplitImageWithAlpha",
search_aliases=["extract alpha", "separate transparency", "remove alpha"],
display_name="Split Image with Alpha",
category="mask/compositing",
inputs=[
@@ -188,6 +190,7 @@ class JoinImageWithAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="JoinImageWithAlpha",
search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
display_name="Join Image with Alpha",
category="mask/compositing",
inputs=[

View File

@@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),
@@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
)
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
context_overlap=context_overlap,
context_stride=context_stride,
closed_loop=closed_loop,
dim=dim)
dim=dim,
freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows
)
# make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model)
if freenoise: # no other use for this wrapper at this time
comfy.context_windows.create_sampler_sample_wrapper(model)
return io.NodeOutput(model)
class WanContextWindowsManualNode(ContextWindowsManualNode):
@@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension):

View File

@@ -1,20 +1,26 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import nodes
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SetUnionControlNetType:
class SetUnionControlNetType(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ),
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
}}
def define_schema(cls):
return io.Schema(
node_id="SetUnionControlNetType",
category="conditioning/controlnet",
inputs=[
io.ControlNet.Input("control_net"),
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
],
outputs=[
io.ControlNet.Output(),
],
)
CATEGORY = "conditioning/controlnet"
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "set_controlnet_type"
def set_controlnet_type(self, control_net, type):
@classmethod
def execute(cls, control_net, type) -> io.NodeOutput:
control_net = control_net.copy()
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
if type_number >= 0:
@@ -22,27 +28,37 @@ class SetUnionControlNetType:
else:
control_net.set_extra_arg("control_type", [])
return (control_net,)
return io.NodeOutput(control_net)
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
set_controlnet_type = execute # TODO: remove
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
def define_schema(cls):
return io.Schema(
node_id="ControlNetInpaintingAliMamaApply",
search_aliases=["masked controlnet"],
category="conditioning/controlnet",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.ControlNet.Input("control_net"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Mask.Input("mask"),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
@classmethod
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
@@ -50,11 +66,20 @@ class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
return io.NodeOutput(result[0], result[1])
apply_inpaint_controlnet = execute # TODO: remove
class ControlNetExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SetUnionControlNetType,
ControlNetInpaintingAliMamaApply,
]
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
}
async def comfy_entrypoint() -> ControlNetExtension:
return ControlNetExtension()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ class DifferentialDiffusion(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="DifferentialDiffusion",
search_aliases=["inpaint gradient", "variable denoise strength"],
display_name="Differential Diffusion",
category="_for_testing",
inputs=[

View File

@@ -9,15 +9,23 @@ if TYPE_CHECKING:
from uuid import UUID
def _extract_tensor(data, output_channels):
"""Extract tensor from data, handling both single tensors and lists."""
if isinstance(data, list):
# LTX2 AV tensors: [video, audio]
return data[0][:, :output_channels], data[1][:, :output_channels]
return data[:, :output_channels], None
def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"]
x, ax = _extract_tensor(args[0], easycache.output_channels)
sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@@ -29,11 +37,17 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
do_easycache = easycache.should_do_easycache(sigmas)
if do_easycache:
easycache.check_metadata(x)
# if there isn't a cache diff for current conds, we cannot skip this step
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
# if first cond marked this step for skipping, skip it and use appropriate cached values
if easycache.skip_current_step:
if easycache.skip_current_step and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
return easycache.apply_cache_diff(x, uuids)
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
if easycache.initial_step:
easycache.first_cond_uuid = uuids[0]
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
@@ -44,18 +58,23 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
return easycache.apply_cache_diff(x, uuids)
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
else:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
full_output: torch.Tensor = executor(*args, **kwargs)
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
if has_first_cond_uuid and easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
@@ -72,22 +91,24 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev, uuids)
if audio_output is not None:
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
if has_first_cond_uuid:
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return output
return full_output
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
timestep: float = args[1]
model_options: dict[str] = args[2]
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
x: torch.Tensor = args[0][:, :easycache.output_channels]
# prepare next x_prev
next_x_prev = x
input_change = None
@@ -173,7 +194,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
class EasyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
self.name = "EasyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
@@ -195,6 +216,7 @@ class EasyCacheHolder:
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
@@ -202,6 +224,7 @@ class EasyCacheHolder:
self.allow_mismatch = True
self.cut_from_start = True
self.state_metadata = None
self.output_channels = output_channels
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
@@ -239,18 +262,24 @@ class EasyCacheHolder:
return to_return.clone()
return to_return
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
if self.first_cond_uuid in uuids and not is_audio:
self.total_steps_skipped += 1
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# slice out only what is relevant to this cond
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_this_dim = True
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
if skip_this_dim:
skip_this_dim = False
continue
@@ -261,12 +290,12 @@ class EasyCacheHolder:
slicing.append(slice(None, dim_u))
else:
slicing.append(slice(None))
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
x = x[slicing]
x += self.uuid_cache_diffs[uuid].to(x.device)
batch_slice = batch_slice + slicing
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if output.shape[1:] != x.shape[1:]:
if not self.allow_mismatch:
@@ -282,11 +311,11 @@ class EasyCacheHolder:
else:
slicing.append(slice(None))
skip_dim = False
x = x[slicing]
x = x[tuple(slicing)]
diff = output - x
batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids
@@ -317,12 +346,14 @@ class EasyCacheHolder:
self.output_prev_norm = None
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
del self.uuid_cache_diffs_audio
self.uuid_cache_diffs_audio = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self
def clone(self):
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
class EasyCacheNode(io.ComfyNode):
@@ -349,7 +380,7 @@ class EasyCacheNode(io.ComfyNode):
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
@@ -357,7 +388,7 @@ class EasyCacheNode(io.ComfyNode):
class LazyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
self.name = "LazyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
@@ -381,6 +412,7 @@ class LazyCacheHolder:
self.approx_output_change_rates = []
self.total_steps_skipped = 0
self.state_metadata = None
self.output_channels = output_channels
def has_cache_diff(self) -> bool:
return self.cache_diff is not None
@@ -455,7 +487,7 @@ class LazyCacheHolder:
return self
def clone(self):
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
class LazyCacheNode(io.ComfyNode):
@classmethod
@@ -481,7 +513,7 @@ class LazyCacheNode(io.ComfyNode):
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
return io.NodeOutput(model)

View File

@@ -2,7 +2,10 @@ import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import math
import nodes
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
@@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode):
encode = execute # TODO: remove
class EmptyFlux2LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyFlux2LatentImage",
display_name="Empty Flux 2 Latent",
category="latent",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class FluxGuidance(io.ComfyNode):
@classmethod
@@ -130,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
display_name="Edit Model Reference Method",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
),
],
outputs=[
@@ -154,6 +179,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
append = execute # TODO: remove
def generalized_time_snr_shift(t, mu: float, sigma: float):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps
class Flux2Scheduler(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Flux2Scheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=4096),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Sigmas.Output(),
],
)
@classmethod
def execute(cls, steps, width, height) -> io.NodeOutput:
seq_len = (width * height / (16 * 16))
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
class FluxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -163,6 +240,8 @@ class FluxExtension(ComfyExtension):
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
]

View File

@@ -2,6 +2,8 @@
import torch
import logging
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
def Fourier_filter(x, threshold, scale):
# FFT
@@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
return x_filtered.to(x.dtype)
class FreeU:
class FreeU(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return IO.Schema(
node_id="FreeU",
category="model_patches/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2):
@classmethod
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@@ -59,23 +66,31 @@ class FreeU:
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
return (m, )
return IO.NodeOutput(m)
class FreeU_V2:
patch = execute # TODO: remove
class FreeU_V2(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return IO.Schema(
node_id="FreeU_V2",
category="model_patches/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2):
@classmethod
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@@ -105,9 +120,19 @@ class FreeU_V2:
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
return (m, )
return IO.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"FreeU": FreeU,
"FreeU_V2": FreeU_V2,
}
patch = execute # TODO: remove
class FreelunchExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
FreeU,
FreeU_V2,
]
async def comfy_entrypoint() -> FreelunchExtension:
return FreelunchExtension()

View File

@@ -58,6 +58,7 @@ class FreSca(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FreSca",
search_aliases=["frequency guidance"],
display_name="FreSca",
category="_for_testing",
description="Applies frequency-dependent scaling to the guidance",

View File

@@ -38,6 +38,7 @@ class CLIPTextEncodeHiDream(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHiDream",
search_aliases=["hidream prompt"],
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),

View File

@@ -259,6 +259,7 @@ class SetClipHooks:
return (clip,)
class ConditioningTimestepsRange:
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
NodeId = 'ConditioningTimestepsRange'
NodeName = 'Timesteps Range'
@classmethod
@@ -468,6 +469,7 @@ class SetHookKeyframes:
return (hooks,)
class CreateHookKeyframe:
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
NodeId = 'CreateHookKeyframe'
NodeName = 'Create Hook Keyframe'
@classmethod
@@ -497,6 +499,7 @@ class CreateHookKeyframe:
return (prev_hook_kf,)
class CreateHookKeyframesInterpolated:
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
NodeId = 'CreateHookKeyframesInterpolated'
NodeName = 'Create Hook Keyframes Interp.'
@classmethod
@@ -544,6 +547,7 @@ class CreateHookKeyframesInterpolated:
return (prev_hook_kf,)
class CreateHookKeyframesFromFloats:
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
NodeId = 'CreateHookKeyframesFromFloats'
NodeName = 'Create Hook Keyframes From Floats'
@classmethod
@@ -618,6 +622,7 @@ class SetModelHooksOnCond:
# Combine Hooks
#------------------------------------------
class CombineHooks:
SEARCH_ALIASES = ["merge hooks"]
NodeId = 'CombineHooks2'
NodeName = 'Combine Hooks [2]'
@classmethod

View File

@@ -4,7 +4,10 @@ import torch
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
import folder_paths
import json
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
@classmethod
@@ -37,6 +40,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="EmptyHunyuanLatentVideo",
display_name="Empty HunyuanVideo 1.0 Latent",
category="latent/video",
inputs=[
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
@@ -52,11 +56,208 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
generate = execute # TODO: remove
class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
@classmethod
def define_schema(cls):
schema = super().define_schema()
schema.node_id = "EmptyHunyuanVideo15Latent"
schema.display_name = "Empty HunyuanVideo 1.5 Latent"
return schema
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
# Using scale factor of 16 instead of 8
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
class HunyuanVideo15ImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15ImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encoded = vae.encode(start_image[:, :, :, :3])
concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device())
concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class HunyuanVideo15SuperResolution(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15SuperResolution",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae", optional=True),
io.Image.Input("start_image", optional=True),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Latent.Input("latent"),
io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
in_latent = latent["samples"]
in_channels = in_latent.shape[1]
cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device())
cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent
cond_latent[:, 2 * in_channels + 1] = 1
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1)
encoded = vae.encode(start_image[:, :, :, :3])
cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded
cond_latent[:, in_channels + 1, 0] = 1
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
return io.NodeOutput(positive, negative, latent)
class LatentUpscaleModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentUpscaleModelLoader",
display_name="Load Latent Upscale Model",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")),
],
outputs=[
io.LatentUpscaleModel.Output(),
],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
if "blocks.0.block.0.conv.weight" in sd:
config = {
"in_channels": sd["in_conv.conv.weight"].shape[1],
"out_channels": sd["out_conv.conv.weight"].shape[0],
"hidden_channels": sd["in_conv.conv.weight"].shape[0],
"num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]),
"global_residual": False,
}
model_type = "720p"
model = HunyuanVideo15SRModel(model_type, config)
model.load_sd(sd)
elif "up.0.block.0.conv1.conv.weight" in sd:
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
config = {
"z_channels": sd["conv_in.conv.weight"].shape[1],
"out_channels": sd["conv_out.conv.weight"].shape[0],
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
}
model_type = "1080p"
model = HunyuanVideo15SRModel(model_type, config)
model.load_sd(sd)
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
config = json.loads(metadata["config"])
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
model.load_state_dict(sd)
return io.NodeOutput(model)
class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15LatentUpscaleWithModel",
display_name="Hunyuan Video 15 Latent Upscale With Model",
category="latent",
inputs=[
io.LatentUpscaleModel.Input("model"),
io.Latent.Input("samples"),
io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"),
io.Int.Input("width", default=1280, min=0, max=16384, step=8),
io.Int.Input("height", default=720, min=0, max=16384, step=8),
io.Combo.Input("crop", options=["disabled", "center"]),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput:
if width == 0 and height == 0:
return io.NodeOutput(samples)
else:
if width == 0:
height = max(64, height)
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
elif height == 0:
width = max(64, width)
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
else:
width = max(64, width)
height = max(64, height)
s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop)
s = model.resample_latent(s)
return io.NodeOutput({"samples": s.cpu().float()})
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
@@ -210,6 +411,11 @@ class HunyuanExtension(ComfyExtension):
CLIPTextEncodeHunyuanDiT,
TextEncodeHunyuanVideo_ImageToVideo,
EmptyHunyuanLatentVideo,
EmptyHunyuanVideo15Latent,
HunyuanVideo15ImageToVideo,
HunyuanVideo15SuperResolution,
HunyuanVideo15LatentUpscaleWithModel,
LatentUpscaleModelLoader,
HunyuanImageToVideo,
EmptyHunyuanImageLatent,
HunyuanRefinerLatent,

View File

@@ -7,63 +7,79 @@ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_fro
import folder_paths
import comfy.model_management
from comfy.cli_args import args
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
class EmptyLatentHunyuan3Dv2:
class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}
}
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentHunyuan3Dv2",
category="latent/3d",
inputs=[
IO.Int.Input("resolution", default=3072, min=1, max=8192),
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
],
outputs=[
IO.Latent.Output(),
]
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/3d"
def generate(self, resolution, batch_size):
@classmethod
def execute(cls, resolution, batch_size) -> IO.NodeOutput:
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"})
class Hunyuan3Dv2Conditioning:
generate = execute # TODO: remove
class Hunyuan3Dv2Conditioning(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
}}
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2Conditioning",
category="conditioning/video_models",
inputs=[
IO.ClipVisionOutput.Input("clip_vision_output"),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
]
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision_output):
@classmethod
def execute(cls, clip_vision_output) -> IO.NodeOutput:
embeds = clip_vision_output.last_hidden_state
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
return IO.NodeOutput(positive, negative)
encode = execute # TODO: remove
class Hunyuan3Dv2ConditioningMultiView:
class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {},
"optional": {"front": ("CLIP_VISION_OUTPUT",),
"left": ("CLIP_VISION_OUTPUT",),
"back": ("CLIP_VISION_OUTPUT",),
"right": ("CLIP_VISION_OUTPUT",), }}
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2ConditioningMultiView",
category="conditioning/video_models",
inputs=[
IO.ClipVisionOutput.Input("front", optional=True),
IO.ClipVisionOutput.Input("left", optional=True),
IO.ClipVisionOutput.Input("back", optional=True),
IO.ClipVisionOutput.Input("right", optional=True),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
]
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, front=None, left=None, back=None, right=None):
@classmethod
def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput:
all_embeds = [front, left, back, right]
out = []
pos_embeds = None
@@ -76,29 +92,35 @@ class Hunyuan3Dv2ConditioningMultiView:
embeds = torch.cat(out, dim=1)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
return IO.NodeOutput(positive, negative)
encode = execute # TODO: remove
class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
class VAEDecodeHunyuan3D(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ),
"vae": ("VAE", ),
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
}}
RETURN_TYPES = ("VOXEL",)
FUNCTION = "decode"
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeHunyuan3D",
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Int.Input("num_chunks", default=8000, min=1000, max=500000),
IO.Int.Input("octree_resolution", default=256, min=16, max=512),
],
outputs=[
IO.Voxel.Output(),
]
)
CATEGORY = "latent/3d"
@classmethod
def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput:
voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return IO.NodeOutput(voxels)
decode = execute # TODO: remove
def decode(self, vae, samples, num_chunks, octree_resolution):
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
@@ -396,24 +418,24 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
return final_vertices, faces
class MESH:
def __init__(self, vertices, faces):
self.vertices = vertices
self.faces = faces
class VoxelToMeshBasic:
class VoxelToMeshBasic(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMeshBasic",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
IO.Mesh.Output(),
]
)
CATEGORY = "3d"
def decode(self, voxel, threshold):
@classmethod
def execute(cls, voxel, threshold) -> IO.NodeOutput:
vertices = []
faces = []
for x in voxel.data:
@@ -421,21 +443,29 @@ class VoxelToMeshBasic:
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
class VoxelToMesh:
decode = execute # TODO: remove
class VoxelToMesh(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"algorithm": (["surface net", "basic"], ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMesh",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
IO.Mesh.Output(),
]
)
CATEGORY = "3d"
def decode(self, voxel, algorithm, threshold):
@classmethod
def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput:
vertices = []
faces = []
@@ -449,7 +479,9 @@ class VoxelToMesh:
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
decode = execute # TODO: remove
def save_glb(vertices, faces, filepath, metadata=None):
@@ -581,50 +613,84 @@ def save_glb(vertices, faces, filepath, metadata=None):
return filepath
class SaveGLB:
class SaveGLB(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"mesh": ("MESH", ),
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
display_name="Save 3D Model",
search_aliases=["export 3d model", "save mesh"],
category="3d",
is_output_node=True,
inputs=[
IO.MultiType.Input(
IO.Mesh.Input("mesh"),
types=[
IO.File3DGLB,
IO.File3DGLTF,
IO.File3DOBJ,
IO.File3DFBX,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DAny,
],
tooltip="Mesh or 3D file to save",
),
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
)
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "3d"
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
@classmethod
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
if isinstance(mesh, Types.File3D):
# Handle File3D input - save BytesIO data to output folder
ext = mesh.format or "glb"
f = f"{filename}_{counter:05}_.{ext}"
mesh.save_to(os.path.join(full_output_folder, f))
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return {"ui": {"3d": results}}
else:
# Handle Mesh input - save vertices and faces as GLB
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return IO.NodeOutput(ui={"3d": results})
NODE_CLASS_MAPPINGS = {
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
"VoxelToMeshBasic": VoxelToMeshBasic,
"VoxelToMesh": VoxelToMesh,
"SaveGLB": SaveGLB,
}
class Hunyuan3dExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
EmptyLatentHunyuan3Dv2,
Hunyuan3Dv2Conditioning,
Hunyuan3Dv2ConditioningMultiView,
VAEDecodeHunyuan3D,
VoxelToMeshBasic,
VoxelToMesh,
SaveGLB,
]
async def comfy_entrypoint() -> Hunyuan3dExtension:
return Hunyuan3dExtension()

View File

@@ -2,6 +2,9 @@ import comfy.utils
import folder_paths
import torch
import logging
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
@@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
return hypernetwork_patch(out, strength)
class HypernetworkLoader:
class HypernetworkLoader(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork"
def define_schema(cls):
return IO.Schema(
node_id="HypernetworkLoader",
category="loaders",
inputs=[
IO.Model.Input("model"),
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
@classmethod
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return (model_hypernetwork,)
return IO.NodeOutput(model_hypernetwork)
NODE_CLASS_MAPPINGS = {
"HypernetworkLoader": HypernetworkLoader
}
load_hypernetwork = execute # TODO: remove
class HyperNetworkExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
HypernetworkLoader,
]
async def comfy_entrypoint() -> HyperNetworkExtension:
return HyperNetworkExtension()

View File

@@ -0,0 +1,53 @@
import nodes
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
class ImageCompare(IO.ComfyNode):
"""Compares two images with a slider interface."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageCompare",
display_name="Image Compare",
description="Compares two images side by side with a slider.",
category="image",
is_experimental=True,
is_output_node=True,
inputs=[
IO.Image.Input("image_a", optional=True),
IO.Image.Input("image_b", optional=True),
IO.ImageCompare.Input("compare_view"),
],
outputs=[],
)
@classmethod
def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput:
result = {"a_images": [], "b_images": []}
preview_node = nodes.PreviewImage()
if image_a is not None and len(image_a) > 0:
saved = preview_node.save_images(image_a, "comfy.compare.a")
result["a_images"] = saved["ui"]["images"]
if image_b is not None and len(image_b) > 0:
saved = preview_node.save_images(image_b, "comfy.compare.b")
result["b_images"] = saved["ui"]["images"]
return IO.NodeOutput(ui=result)
class ImageCompareExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCompare,
]
async def comfy_entrypoint() -> ImageCompareExtension:
return ImageCompareExtension()

View File

@@ -2,280 +2,235 @@ from __future__ import annotations
import nodes
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
import re
from io import BytesIO
from inspect import cleandoc
import torch
import comfy.utils
from comfy.comfy_types import FileLocator, IO
from server import PromptServer
from comfy_api.latest import ComfyExtension, IO, UI
from typing_extensions import override
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
class ImageCrop(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
def define_schema(cls):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform"
def crop(self, image, width, height, x, y):
@classmethod
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return (img,)
return IO.NodeOutput(img)
class RepeatImageBatch:
crop = execute # TODO: remove
class RepeatImageBatch(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
def define_schema(cls):
return IO.Schema(
node_id="RepeatImageBatch",
search_aliases=["duplicate image", "clone image"],
category="image/batch",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("amount", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch"
def repeat(self, image, amount):
@classmethod
def execute(cls, image, amount) -> IO.NodeOutput:
s = image.repeat((amount, 1,1,1))
return (s,)
return IO.NodeOutput(s)
class ImageFromBatch:
repeat = execute # TODO: remove
class ImageFromBatch(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch"
def define_schema(cls):
return IO.Schema(
node_id="ImageFromBatch",
search_aliases=["select image", "pick from batch", "extract image"],
category="image/batch",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("batch_index", default=0, min=0, max=4095),
IO.Int.Input("length", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch"
def frombatch(self, image, batch_index, length):
@classmethod
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone()
return (s,)
return IO.NodeOutput(s)
frombatch = execute # TODO: remove
class ImageAddNoise:
class ImageAddNoise(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
def define_schema(cls):
return IO.Schema(
node_id="ImageAddNoise",
search_aliases=["film grain"],
category="image",
inputs=[
IO.Image.Input("image"),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image"
def repeat(self, image, seed, strength):
@classmethod
def execute(cls, image, seed, strength) -> IO.NodeOutput:
generator = torch.manual_seed(seed)
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
return (s,)
return IO.NodeOutput(s)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
repeat = execute # TODO: remove
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results: list[FileLocator] = []
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class SaveAnimatedWEBP(IO.ComfyNode):
COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
def define_schema(cls):
return IO.Schema(
node_id="SaveAnimatedWEBP",
category="image/animation",
inputs=[
IO.Image.Input("images"),
IO.String.Input("filename_prefix", default="ComfyUI"),
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
IO.Boolean.Input("lossless", default=True),
IO.Int.Input("quality", default=80, min=0, max=100),
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
RETURN_TYPES = ()
FUNCTION = "save_images"
@classmethod
def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=cls.COMPRESS_METHODS.get(method)
)
)
OUTPUT_NODE = True
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
save_images = execute # TODO: remove
class ImageStitch:
class SaveAnimatedPNG(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAnimatedPNG",
category="image/animation",
inputs=[
IO.Image.Input("images"),
IO.String.Input("filename_prefix", default="ComfyUI"),
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
IO.Int.Input("compress_level", default=4, min=0, max=9),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
compress_level=compress_level,
)
)
save_images = execute # TODO: remove
class ImageStitch(IO.ComfyNode):
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageStitch",
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
display_name="Image Stitch",
description="Stitches image2 to image1 in the specified direction.\n"
"If image2 is not provided, returns image1 unchanged.\n"
"Optional spacing can be added between images.",
category="image/transform",
inputs=[
IO.Image.Input("image1"),
IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
IO.Boolean.Input("match_image_size", default=True),
IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
IO.Image.Input("image2", optional=True),
],
outputs=[IO.Image.Output()],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"direction": (["right", "down", "left", "up"], {"default": "right"}),
"match_image_size": ("BOOLEAN", {"default": True}),
"spacing_width": (
"INT",
{"default": 0, "min": 0, "max": 1024, "step": 2},
),
"spacing_color": (
["white", "black", "red", "green", "blue"],
{"default": "white"},
),
},
"optional": {
"image2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stitch"
CATEGORY = "image/transform"
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
def execute(
cls,
image1,
direction,
match_image_size,
spacing_width,
spacing_color,
image2=None,
):
) -> IO.NodeOutput:
if image2 is None:
return (image1,)
return IO.NodeOutput(image1)
# Handle batch size differences
if image1.shape[0] != image2.shape[0]:
@@ -412,36 +367,30 @@ Optional spacing can be added between images.
images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
return IO.NodeOutput(torch.cat(images, dim=concat_dim))
class ResizeAndPadImage:
stitch = execute # TODO: remove
class ResizeAndPadImage(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"target_width": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"target_height": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"padding_color": (["white", "black"],),
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
}
}
def define_schema(cls):
return IO.Schema(
node_id="ResizeAndPadImage",
search_aliases=["fit to size"],
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Combo.Input("padding_color", options=["white", "black"]),
IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
],
outputs=[IO.Image.Output()],
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "resize_and_pad"
CATEGORY = "image/transform"
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
@classmethod
def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
batch_size, orig_height, orig_width, channels = image.shape
scale_w = target_width / orig_width
@@ -469,52 +418,47 @@ class ResizeAndPadImage:
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
output = padded.permute(0, 2, 3, 1)
return (output,)
return IO.NodeOutput(output)
class SaveSVGNode:
"""
Save SVG files on disk.
"""
resize_and_pad = execute # TODO: remove
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = ()
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
class SaveSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveSVGNode",
search_aliases=["export vector", "save vector graphics"],
description="Save SVG files on disk.",
category="image/save",
inputs=[
IO.SVG.Input("svg"),
IO.String.Input(
"filename_prefix",
default="svg/ComfyUI",
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"svg": ("SVG",), # Changed
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results: list[UI.SavedResult] = []
# Prepare metadata JSON
metadata_dict = {}
if prompt is not None:
metadata_dict["prompt"] = prompt
if extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo)
if cls.hidden.prompt is not None:
metadata_dict["prompt"] = cls.hidden.prompt
if cls.hidden.extra_pnginfo is not None:
metadata_dict.update(cls.hidden.extra_pnginfo)
# Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
@@ -544,57 +488,64 @@ class SaveSVGNode:
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8'))
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
counter += 1
return { "ui": { "images": results } }
return IO.NodeOutput(ui={"images": results})
class GetImageSize:
save_svg = execute # TODO: remove
class GetImageSize(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GetImageSize",
search_aliases=["dimensions", "resolution", "image info"],
display_name="Get Image Size",
description="Returns width and height of the image, and passes it through unchanged.",
category="image",
inputs=[
IO.Image.Input("image"),
],
outputs=[
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
IO.Int.Output(display_name="batch_size"),
],
hidden=[IO.Hidden.unique_id],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
},
"hidden": {
"unique_id": "UNIQUE_ID",
}
}
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
RETURN_NAMES = ("width", "height", "batch_size")
FUNCTION = "get_size"
CATEGORY = "image"
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
def get_size(self, image, unique_id=None) -> tuple[int, int]:
def execute(cls, image) -> IO.NodeOutput:
height = image.shape[1]
width = image.shape[2]
batch_size = image.shape[0]
# Send progress text to display size on the node
if unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
return width, height, batch_size
return IO.NodeOutput(width, height, batch_size)
class ImageRotate:
get_size = execute # TODO: remove
class ImageRotate(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": (IO.IMAGE,),
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
}}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "rotate"
def define_schema(cls):
return IO.Schema(
node_id="ImageRotate",
search_aliases=["turn", "flip orientation"],
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform"
def rotate(self, image, rotation):
@classmethod
def execute(cls, image, rotation) -> IO.NodeOutput:
rotate_by = 0
if rotation.startswith("90"):
rotate_by = 1
@@ -604,41 +555,57 @@ class ImageRotate:
rotate_by = 3
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
return (image,)
return IO.NodeOutput(image)
class ImageFlip:
rotate = execute # TODO: remove
class ImageFlip(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": (IO.IMAGE,),
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
}}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "flip"
def define_schema(cls):
return IO.Schema(
node_id="ImageFlip",
search_aliases=["mirror", "reflect"],
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform"
def flip(self, image, flip_method):
@classmethod
def execute(cls, image, flip_method) -> IO.NodeOutput:
if flip_method.startswith("x"):
image = torch.flip(image, dims=[1])
elif flip_method.startswith("y"):
image = torch.flip(image, dims=[2])
return (image,)
return IO.NodeOutput(image)
class ImageScaleToMaxDimension:
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
flip = execute # TODO: remove
class ImageScaleToMaxDimension(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"upscale_method": (s.upscale_methods,),
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
def define_schema(cls):
return IO.Schema(
node_id="ImageScaleToMaxDimension",
category="image/upscaling",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input(
"upscale_method",
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
),
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, largest_size):
@classmethod
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
height = image.shape[1]
width = image.shape[2]
@@ -655,20 +622,30 @@ class ImageScaleToMaxDimension:
samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1)
return (s,)
return IO.NodeOutput(s)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,
"ImageAddNoise": ImageAddNoise,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
"ResizeAndPadImage": ResizeAndPadImage,
"GetImageSize": GetImageSize,
"ImageRotate": ImageRotate,
"ImageFlip": ImageFlip,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
}
upscale = execute # TODO: remove
class ImagesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCrop,
RepeatImageBatch,
ImageFromBatch,
ImageAddNoise,
SaveAnimatedWEBP,
SaveAnimatedPNG,
SaveSVGNode,
ImageStitch,
ResizeAndPadImage,
GetImageSize,
ImageRotate,
ImageFlip,
ImageScaleToMaxDimension,
]
async def comfy_entrypoint() -> ImagesExtension:
return ImagesExtension()

View File

@@ -0,0 +1,137 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class Kandinsky5ImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Kandinsky5ImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond_latent_out = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encoded = vae.encode(start_image[:, :, :, :3])
cond_latent_out["samples"] = encoded
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
# normalization
normalized = (source - source_mean) / (source_std + 1e-8)
normalized = normalized * reference_std + reference_mean
return normalized
class NormalizeVideoLatentStart(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="NormalizeVideoLatentStart",
category="conditioning/video_models",
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
inputs=[
io.Latent.Input("latent"),
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
if latent["samples"].shape[2] <= 1:
return io.NodeOutput(latent)
s = latent.copy()
samples = latent["samples"].clone()
first_frames = samples[:, :, :start_frame_count]
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
samples[:, :, :start_frame_count] = normalized_first_frames
s["samples"] = samples
return io.NodeOutput(s)
class CLIPTextEncodeKandinsky5(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeKandinsky5",
search_aliases=["kandinsky prompt"],
category="advanced/conditioning/kandinsky5",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
tokens = clip.tokenize(clip_l)
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class Kandinsky5Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Kandinsky5ImageToVideo,
NormalizeVideoLatentStart,
CLIPTextEncodeKandinsky5,
]
async def comfy_entrypoint() -> Kandinsky5Extension:
return Kandinsky5Extension()

View File

@@ -4,7 +4,8 @@ import torch
import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import logging
import math
def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]:
@@ -20,6 +21,7 @@ class LatentAdd(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentAdd",
search_aliases=["combine latents", "sum latents"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@@ -46,6 +48,7 @@ class LatentSubtract(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentSubtract",
search_aliases=["difference latent", "remove features"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@@ -72,6 +75,7 @@ class LatentMultiply(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentMultiply",
search_aliases=["scale latent", "amplify latent", "latent gain"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@@ -95,6 +99,7 @@ class LatentInterpolate(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentInterpolate",
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@@ -133,6 +138,7 @@ class LatentConcat(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentConcat",
search_aliases=["join latents", "stitch latents"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@@ -172,6 +178,7 @@ class LatentCut(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentCut",
search_aliases=["crop latent", "slice latent", "extract region"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@@ -207,12 +214,56 @@ class LatentCut(io.ComfyNode):
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return io.NodeOutput(samples_out)
class LatentCutToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentCutToBatch",
search_aliases=["slice to batch", "split latent", "tile latent"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["t", "x", "y"]),
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
if "x" in dim:
dim = s1.ndim - 1
elif "y" in dim:
dim = s1.ndim - 2
elif "t" in dim:
dim = s1.ndim - 3
if dim < 2:
return io.NodeOutput(samples)
s = s1.movedim(dim, 1)
if s.shape[1] < slice_size:
slice_size = s.shape[1]
elif s.shape[1] % slice_size != 0:
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
new_shape = [-1, slice_size] + list(s.shape[2:])
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
return io.NodeOutput(samples_out)
class LatentBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentBatch",
search_aliases=["combine latents", "merge latents", "join latents"],
category="latent/batch",
is_deprecated=True,
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
@@ -267,6 +318,7 @@ class LatentApplyOperation(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperation",
search_aliases=["transform latent"],
category="latent/advanced/operations",
is_experimental=True,
inputs=[
@@ -322,6 +374,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentOperationTonemapReinhard",
search_aliases=["hdr latent"],
category="latent/advanced/operations",
is_experimental=True,
inputs=[
@@ -338,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
dims = list(range(1, latent_vector_magnitude.ndim))
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
top = (std * 5 + mean) * multiplier
@@ -388,6 +442,42 @@ class LatentOperationSharpen(io.ComfyNode):
return luminance * sharpened
return io.NodeOutput(sharpen)
class ReplaceVideoLatentFrames(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ReplaceVideoLatentFrames",
category="latent/batch",
inputs=[
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, destination, index, source=None) -> io.NodeOutput:
if source is None:
return io.NodeOutput(destination)
dest_frames = destination["samples"].shape[2]
source_frames = source["samples"].shape[2]
if index < 0:
index = dest_frames + index
if index > dest_frames:
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
return io.NodeOutput(destination)
if index + source_frames > dest_frames:
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
return io.NodeOutput(destination)
s = source.copy()
s_source = source["samples"]
s_destination = destination["samples"].clone()
s_destination[:, :, index:index + s_source.shape[2]] = s_source
s["samples"] = s_destination
return io.NodeOutput(s)
class LatentExtension(ComfyExtension):
@override
@@ -399,12 +489,14 @@ class LatentExtension(ComfyExtension):
LatentInterpolate,
LatentConcat,
LatentCut,
LatentCutToBatch,
LatentBatch,
LatentBatchSeedBehavior,
LatentApplyOperation,
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
ReplaceVideoLatentFrames
]

View File

@@ -1,9 +1,10 @@
import nodes
import folder_paths
import os
import uuid
from comfy.comfy_types import IO
from comfy_api.input_impl import VideoFromFile
from typing_extensions import override
from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
from pathlib import Path
@@ -11,9 +12,9 @@ from pathlib import Path
def normalize_path(path):
return path.replace('\\', '/')
class Load3D():
class Load3D(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
@@ -24,77 +25,32 @@ class Load3D():
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
]
return IO.Schema(
node_id="Load3D",
display_name="Load 3D & Animation",
category="3d",
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
IO.Load3D.Input("image"),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.Image.Output(display_name="image"),
IO.Mask.Output(display_name="mask"),
IO.String.Output(display_name="mesh_path"),
IO.Image.Output(display_name="normal"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Video.Output(display_name="recording_video"),
IO.File3DAny.Output(display_name="model_3d"),
],
)
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
"image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation():
@classmethod
def INPUT_TYPES(s):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
"image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
@@ -109,74 +65,66 @@ class Load3DAnimation():
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
video = InputImpl.VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
class Preview3D():
process = execute # TODO: remove
class Preview3D(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
def define_schema(cls):
return IO.Schema(
node_id="Preview3D",
search_aliases=["view mesh", "3d viewer"],
display_name="Preview 3D & Animation",
category="3d",
is_experimental=True,
is_output_node=True,
inputs=[
IO.MultiType.Input(
IO.String.Input("model_file", default="", multiline=False),
types=[
IO.File3DGLB,
IO.File3DGLTF,
IO.File3DFBX,
IO.File3DOBJ,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DAny,
],
tooltip="3D model file or path string",
),
IO.Load3DCamera.Input("camera_info", optional=True),
IO.Image.Input("bg_image", optional=True),
],
outputs=[],
)
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
}
}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
if isinstance(model_file, Types.File3D):
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
else:
filename = model_file
camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
return {
"ui": {
"result": [model_file, camera_info]
}
}
process = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,
"Load3DAnimation": Load3DAnimation,
"Preview3D": Preview3D,
"Preview3DAnimation": Preview3DAnimation
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D",
"Load3DAnimation": "Load 3D - Animation",
"Preview3D": "Preview 3D",
"Preview3DAnimation": "Preview 3D - Animation"
}
class Load3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Load3D,
Preview3D,
]
async def comfy_entrypoint() -> Load3DExtension:
return Load3DExtension()

274
comfy_extras/nodes_logic.py Normal file
View File

@@ -0,0 +1,274 @@
from __future__ import annotations
from typing import TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.latest import _io
# sentinel for missing inputs
MISSING = object()
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True),
io.MatchType.Input("on_true", template=template, lazy=True),
],
outputs=[
io.MatchType.Output(template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=None, on_true=None):
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
return io.NodeOutput(on_true if switch else on_false)
class SoftSwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
],
outputs=[
io.MatchType.Output(template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
# We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
# This trick allows us to ignore the value of the switch and still be able to run execute().
# One of the inputs may be missing, in which case we need to evaluate the other input
if on_false is MISSING:
return ["on_true"]
if on_true is MISSING:
return ["on_false"]
# Normal lazy switch operation
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
# This check happens before check_lazy_status(), so we can eliminate the case where
# both inputs are missing.
if on_false is MISSING and on_true is MISSING:
return "At least one of on_false or on_true must be connected to Switch node"
return True
@classmethod
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
if on_true is MISSING:
return io.NodeOutput(on_false)
if on_false is MISSING:
return io.NodeOutput(on_true)
return io.NodeOutput(on_true if switch else on_false)
class CustomComboNode(io.ComfyNode):
"""
Frontend node that allows user to write their own options for a combo.
This is here to make sure the node has a backend-representation to avoid some annoyances.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CustomCombo",
display_name="Custom Combo",
category="utils",
is_experimental=True,
inputs=[io.Combo.Input("choice", options=[])],
outputs=[
io.String.Output(display_name="STRING"),
io.Int.Output(display_name="INDEX"),
],
accept_all_inputs=True,
)
@classmethod
def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool:
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
return True
@classmethod
def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput:
return io.NodeOutput(choice, index)
class DCTestNode(io.ComfyNode):
class DCValues(TypedDict):
combo: str
string: str
integer: int
image: io.Image.Type
subcombo: dict[str]
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="logic",
is_output_node=True,
inputs=[io.DynamicCombo.Input("combo", options=[
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
io.DynamicCombo.Option("option4", [
io.DynamicCombo.Input("subcombo", options=[
io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
])
])]
)],
outputs=[io.AnyType.Output()],
)
@classmethod
def execute(cls, combo: DCValues) -> io.NodeOutput:
combo_val = combo["combo"]
if combo_val == "option1":
return io.NodeOutput(combo["string"])
elif combo_val == "option2":
return io.NodeOutput(combo["integer"])
elif combo_val == "option3":
return io.NodeOutput(combo["image"])
elif combo_val == "option4":
return io.NodeOutput(f"{combo['subcombo']}")
else:
raise ValueError(f"Invalid combo: {combo_val}")
class AutogrowNamesTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
vals = list(autogrow.values())
combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined)
class AutogrowPrefixTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
vals = list(autogrow.values())
combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined)
class ComboOutputTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
)
@classmethod
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
return io.NodeOutput(combo, combo2)
class ConvertStringToComboNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@classmethod
def execute(cls, string: str) -> io.NodeOutput:
return io.NodeOutput(string)
class InvertBooleanNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, boolean: bool) -> io.NodeOutput:
return io.NodeOutput(not boolean)
class LogicExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SwitchNode,
CustomComboNode,
# SoftSwitchNode,
# ConvertStringToComboNode,
# DCTestNode,
# AutogrowNamesTestNode,
# AutogrowPrefixTestNode,
# ComboOutputTestNode,
# InvertBooleanNode,
]
async def comfy_entrypoint() -> LogicExtension:
return LogicExtension()

View File

@@ -0,0 +1,79 @@
import folder_paths
import comfy.utils
import comfy.sd
class LoraLoaderBypass:
"""
Apply LoRA in bypass mode without modifying base model weights.
Bypass mode computes: output = base_forward(x) + lora_path(x)
This is useful for training and when model weights are offloaded.
"""
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
FUNCTION = "load_lora"
CATEGORY = "loaders"
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
EXPERIMENTAL = True
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
self.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
NODE_CLASS_MAPPINGS = {
"LoraLoaderBypass": LoraLoaderBypass,
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)",
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)",
}

View File

@@ -7,6 +7,7 @@ import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from tqdm.auto import trange
CLAMP_QUANTILE = 0.99
@@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
comfy.model_management.load_models_gpu([model_diff])
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
sd_keys = list(sd.keys())
for index in trange(len(sd_keys), unit="weight"):
k = sd_keys[index]
op_keys = sd_keys[index].rsplit('.', 1)
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
continue
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
else:
weight_diff = sd[k]
if op_keys[1] == "weight":
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
@@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
elif bias_diff and op_keys[1] == "bias":
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):
@@ -78,6 +89,7 @@ class LoraSave(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoraSave",
search_aliases=["export lora"],
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[

View File

@@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
generate = execute # TODO: remove
class LTXVImgToVideoInplace(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVImgToVideoInplace",
category="conditioning/video_models",
inputs=[
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Latent.Input("latent"),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
if bypass:
return (latent,)
samples = latent["samples"]
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
batch, _, latent_frames, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
if image.shape[1] != height or image.shape[2] != width:
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
else:
pixels = image
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
generate = execute # TODO: remove
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
@@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
return keyframe_idxs, num_keyframes
class LTXVAddGuide(io.ComfyNode):
NUM_PREFIX_FRAMES = 2
PATCHIFIER = SymmetricPatchifier(1)
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
@classmethod
def define_schema(cls):
@@ -170,11 +223,24 @@ class LTXVAddGuide(io.ComfyNode):
return frame_idx, latent_idx
@classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords[:, 0] += frame_idx
# The following adjusts keyframe end positions for small grid IC-LoRA.
# After dilation, the small grid has the same size and position as the large grid,
# but each token encodes a larger image patch. We adjust the end position (not start)
# so that RoPE represents the correct middle point of each token.
# keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end])
# We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3.
spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor(
scale_factors[1:],
device=pixel_coords.device,
).view(1, -1, 1, 1)
pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype)
if keyframe_idxs is None:
keyframe_idxs = pixel_coords
else:
@@ -182,26 +248,35 @@ class LTXVAddGuide(io.ComfyNode):
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
@classmethod
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = cls.get_latent_index(
cond=positive,
latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2],
frame_idx=frame_idx,
scale_factors=scale_factors,
)
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1):
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
raise ValueError("Adding guide to a combined AV latent is not supported.")
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
if guide_mask is not None:
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
mask = guide_mask - strength
else:
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
# This solves audio video combined latent case where latent_image has audio latent concatenated
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
if latent_image.shape[1] > guiding_latent.shape[1]:
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask
@@ -238,33 +313,17 @@ class LTXVAddGuide(io.ComfyNode):
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive,
negative,
frame_idx,
latent_image,
noise_mask,
t[:, :, :num_prefix_frames],
t,
strength,
scale_factors,
)
latent_idx += num_prefix_frames
t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
latent_image, noise_mask = cls.replace_latent_frames(
latent_image,
noise_mask,
t,
latent_idx,
strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
@@ -507,18 +566,90 @@ class LTXVPreprocess(io.ComfyNode):
preprocess = execute # TODO: remove
import comfy.nested_tensor
class LTXVConcatAVLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVConcatAVLatent",
category="latent/video/ltxv",
inputs=[
io.Latent.Input("video_latent"),
io.Latent.Input("audio_latent"),
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
output = {}
output.update(video_latent)
output.update(audio_latent)
video_noise_mask = video_latent.get("noise_mask", None)
audio_noise_mask = audio_latent.get("noise_mask", None)
if video_noise_mask is not None or audio_noise_mask is not None:
if video_noise_mask is None:
video_noise_mask = torch.ones_like(video_latent["samples"])
if audio_noise_mask is None:
audio_noise_mask = torch.ones_like(audio_latent["samples"])
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
return io.NodeOutput(output)
class LTXVSeparateAVLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVSeparateAVLatent",
category="latent/video/ltxv",
description="LTXV Separate AV Latent",
inputs=[
io.Latent.Input("av_latent"),
],
outputs=[
io.Latent.Output(display_name="video_latent"),
io.Latent.Output(display_name="audio_latent"),
],
)
@classmethod
def execute(cls, av_latent) -> io.NodeOutput:
latents = av_latent["samples"].unbind()
video_latent = av_latent.copy()
video_latent["samples"] = latents[0]
audio_latent = av_latent.copy()
audio_latent["samples"] = latents[1]
if "noise_mask" in av_latent:
masks = av_latent["noise_mask"]
if masks is not None:
masks = masks.unbind()
video_latent["noise_mask"] = masks[0]
audio_latent["noise_mask"] = masks[1]
return io.NodeOutput(video_latent, audio_latent)
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyLTXVLatentVideo,
LTXVImgToVideo,
LTXVImgToVideoInplace,
ModelSamplingLTXV,
LTXVConditioning,
LTXVScheduler,
LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,
LTXVConcatAVLatent,
LTXVSeparateAVLatent,
]

View File

@@ -0,0 +1,224 @@
import folder_paths
import comfy.utils
import comfy.model_management
import torch
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
from comfy_api.latest import ComfyExtension, io
class LTXVAudioVAELoader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAELoader",
display_name="LTXV Audio VAE Loader",
category="audio",
inputs=[
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
tooltip="Audio VAE checkpoint to load.",
)
],
outputs=[io.Vae.Output(display_name="Audio VAE")],
)
@classmethod
def execute(cls, ckpt_name: str) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
return io.NodeOutput(AudioVAE(sd, metadata))
class LTXVAudioVAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAEEncode",
display_name="LTXV Audio VAE Encode",
category="audio",
inputs=[
io.Audio.Input("audio", tooltip="The audio to be encoded."),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model to use for encoding.",
),
],
outputs=[io.Latent.Output(display_name="Audio Latent")],
)
@classmethod
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latents = audio_vae.encode(audio)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": int(audio_vae.sample_rate),
"type": "audio",
}
)
class LTXVAudioVAEDecode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAEDecode",
display_name="LTXV Audio VAE Decode",
category="audio",
inputs=[
io.Latent.Input("samples", tooltip="The latent to be decoded."),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model used for decoding the latent.",
),
],
outputs=[io.Audio.Output(display_name="Audio")],
)
@classmethod
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latent = samples["samples"]
if audio_latent.is_nested:
audio_latent = audio_latent.unbind()[-1]
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
output_audio_sample_rate = audio_vae.output_sample_rate
return io.NodeOutput(
{
"waveform": audio,
"sample_rate": int(output_audio_sample_rate),
}
)
class LTXVEmptyLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVEmptyLatentAudio",
display_name="LTXV Empty Latent Audio",
category="latent/audio",
inputs=[
io.Int.Input(
"frames_number",
default=97,
min=1,
max=1000,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Number of frames.",
),
io.Int.Input(
"frame_rate",
default=25,
min=1,
max=1000,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Number of frames per second.",
),
io.Int.Input(
"batch_size",
default=1,
min=1,
max=4096,
display_mode=io.NumberDisplay.number,
tooltip="The number of latent audio samples in the batch.",
),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model to get configuration from.",
),
],
outputs=[io.Latent.Output(display_name="Latent")],
)
@classmethod
def execute(
cls,
frames_number: int,
frame_rate: int,
batch_size: int,
audio_vae: AudioVAE,
) -> io.NodeOutput:
"""Generate empty audio latents matching the reference pipeline structure."""
assert audio_vae is not None, "Audio VAE model is required"
z_channels = audio_vae.latent_channels
audio_freq = audio_vae.latent_frequency_bins
sampling_rate = int(audio_vae.sample_rate)
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
audio_latents = torch.zeros(
(batch_size, z_channels, num_audio_latents, audio_freq),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": sampling_rate,
"type": "audio",
}
)
class LTXAVTextEncoderLoader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXAVTextEncoderLoader",
display_name="LTXV Audio Text Encoder Loader",
category="advanced/loaders",
description="[Recipes]\n\nltxav: gemma 3 12B",
inputs=[
io.Combo.Input(
"text_encoder",
options=folder_paths.get_filename_list("text_encoders"),
),
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
),
io.Combo.Input(
"device",
options=["default", "cpu"],
)
],
outputs=[io.Clip.Output()],
)
@classmethod
def execute(cls, text_encoder, ckpt_name, device="default"):
clip_type = comfy.sd.CLIPType.LTXV
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip)
class LTXVAudioExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LTXVAudioVAELoader,
LTXVAudioVAEEncode,
LTXVAudioVAEDecode,
LTXVEmptyLatentAudio,
LTXAVTextEncoderLoader,
]
async def comfy_entrypoint() -> ComfyExtension:
return LTXVAudioExtension()

View File

@@ -0,0 +1,75 @@
from comfy import model_management
import math
class LTXVLatentUpsampler:
"""
Upsamples a video latent by a factor of 2.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"upscale_model": ("LATENT_UPSCALE_MODEL",),
"vae": ("VAE",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upsample_latent"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def upsample_latent(
self,
samples: dict,
upscale_model,
vae,
) -> tuple:
"""
Upsample the input latent using the provided model.
Args:
samples (dict): Input latent samples
upscale_model (LatentUpsampler): Loaded upscale model
vae: VAE model for normalization
auto_tiling (bool): Whether to automatically tile the input for processing
Returns:
tuple: Tuple containing the upsampled latent
"""
device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model)
model_dtype = next(upscale_model.parameters()).dtype
latents = samples["samples"]
input_dtype = latents.dtype
memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate
model_management.free_memory(memory_required, device)
try:
upscale_model.to(device) # TODO: use the comfy model management system.
latents = latents.to(dtype=model_dtype, device=device)
"""Upsample latents without tiling."""
latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
upsampled_latents = upscale_model(latents)
finally:
upscale_model.cpu()
upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
upsampled_latents
)
upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device())
return_dict = samples.copy()
return_dict["samples"] = upsampled_latents
return_dict.pop("noise_mask", None)
return (return_dict,)
NODE_CLASS_MAPPINGS = {
"LTXVLatentUpsampler": LTXVLatentUpsampler,
}

View File

@@ -79,6 +79,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeLumina2",
search_aliases=["lumina prompt"],
display_name="CLIP Text Encode for Lumina2",
category="conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "

View File

@@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Mahiro",
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
display_name="Mahiro CFG",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[

View File

@@ -3,11 +3,10 @@ import scipy.ndimage
import torch
import comfy.utils
import node_helpers
import folder_paths
import random
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, UI
import nodes
from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
@@ -46,202 +45,221 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked:
class LatentCompositeMasked(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("LATENT",),
"source": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"resize_source": ("BOOLEAN", {"default": False}),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"
def define_schema(cls):
return IO.Schema(
node_id="LatentCompositeMasked",
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
category="latent",
inputs=[
IO.Latent.Input("destination"),
IO.Latent.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
IO.Boolean.Input("resize_source", default=False),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Latent.Output()],
)
CATEGORY = "latent"
def composite(self, destination, source, x, y, resize_source, mask = None):
@classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
output = destination.copy()
destination = destination["samples"].clone()
source = source["samples"]
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
return (output,)
return IO.NodeOutput(output)
class ImageCompositeMasked:
composite = execute # TODO: remove
class ImageCompositeMasked(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("IMAGE",),
"source": ("IMAGE",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"resize_source": ("BOOLEAN", {"default": False}),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "composite"
def define_schema(cls):
return IO.Schema(
node_id="ImageCompositeMasked",
search_aliases=["paste image", "overlay", "layer"],
category="image",
inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image"
def composite(self, destination, source, x, y, resize_source, mask = None):
@classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
return (output,)
return IO.NodeOutput(output)
class MaskToImage:
composite = execute # TODO: remove
class MaskToImage(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
}
}
def define_schema(cls):
return IO.Schema(
node_id="MaskToImage",
search_aliases=["convert mask"],
display_name="Convert Mask to Image",
category="mask",
inputs=[
IO.Mask.Input("mask"),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "mask_to_image"
def mask_to_image(self, mask):
@classmethod
def execute(cls, mask) -> IO.NodeOutput:
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return (result,)
return IO.NodeOutput(result)
class ImageToMask:
mask_to_image = execute # TODO: remove
class ImageToMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"channel": (["red", "green", "blue", "alpha"],),
}
}
def define_schema(cls):
return IO.Schema(
node_id="ImageToMask",
search_aliases=["extract channel", "channel to mask"],
display_name="Convert Image to Mask",
category="mask",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, image, channel):
@classmethod
def execute(cls, image, channel) -> IO.NodeOutput:
channels = ["red", "green", "blue", "alpha"]
mask = image[:, :, :, channels.index(channel)]
return (mask,)
return IO.NodeOutput(mask)
class ImageColorToMask:
image_to_mask = execute # TODO: remove
class ImageColorToMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
}
}
def define_schema(cls):
return IO.Schema(
node_id="ImageColorToMask",
search_aliases=["color keying", "chroma key"],
category="mask",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, image, color):
@classmethod
def execute(cls, image, color) -> IO.NodeOutput:
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
mask = torch.where(temp == color, 1.0, 0).float()
return (mask,)
return IO.NodeOutput(mask)
class SolidMask:
image_to_mask = execute # TODO: remove
class SolidMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
def define_schema(cls):
return IO.Schema(
node_id="SolidMask",
category="mask",
inputs=[
IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "solid"
def solid(self, value, width, height):
@classmethod
def execute(cls, value, width, height) -> IO.NodeOutput:
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
return (out,)
return IO.NodeOutput(out)
class InvertMask:
solid = execute # TODO: remove
class InvertMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
}
}
def define_schema(cls):
return IO.Schema(
node_id="InvertMask",
search_aliases=["reverse mask", "flip mask"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "invert"
def invert(self, mask):
@classmethod
def execute(cls, mask) -> IO.NodeOutput:
out = 1.0 - mask
return (out,)
return IO.NodeOutput(out)
class CropMask:
invert = execute # TODO: remove
class CropMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
def define_schema(cls):
return IO.Schema(
node_id="CropMask",
search_aliases=["cut mask", "extract mask region", "mask slice"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "crop"
def crop(self, mask, x, y, width, height):
@classmethod
def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = mask[:, y:y + height, x:x + width]
return (out,)
return IO.NodeOutput(out)
class MaskComposite:
crop = execute # TODO: remove
class MaskComposite(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"destination": ("MASK",),
"source": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
}
}
def define_schema(cls):
return IO.Schema(
node_id="MaskComposite",
search_aliases=["combine masks", "blend masks", "layer masks"],
category="mask",
inputs=[
IO.Mask.Input("destination"),
IO.Mask.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "combine"
def combine(self, destination, source, x, y, operation):
@classmethod
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
@@ -267,28 +285,30 @@ class MaskComposite:
output = torch.clamp(output, 0.0, 1.0)
return (output,)
return IO.NodeOutput(output)
class FeatherMask:
combine = execute # TODO: remove
class FeatherMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}
}
def define_schema(cls):
return IO.Schema(
node_id="FeatherMask",
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "feather"
def feather(self, mask, left, top, right, bottom):
@classmethod
def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[-1])
@@ -312,26 +332,29 @@ class FeatherMask:
feather_rate = (y + 1) / bottom
output[:, -y, :] *= feather_rate
return (output,)
return IO.NodeOutput(output)
class GrowMask:
feather = execute # TODO: remove
class GrowMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
"tapered_corners": ("BOOLEAN", {"default": True}),
},
}
def define_schema(cls):
return IO.Schema(
node_id="GrowMask",
search_aliases=["expand mask", "shrink mask"],
display_name="Grow Mask",
category="mask",
inputs=[
IO.Mask.Input("mask"),
IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("tapered_corners", default=True),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "expand_mask"
def expand_mask(self, mask, expand, tapered_corners):
@classmethod
def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
@@ -347,69 +370,76 @@ class GrowMask:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
return (torch.stack(out, dim=0),)
return IO.NodeOutput(torch.stack(out, dim=0))
class ThresholdMask:
expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
def define_schema(cls):
return IO.Schema(
node_id="ThresholdMask",
search_aliases=["binary mask"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, mask, value):
@classmethod
def execute(cls, mask, value) -> IO.NodeOutput:
mask = (mask > value).float()
return (mask,)
return IO.NodeOutput(mask)
image_to_mask = execute # TODO: remove
# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
class MaskPreview(nodes.SaveImage):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 4
class MaskPreview(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="MaskPreview",
search_aliases=["show mask", "view mask", "inspect mask", "debug mask"],
display_name="Preview Mask",
category="mask",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
IO.Mask.Input("mask"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {"mask": ("MASK",), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
FUNCTION = "execute"
CATEGORY = "mask"
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
return IO.NodeOutput(ui=UI.PreviewMask(mask))
NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked,
"ImageCompositeMasked": ImageCompositeMasked,
"MaskToImage": MaskToImage,
"ImageToMask": ImageToMask,
"ImageColorToMask": ImageColorToMask,
"SolidMask": SolidMask,
"InvertMask": InvertMask,
"CropMask": CropMask,
"MaskComposite": MaskComposite,
"FeatherMask": FeatherMask,
"GrowMask": GrowMask,
"ThresholdMask": ThresholdMask,
"MaskPreview": MaskPreview
}
class MaskExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LatentCompositeMasked,
ImageCompositeMasked,
MaskToImage,
ImageToMask,
ImageColorToMask,
SolidMask,
InvertMask,
CropMask,
MaskComposite,
FeatherMask,
GrowMask,
ThresholdMask,
MaskPreview,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageToMask": "Convert Image to Mask",
"MaskToImage": "Convert Mask to Image",
}
async def comfy_entrypoint() -> MaskExtension:
return MaskExtension()

View File

@@ -299,6 +299,7 @@ class RescaleCFG:
return (m, )
class ModelComputeDtype:
SEARCH_ALIASES = ["model precision", "change dtype"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),

View File

@@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
return io.NodeOutput(m)
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "",
}
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@@ -91,6 +91,7 @@ class CLIPMergeSimple:
class CLIPSubtract:
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
@@ -113,6 +114,7 @@ class CLIPSubtract:
class CLIPAdd:
SEARCH_ALIASES = ["combine clip"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
@@ -225,6 +227,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -337,6 +340,7 @@ class VAESave:
return {}
class ModelSave:
SEARCH_ALIASES = ["export model", "checkpoint save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()

View File

@@ -6,6 +6,8 @@ import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
import comfy.ldm.lumina.controlnet
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
class BlockWiseControlBlock(torch.nn.Module):
@@ -189,6 +191,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
return embedding
def z_image_convert(sd):
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
".attention.norm_k.weight": ".attention.k_norm.weight",
".attention.norm_q.weight": ".attention.q_norm.weight",
".attention.to_out.0.weight": ".attention.out.weight"
}
out_sd = {}
for k in sorted(sd.keys()):
w = sd[k]
k_out = k
if k_out.endswith(".attention.to_k.weight"):
cc = [w]
continue
if k_out.endswith(".attention.to_q.weight"):
cc = [w] + cc
continue
if k_out.endswith(".attention.to_v.weight"):
cc = cc + [w]
w = torch.cat(cc, dim=0)
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
for r, rr in replace_keys.items():
k_out = k_out.replace(r, rr)
out_sd[k_out] = w
return out_sd
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@@ -211,10 +242,34 @@ class ModelPatchLoader:
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd)
config = {}
if 'control_layers.4.adaLN_modulation.0.weight' not in sd:
config['n_control_layers'] = 3
config['additional_in_dim'] = 17
config['refiner_control'] = True
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
config['n_control_layers'] = 15
config['additional_in_dim'] = 17
config['refiner_control'] = True
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
if ref_weight is not None:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
elif "audio_proj.proj1.weight" in sd:
model = MultiTalkModelPatch(
audio_window=5, context_tokens=32, vae_scale=4,
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
out_dim=sd["audio_proj.norm.weight"].shape[0],
device=comfy.model_management.unet_offload_device(),
operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
return (model_patcher,)
class DiffSynthCnetPatch:
@@ -263,6 +318,129 @@ class DiffSynthCnetPatch:
def models(self):
return [self.model_patch]
class ZImageControlPatch:
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.inpaint_image = inpaint_image
self.mask = mask
self.strength = strength
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
skip_encoding = False
if self.image is not None and self.inpaint_image is not None:
if self.image.shape != self.inpaint_image.shape:
skip_encoding = True
if skip_encoding:
self.encoded_image = None
else:
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
if self.image is None:
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
else:
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
self.temp_data = None
def encode_latent_cond(self, control_image=None, inpaint_image=None):
latent_image = None
if control_image is not None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
if self.is_inpaint:
if inpaint_image is None:
inpaint_image = torch.ones_like(control_image) * 0.5
if self.mask is not None:
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
if self.mask is None:
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
else:
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
if latent_image is None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
else:
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
img_input = kwargs.get("img_input")
txt = kwargs.get("txt")
pe = kwargs.get("pe")
vec = kwargs.get("vec")
block_index = kwargs.get("block_index")
block_type = kwargs.get("block_type", "")
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = None
if self.image is not None:
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
inpaint_scaled = None
if self.inpaint_image is not None:
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
comfy.model_management.load_models_gpu(loaded_models)
cnet_blocks = self.model_patch.model.n_control_layers
div = round(30 / cnet_blocks)
cnet_index = (block_index // div)
cnet_index_float = (block_index / div)
kwargs.pop("img") # we do ops in place
kwargs.pop("txt")
if cnet_index_float > (cnet_blocks - 1):
self.temp_data = None
return kwargs
if self.temp_data is None or self.temp_data[0] > cnet_index:
if block_type == "noise_refiner":
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
else:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
if block_type == "noise_refiner":
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if self.temp_data[1][0] is not None:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
else:
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if cnet_index_float == self.temp_data[0]:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
if cnet_blocks == self.temp_data[0] + 1:
self.temp_data = None
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
if self.encoded_image is not None:
self.encoded_image = self.encoded_image.to(device_or_dtype)
self.temp_data = None
return self
def models(self):
return [self.model_patch]
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
@@ -279,9 +457,12 @@ class QwenImageDiffsynthControlnet:
CATEGORY = "advanced/loaders/qwen"
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
model_patched = model.clone()
image = image[:, :, :, :3]
if image is not None:
image = image[:, :, :, :3]
if inpaint_image is not None:
inpaint_image = inpaint_image[:, :, :, :3]
if mask is not None:
if mask.ndim == 3:
mask = mask.unsqueeze(1)
@@ -289,9 +470,25 @@ class QwenImageDiffsynthControlnet:
mask = mask.unsqueeze(2)
mask = 1.0 - mask
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
model_patched.set_model_noise_refiner_patch(patch)
model_patched.set_model_double_block_patch(patch)
else:
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
CATEGORY = "advanced/loaders/zimage"
class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
@@ -336,8 +533,41 @@ class USOStyleReference:
return (model_patched,)
class MultiTalkModelPatch(torch.nn.Module):
def __init__(
self,
audio_window: int = 5,
intermediate_dim: int = 512,
in_dim: int = 5120,
out_dim: int = 768,
context_tokens: int = 32,
vae_scale: int = 4,
num_layers: int = 40,
device=None, dtype=None, operations=None
):
super().__init__()
self.audio_proj = MultiTalkAudioProjModel(
seq_len=audio_window,
seq_len_vf=audio_window+vae_scale-1,
intermediate_dim=intermediate_dim,
out_dim=out_dim,
context_tokens=context_tokens,
device=device,
dtype=dtype,
operations=operations
)
self.blocks = torch.nn.ModuleList(
[
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
]
)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"ZImageFunControlnet": ZImageFunControlnet,
"USOStyleReference": USOStyleReference,
}

View File

@@ -12,6 +12,7 @@ class Morphology(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Morphology",
search_aliases=["erode", "dilate"],
display_name="ImageMorphology",
category="image/postprocessing",
inputs=[
@@ -57,6 +58,7 @@ class ImageRGBToYUV(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageRGBToYUV",
search_aliases=["color space conversion"],
category="image/batch",
inputs=[
io.Image.Input("image"),
@@ -78,6 +80,7 @@ class ImageYUVToRGB(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageYUVToRGB",
search_aliases=["color space conversion"],
category="image/batch",
inputs=[
io.Image.Input("Y"),

99
comfy_extras/nodes_nag.py Normal file
View File

@@ -0,0 +1,99 @@
import torch
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class NAGuidance(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="NAGuidance",
display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."),
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
],
outputs=[
io.Model.Output(tooltip="The patched model with NAG enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
m = model.clone()
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
def nag_attention_output_patch(out, extra_options):
cond_or_uncond = extra_options.get("cond_or_uncond", None)
if cond_or_uncond is None:
return out
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
return out
# sigma = extra_options.get("sigmas", None)
# if sigma is not None and len(sigma) > 0:
# sigma = sigma[0].item()
# if sigma > sigma_start or sigma < sigma_end:
# return out
img_slice = extra_options.get("img_slice", None)
if img_slice is not None:
orig_out = out
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
batch_size = out.shape[0]
half_size = batch_size // len(cond_or_uncond)
ind_neg = cond_or_uncond.index(1)
ind_pos = cond_or_uncond.index(0)
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
eps = 1e-6
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
ratio = norm_guided / norm_pos
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
guided_normalized = guided * scale_factor
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
if img_slice is not None:
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
return orig_out
else:
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
return out
m.set_model_attn1_output_patch(nag_attention_output_patch)
m.disable_model_cfg1_optimization()
return io.NodeOutput(m)
class NagExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
NAGuidance,
]
async def comfy_entrypoint() -> NagExtension:
return NagExtension()

39
comfy_extras/nodes_nop.py Normal file
View File

@@ -0,0 +1,39 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list
# "native" block swap nodes are placebo at best and break the ComfyUI memory management system.
# They are also considered harmful because instead of users reporting issues with the built in
# memory management they install these stupid nodes and complain even harder. Now it completely
# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it
# out of all workflows.
class wanBlockSwap(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="wanBlockSwap",
category="",
description="NOP",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(),
],
is_deprecated=True,
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
return io.NodeOutput(model)
class NopExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
wanBlockSwap
]
async def comfy_entrypoint() -> NopExtension:
return NopExtension()

View File

@@ -7,6 +7,7 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodePixArtAlpha",
search_aliases=["pixart prompt"],
category="advanced/conditioning",
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
inputs=[

View File

@@ -4,11 +4,15 @@ import torch
import torch.nn.functional as F
from PIL import Image
import math
from enum import Enum
from typing import TypedDict, Literal
import comfy.utils
import comfy.model_management
from comfy_extras.nodes_latent import reshape_latent_to
import node_helpers
from comfy_api.latest import ComfyExtension, io
from nodes import MAX_RESOLUTION
class Blend(io.ComfyNode):
@classmethod
@@ -221,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
],
outputs=[
io.Image.Output(),
@@ -228,18 +233,429 @@ class ImageScaleToTotalPixels(io.ComfyNode):
)
@classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024)
total = megapixels * 1024 * 1024
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
s = s.movedim(1,-1)
return io.NodeOutput(s)
class ResizeType(str, Enum):
SCALE_BY = "scale by multiplier"
SCALE_DIMENSIONS = "scale dimensions"
SCALE_LONGER_DIMENSION = "scale longer dimension"
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
SCALE_WIDTH = "scale width"
SCALE_HEIGHT = "scale height"
SCALE_TOTAL_PIXELS = "scale total pixels"
MATCH_SIZE = "match size"
SCALE_TO_MULTIPLE = "scale to multiple"
def is_image(input: torch.Tensor) -> bool:
# images have 4 dimensions: [batch, height, width, channels]
# masks have 3 dimensions: [batch, height, width]
return len(input.shape) == 4
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
if is_type_image:
input = input.movedim(-1, 1)
else:
input = input.unsqueeze(1)
return input
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
if is_type_image:
input = input.movedim(1, -1)
else:
input = input.squeeze(1)
return input
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = round(input.shape[-1] * multiplier)
height = round(input.shape[-2] * multiplier)
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
if width == 0 and height == 0:
return input
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
if width == 0:
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
elif height == 0:
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = input.shape[-1]
height = input.shape[-2]
if height > width:
width = round((width / height) * longer_size)
height = longer_size
elif width > height:
height = round((height / width) * longer_size)
width = longer_size
else:
height = longer_size
width = longer_size
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = input.shape[-1]
height = input.shape[-2]
if height < width:
width = round((width / height) * shorter_size)
height = shorter_size
elif width < height:
height = round((height / width) * shorter_size)
width = shorter_size
else:
height = shorter_size
width = shorter_size
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
total = int(megapixels * 1024 * 1024)
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
width = round(input.shape[-1] * scale_by)
height = round(input.shape[-2] * scale_by)
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
match = init_image_mask_input(match, is_image(match))
width = match.shape[-1]
height = match.shape[-2]
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor:
if multiple <= 1:
return input
is_type_image = is_image(input)
if is_type_image:
_, height, width, _ = input.shape
else:
_, height, width = input.shape
target_w = (width // multiple) * multiple
target_h = (height // multiple) * multiple
if target_w == 0 or target_h == 0:
return input
if target_w == width and target_h == height:
return input
s_w = target_w / width
s_h = target_h / height
if s_w >= s_h:
scaled_w = target_w
scaled_h = int(math.ceil(height * s_w))
if scaled_h < target_h:
scaled_h = target_h
else:
scaled_h = target_h
scaled_w = int(math.ceil(width * s_h))
if scaled_w < target_w:
scaled_w = target_w
input = init_image_mask_input(input, is_type_image)
input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
x0 = (scaled_w - target_w) // 2
y0 = (scaled_h - target_h) // 2
x1 = x0 + target_w
y1 = y0 + target_h
if is_type_image:
return input[:, y0:y1, x0:x1, :]
return input[:, y0:y1, x0:x1]
class ResizeImageMaskNode(io.ComfyNode):
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
class ResizeTypedDict(TypedDict):
resize_type: ResizeType
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop: Literal["disabled", "center"]
multiplier: float
width: int
height: int
longer_size: int
shorter_size: int
megapixels: float
multiple: int
@classmethod
def define_schema(cls):
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
crop_combo = io.Combo.Input(
"crop",
options=cls.crop_methods,
default="center",
tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.",
)
return io.Schema(
node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask",
description="Resize an image or mask using various scaling methods.",
category="transform",
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
inputs=[
io.MatchType.Input("input", template=template),
io.DynamicCombo.Input(
"resize_type",
tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.",
options=[
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."),
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."),
]),
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."),
]),
],
),
io.Combo.Input(
"scale_method",
options=cls.scale_methods,
default="area",
tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.",
),
],
outputs=[io.MatchType.Output(template=template, display_name="resized")]
)
@classmethod
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
selected_type = resize_type["resize_type"]
if selected_type == ResizeType.SCALE_BY:
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
elif selected_type == ResizeType.SCALE_DIMENSIONS:
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
elif selected_type == ResizeType.SCALE_WIDTH:
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
elif selected_type == ResizeType.SCALE_HEIGHT:
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
elif selected_type == ResizeType.MATCH_SIZE:
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
elif selected_type == ResizeType.SCALE_TO_MULTIPLE:
return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method))
raise ValueError(f"Unsupported resize type: {selected_type}")
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
if len(images) == 0:
return None
# first, get the max channels count
max_channels = max(image.shape[-1] for image in images)
# then, pad all images to have the same channels count
padded_images: list[torch.Tensor] = []
for image in images:
if image.shape[-1] < max_channels:
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
else:
padded_images.append(image)
# resize all images to be the same size as the first image
resized_images: list[torch.Tensor] = []
first_image_shape = padded_images[0].shape
for image in padded_images:
if image.shape[1:] != first_image_shape[1:]:
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
else:
resized_images.append(image)
# batch the images in the format [b, h, w, c]
return torch.cat(resized_images, dim=0)
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
if len(masks) == 0:
return None
# resize all masks to be the same size as the first mask
resized_masks: list[torch.Tensor] = []
first_mask_shape = masks[0].shape
for mask in masks:
if mask.shape[1:] != first_mask_shape[1:]:
mask = init_image_mask_input(mask, is_type_image=False)
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
else:
resized_masks.append(mask)
# batch the masks in the format [b, h, w]
return torch.cat(resized_masks, dim=0)
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
if len(latents) == 0:
return None
samples_out = latents[0].copy()
samples_out["batch_index"] = []
first_samples = latents[0]["samples"]
tensors: list[torch.Tensor] = []
for latent in latents:
# first, deal with latent tensors
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
# next, deal with batch_index
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
samples_out["samples"] = torch.cat(tensors, dim=0)
return samples_out
class BatchImagesNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
return io.Schema(
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[
io.Autogrow.Input("images", template=autogrow_template)
],
outputs=[
io.Image.Output()
]
)
@classmethod
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_images(list(images.values())))
class BatchMasksNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
return io.Schema(
node_id="BatchMasksNode",
search_aliases=["combine masks", "stack masks", "merge masks"],
display_name="Batch Masks",
category="mask",
inputs=[
io.Autogrow.Input("masks", template=autogrow_template)
],
outputs=[
io.Mask.Output()
]
)
@classmethod
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_masks(list(masks.values())))
class BatchLatentsNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
return io.Schema(
node_id="BatchLatentsNode",
search_aliases=["combine latents", "stack latents", "merge latents"],
display_name="Batch Latents",
category="latent",
inputs=[
io.Autogrow.Input("latents", template=autogrow_template)
],
outputs=[
io.Latent.Output()
]
)
@classmethod
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_latents(list(latents.values())))
class BatchImagesMasksLatentsNode(io.ComfyNode):
@classmethod
def define_schema(cls):
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("input", matchtype_template),
prefix="input", min=1, max=50)
return io.Schema(
node_id="BatchImagesMasksLatentsNode",
search_aliases=["combine batch", "merge batch", "stack inputs"],
display_name="Batch Images/Masks/Latents",
category="util",
inputs=[
io.Autogrow.Input("inputs", template=autogrow_template)
],
outputs=[
io.MatchType.Output(id=None, template=matchtype_template)
]
)
@classmethod
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
batched = None
values = list(inputs.values())
# latents
if isinstance(values[0], dict):
batched = batch_latents(values)
# images
elif is_image(values[0]):
batched = batch_images(values)
# masks
else:
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -249,6 +665,11 @@ class PostProcessingExtension(ComfyExtension):
Quantize,
Sharpen,
ImageScaleToTotalPixels,
ResizeImageMaskNode,
BatchImagesNode,
BatchMasksNode,
BatchLatentsNode,
# BatchImagesMasksLatentsNode,
]
async def comfy_entrypoint() -> PostProcessingExtension:

View File

@@ -16,6 +16,7 @@ class PreviewAny():
OUTPUT_NODE = True
CATEGORY = "utils"
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
def main(self, source=None):
value = 'None'
@@ -39,5 +40,5 @@ NODE_CLASS_MAPPINGS = {
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PreviewAny": "Preview Any",
"PreviewAny": "Preview as Text",
}

View File

@@ -66,7 +66,7 @@ class Float(io.ComfyNode):
display_name="Float",
category="utils/primitive",
inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
],
outputs=[io.Float.Output()],
)

View File

@@ -3,7 +3,9 @@ import comfy.utils
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import nodes
class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod
@@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
return io.NodeOutput(conditioning)
class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyQwenImageLayeredLatentImage",
display_name="Empty Qwen Image Layered Latent",
category="latent/qwen",
inputs=[
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class QwenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus,
EmptyQwenImageLayeredLatentImage,
]

View File

@@ -0,0 +1,103 @@
from comfy_api.latest import ComfyExtension, io, ComfyAPI
api = ComfyAPI()
async def register_replacements():
"""Register all built-in node replacements."""
await register_replacements_longeredge()
await register_replacements_batchimages()
await register_replacements_upscaleimage()
await register_replacements_controlnet()
await register_replacements_load3d()
await register_replacements_preview3d()
await register_replacements_svdimg2vid()
await register_replacements_conditioningavg()
async def register_replacements_longeredge():
# No dynamic inputs here
await api.node_replacement.register(io.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
{"new_id": "image", "old_id": "images"},
{"new_id": "largest_size", "old_id": "longer_edge"},
{"new_id": "upscale_method", "set_value": "lanczos"},
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[{"new_idx": 0, "old_idx": 0}],
))
async def register_replacements_batchimages():
# BatchImages node uses Autogrow
await api.node_replacement.register(io.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
{"new_id": "images.image0", "old_id": "image1"},
{"new_id": "images.image1", "old_id": "image2"},
],
))
async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
await api.node_replacement.register(io.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
{"new_id": "input", "old_id": "image"},
{"new_id": "resize_type", "set_value": "scale by multiplier"},
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
{"new_id": "scale_method", "old_id": "upscale_method"},
],
))
async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
await api.node_replacement.register(io.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
],
))
async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
await api.node_replacement.register(io.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
await api.node_replacement.register(io.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
class NodeReplacementsExtension(ComfyExtension):
async def on_load(self) -> None:
await register_replacements()
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return []
async def comfy_entrypoint() -> NodeReplacementsExtension:
return NodeReplacementsExtension()

View File

@@ -0,0 +1,47 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class ScaleROPE(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ScaleROPE",
category="advanced/model_patches",
description="Scale and shift the ROPE of the model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
m = model.clone()
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
return io.NodeOutput(m)
class RopeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ScaleROPE
]
async def comfy_entrypoint() -> RopeExtension:
return RopeExtension()

View File

@@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode):
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
generate = execute # TODO: remove
@@ -65,6 +65,7 @@ class CLIPTextEncodeSD3(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeSD3",
search_aliases=["sd3 prompt"],
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),

View File

@@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode):
node_id="StringConcatenate",
display_name="Concatenate",
category="utils/string",
search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
inputs=[
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),
@@ -31,6 +32,7 @@ class StringSubstring(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringSubstring",
search_aliases=["extract text", "text portion"],
display_name="Substring",
category="utils/string",
inputs=[
@@ -53,6 +55,7 @@ class StringLength(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringLength",
search_aliases=["character count", "text size"],
display_name="Length",
category="utils/string",
inputs=[
@@ -73,6 +76,7 @@ class CaseConverter(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CaseConverter",
search_aliases=["text case", "uppercase", "lowercase", "capitalize"],
display_name="Case Converter",
category="utils/string",
inputs=[
@@ -105,6 +109,7 @@ class StringTrim(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringTrim",
search_aliases=["clean whitespace", "remove whitespace"],
display_name="Trim",
category="utils/string",
inputs=[
@@ -135,6 +140,7 @@ class StringReplace(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringReplace",
search_aliases=["find and replace", "substitute", "swap text"],
display_name="Replace",
category="utils/string",
inputs=[
@@ -157,6 +163,7 @@ class StringContains(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringContains",
search_aliases=["text includes", "string includes"],
display_name="Contains",
category="utils/string",
inputs=[
@@ -184,6 +191,7 @@ class StringCompare(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringCompare",
search_aliases=["text match", "string equals", "starts with", "ends with"],
display_name="Compare",
category="utils/string",
inputs=[
@@ -219,6 +227,7 @@ class RegexMatch(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexMatch",
search_aliases=["pattern match", "text contains", "string match"],
display_name="Regex Match",
category="utils/string",
inputs=[
@@ -259,6 +268,7 @@ class RegexExtract(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexExtract",
search_aliases=["pattern extract", "text parser", "parse text"],
display_name="Regex Extract",
category="utils/string",
inputs=[
@@ -333,6 +343,7 @@ class RegexReplace(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexReplace",
search_aliases=["pattern replace", "find and replace", "substitution"],
display_name="Regex Replace",
category="utils/string",
description="Find and replace text using regex patterns.",

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CreateList(io.ComfyNode):
@classmethod
def define_schema(cls):
template_matchtype = io.MatchType.Template("type")
template_autogrow = io.Autogrow.TemplatePrefix(
input=io.MatchType.Input("input", template=template_matchtype),
prefix="input",
)
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="logic",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
outputs=[
io.MatchType.Output(
template=template_matchtype,
is_output_list=True,
display_name="list",
),
],
)
@classmethod
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
output_list = []
for input in inputs.values():
output_list += input
return io.NodeOutput(output_list)
class ToolkitExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CreateList,
]
async def comfy_entrypoint() -> ToolkitExtension:
return ToolkitExtension()

View File

@@ -2,6 +2,8 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper
def skip_torch_compile_dict(guard_entries):
return [("transformer_options" not in entry.name) for entry in guard_entries]
class TorchCompileModel(io.ComfyNode):
@classmethod
@@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m)

File diff suppressed because it is too large Load Diff

View File

@@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
node_id="ImageUpscaleWithModel",
display_name="Upscale Image (using Model)",
category="image/upscaling",
search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"],
inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),
@@ -78,18 +79,20 @@ class ImageUpscaleWithModel(io.ComfyNode):
overlap = 32
oom = True
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
try:
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
finally:
upscale_model.to("cpu")
upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return io.NodeOutput(s)

View File

@@ -8,10 +8,7 @@ import json
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
from comfy.cli_args import args
class SaveWEBM(io.ComfyNode):
@@ -19,6 +16,7 @@ class SaveWEBM(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveWEBM",
search_aliases=["export webm"],
category="image/video",
is_experimental=True,
inputs=[
@@ -28,7 +26,6 @@ class SaveWEBM(io.ComfyNode):
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@@ -73,22 +70,22 @@ class SaveVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveVideo",
search_aliases=["export video"],
display_name="Save Video",
category="image/video",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@@ -105,10 +102,10 @@ class SaveVideo(io.ComfyNode):
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=format,
format=Types.VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
@@ -121,6 +118,7 @@ class CreateVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CreateVideo",
search_aliases=["images to video"],
display_name="Create Video",
category="image/video",
description="Create a video from images.",
@@ -135,9 +133,9 @@ class CreateVideo(io.ComfyNode):
)
@classmethod
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
class GetVideoComponents(io.ComfyNode):
@@ -145,6 +143,7 @@ class GetVideoComponents(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="GetVideoComponents",
search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components",
category="image/video",
description="Extracts all components from a video: frames, audio, and framerate.",
@@ -159,11 +158,11 @@ class GetVideoComponents(io.ComfyNode):
)
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components()
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -172,6 +171,7 @@ class LoadVideo(io.ComfyNode):
files = folder_paths.filter_files_content_types(files, ["video"])
return io.Schema(
node_id="LoadVideo",
search_aliases=["import video", "open video", "video file"],
display_name="Load Video",
category="image/video",
inputs=[
@@ -185,7 +185,7 @@ class LoadVideo(io.ComfyNode):
@classmethod
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):
@@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
return True
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
category="image/video",
inputs=[
io.Video.Input("video"),
io.Float.Input(
"start_time",
default=0.0,
max=1e5,
min=-1e5,
step=0.001,
tooltip="Start time in seconds",
),
io.Float.Input(
"duration",
default=0.0,
min=0.0,
step=0.001,
tooltip="Duration in seconds, or 0 for unlimited duration",
),
io.Boolean.Input(
"strict_duration",
default=False,
tooltip="If True, when the specified duration is not possible, an error will be raised.",
),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
if trimmed is not None:
return io.NodeOutput(trimmed)
raise ValueError(
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
)
class VideoExtension(ComfyExtension):
@override
@@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
VideoSlice,
]
async def comfy_entrypoint() -> VideoExtension:

View File

@@ -8,9 +8,10 @@ import comfy.latent_formats
import comfy.clip_vision
import json
import numpy as np
from typing import Tuple
from typing import Tuple, TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import logging
class WanImageToVideo(io.ComfyNode):
@classmethod
@@ -286,6 +287,7 @@ class WanVaceToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanVaceToVideo",
search_aliases=["video conditioning", "video control"],
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
@@ -704,6 +706,7 @@ class WanTrackToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanTrackToVideo",
search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"],
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
@@ -817,7 +820,7 @@ def get_sample_indices(original_fps,
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
if fixed_start is not None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
@@ -1288,6 +1291,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
return io.NodeOutput(out_latent)
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
class WanInfiniteTalkToVideo(io.ComfyNode):
class DCValues(TypedDict):
mode: str
audio_encoder_output_2: io.AudioEncoderOutput.Type
mask: io.Mask.Type
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanInfiniteTalkToVideo",
category="conditioning/video_models",
inputs=[
io.DynamicCombo.Input("mode", options=[
io.DynamicCombo.Option("single_speaker", []),
io.DynamicCombo.Option("two_speakers", [
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
]),
]),
io.Model.Input("model"),
io.ModelPatch.Input("model_patch"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
io.Image.Input("previous_frames", optional=True),
],
outputs=[
io.Model.Output(display_name="model"),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_image"),
],
)
@classmethod
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
raise ValueError("Not enough previous frames provided.")
if mode["mode"] == "two_speakers":
audio_encoder_output_2 = mode["audio_encoder_output_2"]
mask_1 = mode["mask_1"]
mask_2 = mode["mask_2"]
if audio_encoder_output_2 is not None:
if mask_1 is None or mask_2 is None:
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
ref_masks = None
if mask_1 is not None and mask_2 is not None:
if audio_encoder_output_2 is None:
raise ValueError("Second audio encoder output must be provided if two masks are used.")
ref_masks = torch.cat([mask_1, mask_2])
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
model_patched = model.clone()
encoded_audio_list = []
seq_lengths = []
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
if audio_encoder_output is None:
continue
all_layers = audio_encoder_output["encoded_audio_all_layers"]
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
encoded_audio_list.append(encoded_audio)
seq_lengths.append(encoded_audio.shape[0])
# Pad / combine depending on multi_audio_type
multi_audio_type = "add"
if len(encoded_audio_list) > 1:
if multi_audio_type == "para":
max_len = max(seq_lengths)
padded = []
for emb in encoded_audio_list:
if emb.shape[0] < max_len:
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
emb = torch.cat([emb, pad], dim=0)
padded.append(emb)
encoded_audio_list = padded
elif multi_audio_type == "add":
total_len = sum(seq_lengths)
full_list = []
offset = 0
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
full[offset:offset+seq_len] = emb
full_list.append(full)
offset += seq_len
encoded_audio_list = full_list
token_ref_target_masks = None
if ref_masks is not None:
token_ref_target_masks = torch.nn.functional.interpolate(
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
# when extending from previous frames
if previous_frames is not None:
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
frame_offset = previous_frames.shape[0] - motion_frame_count
audio_start = frame_offset
audio_end = audio_start + length
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
trim_image = motion_frame_count
else:
audio_start = trim_image = 0
audio_end = length
motion_frames_latent = concat_latent_image[:, :, :1]
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
# add outer sample wrapper
model_patched.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
"infinite_talk_outer_sample",
InfiniteTalkOuterSampleWrapper(
motion_frames_latent,
model_patch,
is_extend=previous_frames is not None,
))
# add cross-attention patch
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
if token_ref_target_masks is not None:
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -1307,6 +1475,7 @@ class WanExtension(ComfyExtension):
WanHuMoImageToVideo,
WanAnimateToVideo,
Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo,
]
async def comfy_entrypoint() -> WanExtension:

View File

@@ -0,0 +1,536 @@
import nodes
import node_helpers
import torch
import torchvision.transforms.functional as TF
import comfy.model_management
import comfy.utils
import numpy as np
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_wan import parse_json_tracks
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
from PIL import Image, ImageDraw
SKIP_ZERO = False
def get_pos_emb(
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
pos_emb_dim: int,
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
pos_k = pos_k.to(device, dtype)
if SKIP_ZERO:
pos_k = pos_k + 1
batch_size = pos_k.size(0)
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
# Expand denominator to match the shape needed for broadcasting
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
thetas = theta_func(denominator_expanded, pos_emb_dim)
# Ensure pos_k is in the correct shape for broadcasting
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
# Concatenate sine and cosine embeddings along the last dimension
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
return pos_emb
def create_pos_embeddings(
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
height: int, # the height of the feature map
width: int, # the width of the feature map
track_num: int = -1, # the number of tracks to use
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
):
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
t, n, _ = pred_tracks.shape
t_down, h_down, w_down = downsample_ratios
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
if track_num == -1:
track_num = n
tracks_idx = torch.randperm(n)[:track_num]
tracks = pred_tracks[:, tracks_idx]
visibility = pred_visibility[:, tracks_idx]
for t_idx in range(0, t, t_down):
if t_down_strategy == "sample" or t_idx == 0:
cur_tracks = tracks[t_idx] # [N, 2]
cur_visibility = visibility[t_idx] # [N]
else:
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
for i in range(track_num):
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
continue
x, y = cur_tracks[i]
x, y = int(x // w_down), int(y // h_down)
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
def replace_feature(
vae_feature: torch.Tensor, # [B, C', T', H', W']
track_pos: torch.Tensor, # [B, N, T', 2]
strength: float = 1.0
) -> torch.Tensor:
b, _, t, h, w = vae_feature.shape
assert b == track_pos.shape[0], "Batch size mismatch."
n = track_pos.shape[1]
# Shuffle the trajectory order
track_pos = track_pos[:, torch.randperm(n), :, :]
# Extract coordinates at time steps ≥ 1 and generate a valid mask
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
# Get all valid indices
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
num_valid = valid_indices.shape[0]
if num_valid == 0:
return vae_feature
# Decompose valid indices into each dimension
batch_idx = valid_indices[:, 0]
track_idx = valid_indices[:, 1]
t_rel = valid_indices[:, 2]
t_target = t_rel + 1 # Convert to original time step indices
# Extract target position coordinates
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
# Extract source position coordinates (t=0)
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
# Get source features and assign to target positions
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
return vae_feature
# Visualize functions
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
draw = ImageDraw.Draw(overlay, 'RGBA')
points = points[::-1]
# Compute total length
total_length = 0
segment_lengths = []
for i in range(len(points) - 1):
dx = points[i + 1][0] - points[i][0]
dy = points[i + 1][1] - points[i][1]
length = (dx * dx + dy * dy) ** 0.5
segment_lengths.append(length)
total_length += length
if total_length == 0:
return
accumulated_length = 0
# Draw the gradient polyline
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
segment_length = segment_lengths[idx]
steps = max(int(segment_length), 1)
for i in range(steps):
current_length = accumulated_length + (i / steps) * segment_length
ratio = current_length / total_length
alpha = int(255 * (1 - ratio) * opacity)
color = (*start_color, alpha)
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
accumulated_length += segment_length
def add_weighted(rgb, track):
rgb = np.array(rgb) # [H, W, C] "RGB"
track = np.array(track) # [H, W, C] "RGBA"
alpha = track[:, :, 3] / 255.0
alpha = np.stack([alpha] * 3, axis=-1)
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
return Image.fromarray(blend_img.astype(np.uint8))
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
tracks = tracks[0].long().detach().cpu().numpy()
if visibility is not None:
visibility = visibility[0].detach().cpu().numpy()
num_frames, height, width = video.shape[:3]
num_tracks = tracks.shape[1]
alpha_opacity = int(255 * opacity)
output_frames = []
for t in range(num_frames):
frame_rgb = video[t].astype(np.float32)
# Create a single RGBA overlay for all tracks in this frame
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
draw_overlay = ImageDraw.Draw(overlay)
polyline_data = []
# Draw all circles on a single overlay
for n in range(num_tracks):
if visibility is not None and visibility[t, n] == 0:
continue
track_coord = tracks[t, n]
color = color_map[n % len(color_map)]
circle_color = color + (alpha_opacity,)
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
fill=circle_color
)
# Store polyline data for batch processing
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
if len(tracks_coord) > 1:
polyline_data.append((tracks_coord, color))
# Blend circles overlay once
overlay_np = np.array(overlay)
alpha = overlay_np[:, :, 3:4] / 255.0
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
# Draw all polylines on a single overlay
if polyline_data:
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
for tracks_coord, color in polyline_data:
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
# Blend polylines overlay once
polyline_np = np.array(polyline_overlay)
alpha = polyline_np[:, :, 3:4] / 255.0
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
return output_frames
class WanMoveVisualizeTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveVisualizeTracks",
category="conditioning/video_models",
inputs=[
io.Image.Input("images"),
io.Tracks.Input("tracks", optional=True),
io.Int.Input("line_resolution", default=24, min=1, max=1024),
io.Int.Input("circle_size", default=12, min=1, max=128),
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
io.Int.Input("line_width", default=16, min=1, max=128),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
if tracks is None:
return io.NodeOutput(images)
track_path = tracks["track_path"].unsqueeze(0)
track_visibility = tracks["track_visibility"].unsqueeze(0)
images_in = images * 255.0
if images_in.shape[0] != track_path.shape[1]:
repeat_count = track_path.shape[1] // images.shape[0]
images_in = images_in.repeat(repeat_count, 1, 1, 1)
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
class WanMoveTracksFromCoords(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTracksFromCoords",
category="conditioning/video_models",
inputs=[
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
io.Mask.Input("track_mask", optional=True),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
tracks_data = parse_json_tracks(track_coords)
track_length = len(tracks_data[0])
track_list = [
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
for frame in range(len(tracks_data[0]))
]
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
num_tracks = tracks.shape[-2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class GenerateTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GenerateTracks",
search_aliases=["motion paths", "camera movement", "trajectory"],
category="conditioning/video_models",
inputs=[
io.Int.Input("width", default=832, min=16, max=4096, step=16),
io.Int.Input("height", default=480, min=16, max=4096, step=16),
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
io.Int.Input("num_frames", default=81, min=1, max=1024),
io.Int.Input("num_tracks", default=5, min=1, max=100),
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Combo.Input(
"interpolation",
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
tooltip="Controls the timing/speed of movement along the path.",
),
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
device = comfy.model_management.intermediate_device()
track_length = num_frames
# normalized coordinates to pixel coordinates
start_x_px = start_x * width
start_y_px = start_y * height
mid_x_px = mid_x * width
mid_y_px = mid_y * height
end_x_px = end_x * width
end_y_px = end_y * height
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
t = torch.linspace(0, 1, num_frames, device=device)
if interpolation == "constant": # All points stay at start position
interp_values = torch.zeros_like(t)
elif interpolation == "linear":
interp_values = t
elif interpolation == "ease_in":
interp_values = t ** 2
elif interpolation == "ease_out":
interp_values = 1 - (1 - t) ** 2
elif interpolation == "ease_in_out":
interp_values = t * t * (3 - 2 * t)
if bezier: # apply interpolation to t for timing control along the bezier path
t_interp = interp_values
one_minus_t = 1 - t_interp
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
else: # calculate base x and y positions for each frame (center track)
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
# For non-bezier, tangent is constant (direction from start to end)
tangent_x = torch.full_like(t, end_x_px - start_x_px)
tangent_y = torch.full_like(t, end_y_px - start_y_px)
track_list = []
for frame_idx in range(num_frames):
# Calculate perpendicular direction at this frame
tx = tangent_x[frame_idx].item()
ty = tangent_y[frame_idx].item()
length = (tx ** 2 + ty ** 2) ** 0.5
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
perp_x = -ty / length
perp_y = tx / length
else: # If tangent is zero, spread horizontally
perp_x = 1.0
perp_y = 0.0
frame_tracks = []
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
track_x = x_positions[frame_idx].item() + perp_x * offset
track_y = y_positions[frame_idx].item() + perp_y * offset
frame_tracks.append([track_x, track_y])
track_list.append(frame_tracks)
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class WanMoveConcatTrack(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveConcatTrack",
category="conditioning/video_models",
inputs=[
io.Tracks.Input("tracks_1"),
io.Tracks.Input("tracks_2", optional=True),
],
outputs=[
io.Tracks.Output(),
],
)
@classmethod
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
if tracks_2 is None:
return io.NodeOutput(tracks_1)
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
out_track_info = {}
out_track_info["track_path"] = tracks_out
out_track_info["track_visibility"] = mask_out
return io.NodeOutput(out_track_info)
class WanMoveTrackToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Tracks.Input("tracks", optional=True),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image"),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if tracks is not None and strength > 0.0:
tracks_path = tracks["track_path"][:length] # [T, N, 2]
num_tracks = tracks_path.shape[-2]
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
else:
concat_latent_image_pos = concat_latent_image
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanMoveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanMoveTrackToVideo,
WanMoveTracksFromCoords,
WanMoveConcatTrack,
WanMoveVisualizeTracks,
GenerateTracks,
]
async def comfy_entrypoint() -> WanMoveExtension:
return WanMoveExtension()

View File

@@ -5,6 +5,7 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION
class WebcamCapture(nodes.LoadImage):
SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"]
@classmethod
def INPUT_TYPES(s):
return {

View File

@@ -0,0 +1,88 @@
import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import math
import comfy.utils
class TextEncodeZImageOmni(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeZImageOmni",
category="advanced/conditioning",
is_experimental=True,
inputs=[
io.Clip.Input("clip"),
io.ClipVision.Input("image_encoder", optional=True),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Boolean.Input("auto_resize_images", default=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
ref_latents = []
images = list(filter(lambda a: a is not None, [image1, image2, image3]))
prompt_list = []
template = None
if len(images) > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1)
prompt_list += ["<|vision_end|><|im_end|>"]
template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"
encoded_images = []
for i, image in enumerate(images):
if image_encoder is not None:
encoded_images.append(image_encoder.encode_image(image))
if vae is not None:
if auto_resize_images:
samples = image.movedim(-1, 1)
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 8.0) * 8
height = round(samples.shape[2] * scale_by / 8.0) * 8
image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
ref_latents.append(vae.encode(image))
tokens = clip.tokenize(prompt, llama_template=template)
conditioning = clip.encode_from_tokens_scheduled(tokens)
extra_text_embeds = []
for p in prompt_list:
tokens = clip.tokenize(p, llama_template="{}")
text_embeds = clip.encode_from_tokens_scheduled(tokens)
extra_text_embeds.append(text_embeds[0][0])
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
if len(encoded_images) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True)
if len(extra_text_embeds) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True)
return io.NodeOutput(conditioning)
class ZImageExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeZImageOmni,
]
async def comfy_entrypoint() -> ZImageExtension:
return ZImageExtension()