mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-09 23:30:02 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
@@ -28,12 +28,45 @@ class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Int.Input("bpm", default=120, min=10, max=300),
|
||||
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
@@ -51,12 +84,61 @@ class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = round((seconds * 48000 / 1920))
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
TextEncodeAceStepAudio15,
|
||||
EmptyAceStep15LatentAudio,
|
||||
ReferenceAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
|
||||
@@ -28,6 +28,7 @@ class AlignYourStepsScheduler(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AlignYourStepsScheduler",
|
||||
search_aliases=["AYS scheduler"],
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
|
||||
@@ -55,7 +55,8 @@ class APG(io.ComfyNode):
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
||||
if len(args["conds_out"]) == 1:
|
||||
return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
|
||||
@@ -71,6 +71,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
search_aliases=["clip attention scale", "text encoder attention"],
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@@ -6,279 +6,253 @@ import torch
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
import logging
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types import FileLocator
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
|
||||
class EmptyLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
class EmptyLatentAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentAudio",
|
||||
display_name="Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||
return ({"samples":latent, "type": "audio"}, )
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return IO.NodeOutput({"samples":latent, "type": "audio"})
|
||||
|
||||
class ConditioningStableAudio:
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class ConditioningStableAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
||||
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ConditioningStableAudio",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
IO.Conditioning.Input("positive"),
|
||||
IO.Conditioning.Input("negative"),
|
||||
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
||||
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, positive, negative, seconds_start, seconds_total):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||
return (positive, negative)
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class VAEEncodeAudio:
|
||||
append = execute # TODO: remove
|
||||
|
||||
|
||||
class VAEEncodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEEncodeAudio",
|
||||
search_aliases=["audio to latent"],
|
||||
display_name="VAE Encode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def encode(self, vae, audio):
|
||||
@classmethod
|
||||
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||
sample_rate = audio["sample_rate"]
|
||||
if 44100 != sample_rate:
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
if vae_sample_rate != sample_rate:
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
|
||||
else:
|
||||
waveform = audio["waveform"]
|
||||
|
||||
t = vae.encode(waveform.movedim(1, -1))
|
||||
return ({"samples":t}, )
|
||||
return IO.NodeOutput({"samples": t})
|
||||
|
||||
class VAEDecodeAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "decode"
|
||||
encode = execute # TODO: remove
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def decode(self, vae, samples):
|
||||
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
if tile is not None:
|
||||
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||
else:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
|
||||
# Prepare metadata dictionary
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
# Opus supported sample rates
|
||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
|
||||
class SaveAudio:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class VAEDecodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudio",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples))
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_flac"
|
||||
decode = execute # TODO: remove
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
||||
|
||||
class SaveAudioMP3:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class VAEDecodeAudioTiled(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudioTiled",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio (Tiled)",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
|
||||
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_mp3"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class SaveAudioOpus:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudio",
|
||||
search_aliases=["export flac"],
|
||||
display_name="Save Audio (FLAC)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_opus"
|
||||
save_flac = execute # TODO: remove
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
self.type = "temp"
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
class SaveAudioMP3(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioMP3",
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"audio": ("AUDIO", ), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
)
|
||||
)
|
||||
|
||||
save_mp3 = execute # TODO: remove
|
||||
|
||||
|
||||
class SaveAudioOpus(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioOpus",
|
||||
search_aliases=["export opus"],
|
||||
display_name="Save Audio (Opus)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
)
|
||||
)
|
||||
|
||||
save_opus = execute # TODO: remove
|
||||
|
||||
|
||||
class PreviewAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PreviewAudio",
|
||||
search_aliases=["play audio"],
|
||||
display_name="Preview Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
||||
|
||||
save_flac = execute # TODO: remove
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format."""
|
||||
@@ -316,26 +290,31 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
|
||||
wav = f32_pcm(wav)
|
||||
return wav, sr
|
||||
|
||||
class LoadAudio:
|
||||
class LoadAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
def define_schema(cls):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||
return IO.Schema(
|
||||
node_id="LoadAudio",
|
||||
search_aliases=["import audio", "open audio", "audio file"],
|
||||
display_name="Load Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, audio):
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
waveform, sample_rate = load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, audio):
|
||||
def fingerprint_inputs(cls, audio):
|
||||
image_path = folder_paths.get_annotated_filepath(audio)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
@@ -343,46 +322,71 @@ class LoadAudio:
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, audio):
|
||||
def validate_inputs(cls, audio):
|
||||
if not folder_paths.exists_annotated_filepath(audio):
|
||||
return "Invalid audio file: {}".format(audio)
|
||||
return True
|
||||
|
||||
class RecordAudio:
|
||||
load = execute # TODO: remove
|
||||
|
||||
|
||||
class RecordAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecordAudio",
|
||||
search_aliases=["microphone input", "audio capture", "voice input"],
|
||||
display_name="Record Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Custom("AUDIO_RECORD").Input("audio"),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, audio):
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
|
||||
waveform, sample_rate = load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
load = execute # TODO: remove
|
||||
|
||||
|
||||
class TrimAudioDuration:
|
||||
class TrimAudioDuration(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"audio": ("AUDIO",),
|
||||
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
|
||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TrimAudioDuration",
|
||||
search_aliases=["cut audio", "audio clip", "shorten audio"],
|
||||
display_name="Trim Audio Duration",
|
||||
description="Trim audio tensor into chosen time range.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Float.Input(
|
||||
"start_index",
|
||||
default=0.0,
|
||||
min=-0xffffffffffffffff,
|
||||
max=0xffffffffffffffff,
|
||||
step=0.01,
|
||||
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"duration",
|
||||
default=60.0,
|
||||
min=0.0,
|
||||
step=0.01,
|
||||
tooltip="Duration in seconds",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
FUNCTION = "trim"
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Trim audio tensor into chosen time range."
|
||||
|
||||
def trim(self, audio, start_index, duration):
|
||||
@classmethod
|
||||
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
audio_length = waveform.shape[-1]
|
||||
@@ -399,23 +403,31 @@ class TrimAudioDuration:
|
||||
if start_frame >= end_frame:
|
||||
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
||||
|
||||
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
|
||||
|
||||
trim = execute # TODO: remove
|
||||
|
||||
|
||||
class SplitAudioChannels:
|
||||
class SplitAudioChannels(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio": ("AUDIO",),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SplitAudioChannels",
|
||||
search_aliases=["stereo to mono"],
|
||||
display_name="Split Audio Channels",
|
||||
description="Separates the audio into left and right channels.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(display_name="left"),
|
||||
IO.Audio.Output(display_name="right"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO", "AUDIO")
|
||||
RETURN_NAMES = ("left", "right")
|
||||
FUNCTION = "separate"
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Separates the audio into left and right channels."
|
||||
|
||||
def separate(self, audio):
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
@@ -425,7 +437,61 @@ class SplitAudioChannels:
|
||||
left_channel = waveform[..., 0:1, :]
|
||||
right_channel = waveform[..., 1:2, :]
|
||||
|
||||
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
||||
return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
||||
|
||||
separate = execute # TODO: remove
|
||||
|
||||
class JoinAudioChannels(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="JoinAudioChannels",
|
||||
display_name="Join Audio Channels",
|
||||
description="Joins left and right mono audio channels into a stereo audio.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio_left"),
|
||||
IO.Audio.Input("audio_right"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(display_name="audio"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
|
||||
waveform_left = audio_left["waveform"]
|
||||
sample_rate_left = audio_left["sample_rate"]
|
||||
waveform_right = audio_right["waveform"]
|
||||
sample_rate_right = audio_right["sample_rate"]
|
||||
|
||||
if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
|
||||
raise ValueError("AudioJoin: Both input audios must be mono.")
|
||||
|
||||
# Handle different sample rates by resampling to the higher rate
|
||||
waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
|
||||
waveform_left, sample_rate_left, waveform_right, sample_rate_right
|
||||
)
|
||||
|
||||
# Handle different lengths by trimming to the shorter length
|
||||
length_left = waveform_left.shape[-1]
|
||||
length_right = waveform_right.shape[-1]
|
||||
|
||||
if length_left != length_right:
|
||||
min_length = min(length_left, length_right)
|
||||
if length_left > min_length:
|
||||
logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
|
||||
waveform_left = waveform_left[..., :min_length]
|
||||
if length_right > min_length:
|
||||
logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
|
||||
waveform_right = waveform_right[..., :min_length]
|
||||
|
||||
# Join the channels into stereo
|
||||
left_channel = waveform_left[..., 0:1, :]
|
||||
right_channel = waveform_right[..., 0:1, :]
|
||||
stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
|
||||
|
||||
return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
|
||||
|
||||
|
||||
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
||||
@@ -443,21 +509,30 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
|
||||
return waveform_1, waveform_2, output_sample_rate
|
||||
|
||||
|
||||
class AudioConcat:
|
||||
class AudioConcat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio1": ("AUDIO",),
|
||||
"audio2": ("AUDIO",),
|
||||
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Audio Concat",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio1"),
|
||||
IO.Audio.Input("audio2"),
|
||||
IO.Combo.Input(
|
||||
"direction",
|
||||
options=['after', 'before'],
|
||||
default="after",
|
||||
tooltip="Whether to append audio2 after or before audio1.",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "concat"
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
|
||||
|
||||
def concat(self, audio1, audio2, direction):
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@@ -477,26 +552,34 @@ class AudioConcat:
|
||||
elif direction == 'before':
|
||||
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
||||
|
||||
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
|
||||
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
|
||||
|
||||
concat = execute # TODO: remove
|
||||
|
||||
|
||||
class AudioMerge:
|
||||
class AudioMerge(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"audio1": ("AUDIO",),
|
||||
"audio2": ("AUDIO",),
|
||||
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Audio Merge",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio1"),
|
||||
IO.Audio.Input("audio2"),
|
||||
IO.Combo.Input(
|
||||
"merge_method",
|
||||
options=["add", "mean", "subtract", "multiply"],
|
||||
tooltip="The method used to combine the audio waveforms.",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
FUNCTION = "merge"
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
|
||||
|
||||
def merge(self, audio1, audio2, merge_method):
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@@ -530,85 +613,114 @@ class AudioMerge:
|
||||
if max_val > 1.0:
|
||||
waveform = waveform / max_val
|
||||
|
||||
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
|
||||
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class AudioAdjustVolume:
|
||||
class AudioAdjustVolume(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio": ("AUDIO",),
|
||||
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Audio Adjust Volume",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Int.Input(
|
||||
"volume",
|
||||
default=1,
|
||||
min=-100,
|
||||
max=100,
|
||||
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "adjust_volume"
|
||||
CATEGORY = "audio"
|
||||
|
||||
def adjust_volume(self, audio, volume):
|
||||
@classmethod
|
||||
def execute(cls, audio, volume) -> IO.NodeOutput:
|
||||
if volume == 0:
|
||||
return (audio,)
|
||||
return IO.NodeOutput(audio)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
gain = 10 ** (volume / 20)
|
||||
waveform = waveform * gain
|
||||
|
||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||
|
||||
adjust_volume = execute # TODO: remove
|
||||
|
||||
|
||||
class EmptyAudio:
|
||||
class EmptyAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
|
||||
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
|
||||
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAudio",
|
||||
search_aliases=["blank audio"],
|
||||
display_name="Empty Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Float.Input(
|
||||
"duration",
|
||||
default=60.0,
|
||||
min=0.0,
|
||||
max=0xffffffffffffffff,
|
||||
step=0.01,
|
||||
tooltip="Duration of the empty audio clip in seconds",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"sample_rate",
|
||||
default=44100,
|
||||
tooltip="Sample rate of the empty audio clip.",
|
||||
min=1,
|
||||
max=192000,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"channels",
|
||||
default=2,
|
||||
min=1,
|
||||
max=2,
|
||||
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "create_empty_audio"
|
||||
CATEGORY = "audio"
|
||||
|
||||
def create_empty_audio(self, duration, sample_rate, channels):
|
||||
@classmethod
|
||||
def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
|
||||
num_samples = int(round(duration * sample_rate))
|
||||
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||
|
||||
create_empty_audio = execute # TODO: remove
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentAudio": EmptyLatentAudio,
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
"VAEDecodeAudio": VAEDecodeAudio,
|
||||
"SaveAudio": SaveAudio,
|
||||
"SaveAudioMP3": SaveAudioMP3,
|
||||
"SaveAudioOpus": SaveAudioOpus,
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
"RecordAudio": RecordAudio,
|
||||
"TrimAudioDuration": TrimAudioDuration,
|
||||
"SplitAudioChannels": SplitAudioChannels,
|
||||
"AudioConcat": AudioConcat,
|
||||
"AudioMerge": AudioMerge,
|
||||
"AudioAdjustVolume": AudioAdjustVolume,
|
||||
"EmptyAudio": EmptyAudio,
|
||||
}
|
||||
class AudioExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
EmptyLatentAudio,
|
||||
VAEEncodeAudio,
|
||||
VAEDecodeAudio,
|
||||
VAEDecodeAudioTiled,
|
||||
SaveAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
LoadAudio,
|
||||
PreviewAudio,
|
||||
ConditioningStableAudio,
|
||||
RecordAudio,
|
||||
TrimAudioDuration,
|
||||
SplitAudioChannels,
|
||||
JoinAudioChannels,
|
||||
AudioConcat,
|
||||
AudioMerge,
|
||||
AudioAdjustVolume,
|
||||
EmptyAudio,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"EmptyLatentAudio": "Empty Latent Audio",
|
||||
"VAEEncodeAudio": "VAE Encode Audio",
|
||||
"VAEDecodeAudio": "VAE Decode Audio",
|
||||
"PreviewAudio": "Preview Audio",
|
||||
"LoadAudio": "Load Audio",
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
"RecordAudio": "Record Audio",
|
||||
"TrimAudioDuration": "Trim Audio Duration",
|
||||
"SplitAudioChannels": "Split Audio Channels",
|
||||
"AudioConcat": "Audio Concat",
|
||||
"AudioMerge": "Audio Merge",
|
||||
"AudioAdjustVolume": "Audio Adjust Volume",
|
||||
"EmptyAudio": "Empty Audio",
|
||||
}
|
||||
async def comfy_entrypoint() -> AudioExtension:
|
||||
return AudioExtension()
|
||||
|
||||
@@ -10,6 +10,7 @@ class Canny(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Canny",
|
||||
search_aliases=["edge detection", "outline", "contour detection", "line art"],
|
||||
category="image/preprocessors",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
42
comfy_extras/nodes_color.py
Normal file
42
comfy_extras/nodes_color.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class ColorToRGBInt(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ColorToRGBInt",
|
||||
display_name="Color to RGB Int",
|
||||
category="utils",
|
||||
description="Convert a color to a RGB integer value.",
|
||||
inputs=[
|
||||
io.Color.Input("color"),
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="rgb_int"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
color: str,
|
||||
) -> io.NodeOutput:
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
|
||||
|
||||
|
||||
class ColorExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [ColorToRGBInt]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ColorExtension:
|
||||
return ColorExtension()
|
||||
@@ -109,6 +109,7 @@ class PorterDuffImageComposite(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PorterDuffImageComposite",
|
||||
search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
|
||||
display_name="Porter-Duff Image Composite",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
@@ -165,6 +166,7 @@ class SplitImageWithAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SplitImageWithAlpha",
|
||||
search_aliases=["extract alpha", "separate transparency", "remove alpha"],
|
||||
display_name="Split Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
@@ -188,6 +190,7 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="JoinImageWithAlpha",
|
||||
search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
|
||||
display_name="Join Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
|
||||
@@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
@@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
context_overlap=context_overlap,
|
||||
context_stride=context_stride,
|
||||
closed_loop=closed_loop,
|
||||
dim=dim)
|
||||
dim=dim,
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
if freenoise: # no other use for this wrapper at this time
|
||||
comfy.context_windows.create_sampler_sample_wrapper(model)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
@@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
|
||||
|
||||
class ContextWindowsExtension(ComfyExtension):
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
import nodes
|
||||
import comfy.utils
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class SetUnionControlNetType:
|
||||
class SetUnionControlNetType(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"control_net": ("CONTROL_NET", ),
|
||||
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SetUnionControlNetType",
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.ControlNet.Input("control_net"),
|
||||
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
|
||||
],
|
||||
outputs=[
|
||||
io.ControlNet.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
|
||||
FUNCTION = "set_controlnet_type"
|
||||
|
||||
def set_controlnet_type(self, control_net, type):
|
||||
@classmethod
|
||||
def execute(cls, control_net, type) -> io.NodeOutput:
|
||||
control_net = control_net.copy()
|
||||
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
||||
if type_number >= 0:
|
||||
@@ -22,27 +28,37 @@ class SetUnionControlNetType:
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
return (control_net,)
|
||||
return io.NodeOutput(control_net)
|
||||
|
||||
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||
set_controlnet_type = execute # TODO: remove
|
||||
|
||||
|
||||
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"control_net": ("CONTROL_NET", ),
|
||||
"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ControlNetInpaintingAliMamaApply",
|
||||
search_aliases=["masked controlnet"],
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.ControlNet.Input("control_net"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Mask.Input("mask"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
FUNCTION = "apply_inpaint_controlnet"
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
|
||||
extra_concat = []
|
||||
if control_net.concat_mask:
|
||||
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
@@ -50,11 +66,20 @@ class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||
extra_concat = [mask]
|
||||
|
||||
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
return io.NodeOutput(result[0], result[1])
|
||||
|
||||
apply_inpaint_controlnet = execute # TODO: remove
|
||||
|
||||
|
||||
class ControlNetExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SetUnionControlNetType,
|
||||
ControlNetInpaintingAliMamaApply,
|
||||
]
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> ControlNetExtension:
|
||||
return ControlNetExtension()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1529
comfy_extras/nodes_dataset.py
Normal file
1529
comfy_extras/nodes_dataset.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DifferentialDiffusion",
|
||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||
display_name="Differential Diffusion",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@@ -9,15 +9,23 @@ if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _extract_tensor(data, output_channels):
|
||||
"""Extract tensor from data, handling both single tensors and lists."""
|
||||
if isinstance(data, list):
|
||||
# LTX2 AV tensors: [video, audio]
|
||||
return data[0][:, :output_channels], data[1][:, :output_channels]
|
||||
return data[:, :output_channels], None
|
||||
|
||||
|
||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
transformer_options: dict[str] = args[-1]
|
||||
if not isinstance(transformer_options, dict):
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||
x, ax = _extract_tensor(args[0], easycache.output_channels)
|
||||
sigmas = transformer_options["sigmas"]
|
||||
uuids = transformer_options["uuids"]
|
||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||
@@ -29,11 +37,17 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
do_easycache = easycache.should_do_easycache(sigmas)
|
||||
if do_easycache:
|
||||
easycache.check_metadata(x)
|
||||
# if there isn't a cache diff for current conds, we cannot skip this step
|
||||
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
|
||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||
if easycache.skip_current_step:
|
||||
if easycache.skip_current_step and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
result = easycache.apply_cache_diff(x, uuids)
|
||||
if ax is not None:
|
||||
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||
return [result, result_audio]
|
||||
return result
|
||||
if easycache.initial_step:
|
||||
easycache.first_cond_uuid = uuids[0]
|
||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||
@@ -44,18 +58,23 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
easycache.skip_current_step = True
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
result = easycache.apply_cache_diff(x, uuids)
|
||||
if ax is not None:
|
||||
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||
return [result, result_audio]
|
||||
return result
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
full_output: torch.Tensor = executor(*args, **kwargs)
|
||||
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
|
||||
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
@@ -72,22 +91,24 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||
if audio_output is not None:
|
||||
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
|
||||
if has_first_cond_uuid:
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
return output
|
||||
return full_output
|
||||
|
||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
timestep: float = args[1]
|
||||
model_options: dict[str] = args[2]
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
@@ -173,7 +194,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
|
||||
|
||||
|
||||
class EasyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||
self.name = "EasyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
@@ -195,6 +216,7 @@ class EasyCacheHolder:
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
||||
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
@@ -202,6 +224,7 @@ class EasyCacheHolder:
|
||||
self.allow_mismatch = True
|
||||
self.cut_from_start = True
|
||||
self.state_metadata = None
|
||||
self.output_channels = output_channels
|
||||
|
||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||
return not (timestep[0] > self.end_t).item()
|
||||
@@ -239,18 +262,24 @@ class EasyCacheHolder:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||
if self.first_cond_uuid in uuids and not is_audio:
|
||||
self.total_steps_skipped += 1
|
||||
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
# slice out only what is relevant to this cond
|
||||
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
||||
slicing = []
|
||||
skip_this_dim = True
|
||||
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
||||
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
|
||||
if skip_this_dim:
|
||||
skip_this_dim = False
|
||||
continue
|
||||
@@ -261,12 +290,12 @@ class EasyCacheHolder:
|
||||
slicing.append(slice(None, dim_u))
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
||||
x = x[slicing]
|
||||
x += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
batch_slice = batch_slice + slicing
|
||||
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if output.shape[1:] != x.shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
@@ -282,11 +311,11 @@ class EasyCacheHolder:
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
skip_dim = False
|
||||
x = x[slicing]
|
||||
x = x[tuple(slicing)]
|
||||
diff = output - x
|
||||
batch_offset = diff.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
|
||||
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||
return self.first_cond_uuid in uuids
|
||||
@@ -317,12 +346,14 @@ class EasyCacheHolder:
|
||||
self.output_prev_norm = None
|
||||
del self.uuid_cache_diffs
|
||||
self.uuid_cache_diffs = {}
|
||||
del self.uuid_cache_diffs_audio
|
||||
self.uuid_cache_diffs_audio = {}
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||
|
||||
|
||||
class EasyCacheNode(io.ComfyNode):
|
||||
@@ -349,7 +380,7 @@ class EasyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
||||
@@ -357,7 +388,7 @@ class EasyCacheNode(io.ComfyNode):
|
||||
|
||||
|
||||
class LazyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||
self.name = "LazyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
@@ -381,6 +412,7 @@ class LazyCacheHolder:
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
self.output_channels = output_channels
|
||||
|
||||
def has_cache_diff(self) -> bool:
|
||||
return self.cache_diff is not None
|
||||
@@ -455,7 +487,7 @@ class LazyCacheHolder:
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||
|
||||
class LazyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -481,7 +513,7 @@ class LazyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
@@ -2,7 +2,10 @@ import node_helpers
|
||||
import comfy.utils
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import math
|
||||
import nodes
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
|
||||
encode = execute # TODO: remove
|
||||
|
||||
class EmptyFlux2LatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyFlux2LatentImage",
|
||||
display_name="Empty Flux 2 Latent",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
class FluxGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -130,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FluxKontextMultiReferenceLatentMethod",
|
||||
display_name="Edit Model Reference Method",
|
||||
category="advanced/conditioning/flux",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Combo.Input(
|
||||
"reference_latents_method",
|
||||
options=["offset", "index", "uxo/uno"],
|
||||
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@@ -154,6 +179,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
||||
append = execute # TODO: remove
|
||||
|
||||
|
||||
def generalized_time_snr_shift(t, mu: float, sigma: float):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
|
||||
mu = compute_empirical_mu(image_seq_len, num_steps)
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
|
||||
return timesteps
|
||||
|
||||
|
||||
class Flux2Scheduler(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Flux2Scheduler",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=4096),
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Sigmas.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, steps, width, height) -> io.NodeOutput:
|
||||
seq_len = (width * height / (16 * 16))
|
||||
sigmas = get_schedule(steps, round(seq_len))
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
|
||||
class FluxExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -163,6 +240,8 @@ class FluxExtension(ComfyExtension):
|
||||
FluxDisableGuidance,
|
||||
FluxKontextImageScale,
|
||||
FluxKontextMultiReferenceLatentMethod,
|
||||
EmptyFlux2LatentImage,
|
||||
Flux2Scheduler,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
def Fourier_filter(x, threshold, scale):
|
||||
# FFT
|
||||
@@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
|
||||
return x_filtered.to(x.dtype)
|
||||
|
||||
|
||||
class FreeU:
|
||||
class FreeU(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FreeU",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
@classmethod
|
||||
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||
on_cpu_devices = {}
|
||||
@@ -59,23 +66,31 @@ class FreeU:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m, )
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
class FreeU_V2:
|
||||
patch = execute # TODO: remove
|
||||
|
||||
|
||||
class FreeU_V2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FreeU_V2",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
@classmethod
|
||||
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||
on_cpu_devices = {}
|
||||
@@ -105,9 +120,19 @@ class FreeU_V2:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m, )
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FreeU": FreeU,
|
||||
"FreeU_V2": FreeU_V2,
|
||||
}
|
||||
patch = execute # TODO: remove
|
||||
|
||||
|
||||
class FreelunchExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
FreeU,
|
||||
FreeU_V2,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> FreelunchExtension:
|
||||
return FreelunchExtension()
|
||||
|
||||
@@ -58,6 +58,7 @@ class FreSca(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FreSca",
|
||||
search_aliases=["frequency guidance"],
|
||||
display_name="FreSca",
|
||||
category="_for_testing",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
|
||||
@@ -38,6 +38,7 @@ class CLIPTextEncodeHiDream(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeHiDream",
|
||||
search_aliases=["hidream prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@@ -259,6 +259,7 @@ class SetClipHooks:
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
|
||||
NodeId = 'ConditioningTimestepsRange'
|
||||
NodeName = 'Timesteps Range'
|
||||
@classmethod
|
||||
@@ -468,6 +469,7 @@ class SetHookKeyframes:
|
||||
return (hooks,)
|
||||
|
||||
class CreateHookKeyframe:
|
||||
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
|
||||
NodeId = 'CreateHookKeyframe'
|
||||
NodeName = 'Create Hook Keyframe'
|
||||
@classmethod
|
||||
@@ -497,6 +499,7 @@ class CreateHookKeyframe:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesInterpolated:
|
||||
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
|
||||
NodeId = 'CreateHookKeyframesInterpolated'
|
||||
NodeName = 'Create Hook Keyframes Interp.'
|
||||
@classmethod
|
||||
@@ -544,6 +547,7 @@ class CreateHookKeyframesInterpolated:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesFromFloats:
|
||||
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
|
||||
NodeId = 'CreateHookKeyframesFromFloats'
|
||||
NodeName = 'Create Hook Keyframes From Floats'
|
||||
@classmethod
|
||||
@@ -618,6 +622,7 @@ class SetModelHooksOnCond:
|
||||
# Combine Hooks
|
||||
#------------------------------------------
|
||||
class CombineHooks:
|
||||
SEARCH_ALIASES = ["merge hooks"]
|
||||
NodeId = 'CombineHooks2'
|
||||
NodeName = 'Combine Hooks [2]'
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,10 @@ import torch
|
||||
import comfy.model_management
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
|
||||
import folder_paths
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -37,6 +40,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyHunyuanLatentVideo",
|
||||
display_name="Empty HunyuanVideo 1.0 Latent",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
@@ -52,11 +56,208 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples":latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "EmptyHunyuanVideo15Latent"
|
||||
schema.display_name = "Empty HunyuanVideo 1.5 Latent"
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
# Using scale factor of 16 instead of 8
|
||||
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
|
||||
|
||||
|
||||
class HunyuanVideo15ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15ImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device())
|
||||
concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
|
||||
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class HunyuanVideo15SuperResolution(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15SuperResolution",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Latent.Input("latent"),
|
||||
io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01),
|
||||
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
in_latent = latent["samples"]
|
||||
in_channels = in_latent.shape[1]
|
||||
cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent
|
||||
cond_latent[:, 2 * in_channels + 1] = 1
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1)
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded
|
||||
cond_latent[:, in_channels + 1, 0] = 1
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
return io.NodeOutput(positive, negative, latent)
|
||||
|
||||
|
||||
class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentUpscaleModelLoader",
|
||||
display_name="Load Latent Upscale Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentUpscaleModel.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
|
||||
|
||||
if "blocks.0.block.0.conv.weight" in sd:
|
||||
config = {
|
||||
"in_channels": sd["in_conv.conv.weight"].shape[1],
|
||||
"out_channels": sd["out_conv.conv.weight"].shape[0],
|
||||
"hidden_channels": sd["in_conv.conv.weight"].shape[0],
|
||||
"num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]),
|
||||
"global_residual": False,
|
||||
}
|
||||
model_type = "720p"
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||
config = {
|
||||
"z_channels": sd["conv_in.conv.weight"].shape[1],
|
||||
"out_channels": sd["conv_out.conv.weight"].shape[0],
|
||||
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||
}
|
||||
model_type = "1080p"
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
|
||||
config = json.loads(metadata["config"])
|
||||
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
|
||||
model.load_state_dict(sd)
|
||||
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15LatentUpscaleWithModel",
|
||||
display_name="Hunyuan Video 15 Latent Upscale With Model",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.LatentUpscaleModel.Input("model"),
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"),
|
||||
io.Int.Input("width", default=1280, min=0, max=16384, step=8),
|
||||
io.Int.Input("height", default=720, min=0, max=16384, step=8),
|
||||
io.Combo.Input("crop", options=["disabled", "center"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput:
|
||||
if width == 0 and height == 0:
|
||||
return io.NodeOutput(samples)
|
||||
else:
|
||||
if width == 0:
|
||||
height = max(64, height)
|
||||
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
|
||||
elif height == 0:
|
||||
width = max(64, width)
|
||||
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
|
||||
else:
|
||||
width = max(64, width)
|
||||
height = max(64, height)
|
||||
s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop)
|
||||
s = model.resample_latent(s)
|
||||
return io.NodeOutput({"samples": s.cpu().float()})
|
||||
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
@@ -210,6 +411,11 @@ class HunyuanExtension(ComfyExtension):
|
||||
CLIPTextEncodeHunyuanDiT,
|
||||
TextEncodeHunyuanVideo_ImageToVideo,
|
||||
EmptyHunyuanLatentVideo,
|
||||
EmptyHunyuanVideo15Latent,
|
||||
HunyuanVideo15ImageToVideo,
|
||||
HunyuanVideo15SuperResolution,
|
||||
HunyuanVideo15LatentUpscaleWithModel,
|
||||
LatentUpscaleModelLoader,
|
||||
HunyuanImageToVideo,
|
||||
EmptyHunyuanImageLatent,
|
||||
HunyuanRefinerLatent,
|
||||
|
||||
@@ -7,63 +7,79 @@ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_fro
|
||||
import folder_paths
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
|
||||
|
||||
class EmptyLatentHunyuan3Dv2:
|
||||
|
||||
class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentHunyuan3Dv2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def generate(self, resolution, batch_size):
|
||||
@classmethod
|
||||
def execute(cls, resolution, batch_size) -> IO.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||
return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"})
|
||||
|
||||
class Hunyuan3Dv2Conditioning:
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2Conditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("clip_vision_output"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision_output):
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_output) -> IO.NodeOutput:
|
||||
embeds = clip_vision_output.last_hidden_state
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
class Hunyuan3Dv2ConditioningMultiView:
|
||||
class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {},
|
||||
"optional": {"front": ("CLIP_VISION_OUTPUT",),
|
||||
"left": ("CLIP_VISION_OUTPUT",),
|
||||
"back": ("CLIP_VISION_OUTPUT",),
|
||||
"right": ("CLIP_VISION_OUTPUT",), }}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2ConditioningMultiView",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("front", optional=True),
|
||||
IO.ClipVisionOutput.Input("left", optional=True),
|
||||
IO.ClipVisionOutput.Input("back", optional=True),
|
||||
IO.ClipVisionOutput.Input("right", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, front=None, left=None, back=None, right=None):
|
||||
@classmethod
|
||||
def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput:
|
||||
all_embeds = [front, left, back, right]
|
||||
out = []
|
||||
pos_embeds = None
|
||||
@@ -76,29 +92,35 @@ class Hunyuan3Dv2ConditioningMultiView:
|
||||
embeds = torch.cat(out, dim=1)
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
class VAEDecodeHunyuan3D:
|
||||
class VAEDecodeHunyuan3D(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ),
|
||||
"vae": ("VAE", ),
|
||||
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
|
||||
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
|
||||
}}
|
||||
RETURN_TYPES = ("VOXEL",)
|
||||
FUNCTION = "decode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeHunyuan3D",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("num_chunks", default=8000, min=1000, max=500000),
|
||||
IO.Int.Input("octree_resolution", default=256, min=16, max=512),
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput:
|
||||
voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return IO.NodeOutput(voxels)
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
def decode(self, vae, samples, num_chunks, octree_resolution):
|
||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return (voxels, )
|
||||
|
||||
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
@@ -396,24 +418,24 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
||||
|
||||
return final_vertices, faces
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices, faces):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class VoxelToMeshBasic:
|
||||
class VoxelToMeshBasic(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, threshold):
|
||||
@classmethod
|
||||
def execute(cls, voxel, threshold) -> IO.NodeOutput:
|
||||
vertices = []
|
||||
faces = []
|
||||
for x in voxel.data:
|
||||
@@ -421,21 +443,29 @@ class VoxelToMeshBasic:
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
|
||||
class VoxelToMesh:
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
class VoxelToMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"algorithm": (["surface net", "basic"], ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMesh",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, algorithm, threshold):
|
||||
@classmethod
|
||||
def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput:
|
||||
vertices = []
|
||||
faces = []
|
||||
|
||||
@@ -449,7 +479,9 @@ class VoxelToMesh:
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None):
|
||||
@@ -581,50 +613,84 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
||||
return filepath
|
||||
|
||||
|
||||
class SaveGLB:
|
||||
class SaveGLB(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"mesh": ("MESH", ),
|
||||
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
display_name="Save 3D Model",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
IO.Mesh.Input("mesh"),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DFBX,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
),
|
||||
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
@classmethod
|
||||
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
if isinstance(mesh, Types.File3D):
|
||||
# Handle File3D input - save BytesIO data to output folder
|
||||
ext = mesh.format or "glb"
|
||||
f = f"{filename}_{counter:05}_.{ext}"
|
||||
mesh.save_to(os.path.join(full_output_folder, f))
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return {"ui": {"3d": results}}
|
||||
else:
|
||||
# Handle Mesh input - save vertices and faces as GLB
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
|
||||
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
|
||||
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
||||
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
||||
"VoxelToMeshBasic": VoxelToMeshBasic,
|
||||
"VoxelToMesh": VoxelToMesh,
|
||||
"SaveGLB": SaveGLB,
|
||||
}
|
||||
class Hunyuan3dExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
EmptyLatentHunyuan3Dv2,
|
||||
Hunyuan3Dv2Conditioning,
|
||||
Hunyuan3Dv2ConditioningMultiView,
|
||||
VAEDecodeHunyuan3D,
|
||||
VoxelToMeshBasic,
|
||||
VoxelToMesh,
|
||||
SaveGLB,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Hunyuan3dExtension:
|
||||
return Hunyuan3dExtension()
|
||||
|
||||
@@ -2,6 +2,9 @@ import comfy.utils
|
||||
import folder_paths
|
||||
import torch
|
||||
import logging
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def load_hypernetwork_patch(path, strength):
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
@@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
|
||||
|
||||
return hypernetwork_patch(out, strength)
|
||||
|
||||
class HypernetworkLoader:
|
||||
class HypernetworkLoader(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_hypernetwork"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HypernetworkLoader",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||
@classmethod
|
||||
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
model_hypernetwork.set_model_attn1_patch(patch)
|
||||
model_hypernetwork.set_model_attn2_patch(patch)
|
||||
return (model_hypernetwork,)
|
||||
return IO.NodeOutput(model_hypernetwork)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HypernetworkLoader": HypernetworkLoader
|
||||
}
|
||||
load_hypernetwork = execute # TODO: remove
|
||||
|
||||
|
||||
class HyperNetworkExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
HypernetworkLoader,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HyperNetworkExtension:
|
||||
return HyperNetworkExtension()
|
||||
|
||||
53
comfy_extras/nodes_image_compare.py
Normal file
53
comfy_extras/nodes_image_compare.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import nodes
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
|
||||
|
||||
class ImageCompare(IO.ComfyNode):
|
||||
"""Compares two images with a slider interface."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompare",
|
||||
display_name="Image Compare",
|
||||
description="Compares two images side by side with a slider.",
|
||||
category="image",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image_a", optional=True),
|
||||
IO.Image.Input("image_b", optional=True),
|
||||
IO.ImageCompare.Input("compare_view"),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput:
|
||||
result = {"a_images": [], "b_images": []}
|
||||
|
||||
preview_node = nodes.PreviewImage()
|
||||
|
||||
if image_a is not None and len(image_a) > 0:
|
||||
saved = preview_node.save_images(image_a, "comfy.compare.a")
|
||||
result["a_images"] = saved["ui"]["images"]
|
||||
|
||||
if image_b is not None and len(image_b) > 0:
|
||||
saved = preview_node.save_images(image_b, "comfy.compare.b")
|
||||
result["b_images"] = saved["ui"]["images"]
|
||||
|
||||
return IO.NodeOutput(ui=result)
|
||||
|
||||
|
||||
class ImageCompareExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ImageCompare,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ImageCompareExtension:
|
||||
return ImageCompareExtension()
|
||||
@@ -2,280 +2,235 @@ from __future__ import annotations
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator, IO
|
||||
from server import PromptServer
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
from typing_extensions import override
|
||||
|
||||
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
class ImageCrop:
|
||||
class ImageCrop(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "crop"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def crop(self, image, width, height, x, y):
|
||||
@classmethod
|
||||
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
|
||||
x = min(x, image.shape[2] - 1)
|
||||
y = min(y, image.shape[1] - 1)
|
||||
to_x = width + x
|
||||
to_y = height + y
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return (img,)
|
||||
return IO.NodeOutput(img)
|
||||
|
||||
class RepeatImageBatch:
|
||||
crop = execute # TODO: remove
|
||||
|
||||
|
||||
class RepeatImageBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RepeatImageBatch",
|
||||
search_aliases=["duplicate image", "clone image"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("amount", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def repeat(self, image, amount):
|
||||
@classmethod
|
||||
def execute(cls, image, amount) -> IO.NodeOutput:
|
||||
s = image.repeat((amount, 1,1,1))
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
class ImageFromBatch:
|
||||
repeat = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageFromBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "frombatch"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFromBatch",
|
||||
search_aliases=["select image", "pick from batch", "extract image"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("batch_index", default=0, min=0, max=4095),
|
||||
IO.Int.Input("length", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def frombatch(self, image, batch_index, length):
|
||||
@classmethod
|
||||
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
|
||||
s_in = image
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
frombatch = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageAddNoise:
|
||||
class ImageAddNoise(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
||||
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageAddNoise",
|
||||
search_aliases=["film grain"],
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def repeat(self, image, seed, strength):
|
||||
@classmethod
|
||||
def execute(cls, image, seed, strength) -> IO.NodeOutput:
|
||||
generator = torch.manual_seed(seed)
|
||||
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
repeat = execute # TODO: remove
|
||||
|
||||
methods = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"lossless": ("BOOLEAN", {"default": True}),
|
||||
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
|
||||
"method": (list(s.methods.keys()),),
|
||||
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||
method = self.methods.get(method)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results: list[FileLocator] = []
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = pil_images[0].getexif()
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
inital_exif = 0x010f
|
||||
for x in extra_pnginfo:
|
||||
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
|
||||
inital_exif -= 1
|
||||
|
||||
if num_frames == 0:
|
||||
num_frames = len(pil_images)
|
||||
|
||||
c = len(pil_images)
|
||||
for i in range(0, c, num_frames):
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
animated = num_frames != 1
|
||||
return { "ui": { "images": results, "animated": (animated,) } }
|
||||
|
||||
class SaveAnimatedPNG:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveAnimatedWEBP(IO.ComfyNode):
|
||||
COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAnimatedWEBP",
|
||||
category="image/animation",
|
||||
inputs=[
|
||||
IO.Image.Input("images"),
|
||||
IO.String.Input("filename_prefix", default="ComfyUI"),
|
||||
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
||||
IO.Boolean.Input("lossless", default=True),
|
||||
IO.Int.Input("quality", default=80, min=0, max=100),
|
||||
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
|
||||
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
@classmethod
|
||||
def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
|
||||
images=images,
|
||||
filename_prefix=filename_prefix,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=cls.COMPRESS_METHODS.get(method)
|
||||
)
|
||||
)
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
save_images = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageStitch:
|
||||
class SaveAnimatedPNG(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAnimatedPNG",
|
||||
category="image/animation",
|
||||
inputs=[
|
||||
IO.Image.Input("images"),
|
||||
IO.String.Input("filename_prefix", default="ComfyUI"),
|
||||
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
||||
IO.Int.Input("compress_level", default=4, min=0, max=9),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
|
||||
images=images,
|
||||
filename_prefix=filename_prefix,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
)
|
||||
|
||||
save_images = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageStitch(IO.ComfyNode):
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageStitch",
|
||||
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
|
||||
display_name="Image Stitch",
|
||||
description="Stitches image2 to image1 in the specified direction.\n"
|
||||
"If image2 is not provided, returns image1 unchanged.\n"
|
||||
"Optional spacing can be added between images.",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image1"),
|
||||
IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
|
||||
IO.Boolean.Input("match_image_size", default=True),
|
||||
IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
|
||||
IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
|
||||
IO.Image.Input("image2", optional=True),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
"spacing_width": (
|
||||
"INT",
|
||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||
),
|
||||
"spacing_color": (
|
||||
["white", "black", "red", "green", "blue"],
|
||||
{"default": "white"},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stitch"
|
||||
CATEGORY = "image/transform"
|
||||
DESCRIPTION = """
|
||||
Stitches image2 to image1 in the specified direction.
|
||||
If image2 is not provided, returns image1 unchanged.
|
||||
Optional spacing can be added between images.
|
||||
"""
|
||||
|
||||
def stitch(
|
||||
self,
|
||||
def execute(
|
||||
cls,
|
||||
image1,
|
||||
direction,
|
||||
match_image_size,
|
||||
spacing_width,
|
||||
spacing_color,
|
||||
image2=None,
|
||||
):
|
||||
) -> IO.NodeOutput:
|
||||
if image2 is None:
|
||||
return (image1,)
|
||||
return IO.NodeOutput(image1)
|
||||
|
||||
# Handle batch size differences
|
||||
if image1.shape[0] != image2.shape[0]:
|
||||
@@ -412,36 +367,30 @@ Optional spacing can be added between images.
|
||||
images.insert(1, spacing)
|
||||
|
||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||
return (torch.cat(images, dim=concat_dim),)
|
||||
return IO.NodeOutput(torch.cat(images, dim=concat_dim))
|
||||
|
||||
class ResizeAndPadImage:
|
||||
stitch = execute # TODO: remove
|
||||
|
||||
|
||||
class ResizeAndPadImage(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"target_width": ("INT", {
|
||||
"default": 512,
|
||||
"min": 1,
|
||||
"max": MAX_RESOLUTION,
|
||||
"step": 1
|
||||
}),
|
||||
"target_height": ("INT", {
|
||||
"default": 512,
|
||||
"min": 1,
|
||||
"max": MAX_RESOLUTION,
|
||||
"step": 1
|
||||
}),
|
||||
"padding_color": (["white", "black"],),
|
||||
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ResizeAndPadImage",
|
||||
search_aliases=["fit to size"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Combo.Input("padding_color", options=["white", "black"]),
|
||||
IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "resize_and_pad"
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
|
||||
@classmethod
|
||||
def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
|
||||
batch_size, orig_height, orig_width, channels = image.shape
|
||||
|
||||
scale_w = target_width / orig_width
|
||||
@@ -469,52 +418,47 @@ class ResizeAndPadImage:
|
||||
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
|
||||
|
||||
output = padded.permute(0, 2, 3, 1)
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
resize_and_pad = execute # TODO: remove
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "image/save" # Changed
|
||||
OUTPUT_NODE = True
|
||||
class SaveSVGNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveSVGNode",
|
||||
search_aliases=["export vector", "save vector graphics"],
|
||||
description="Save SVG files on disk.",
|
||||
category="image/save",
|
||||
inputs=[
|
||||
IO.SVG.Input("svg"),
|
||||
IO.String.Input(
|
||||
"filename_prefix",
|
||||
default="svg/ComfyUI",
|
||||
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
|
||||
),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": ("SVG",), # Changed
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results: list[UI.SavedResult] = []
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata_dict["prompt"] = cls.hidden.prompt
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
metadata_dict.update(cls.hidden.extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
@@ -544,57 +488,64 @@ class SaveSVGNode:
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
return IO.NodeOutput(ui={"images": results})
|
||||
|
||||
class GetImageSize:
|
||||
save_svg = execute # TODO: remove
|
||||
|
||||
|
||||
class GetImageSize(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GetImageSize",
|
||||
search_aliases=["dimensions", "resolution", "image info"],
|
||||
display_name="Get Image Size",
|
||||
description="Returns width and height of the image, and passes it through unchanged.",
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Int.Output(display_name="width"),
|
||||
IO.Int.Output(display_name="height"),
|
||||
IO.Int.Output(display_name="batch_size"),
|
||||
],
|
||||
hidden=[IO.Hidden.unique_id],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
|
||||
RETURN_NAMES = ("width", "height", "batch_size")
|
||||
FUNCTION = "get_size"
|
||||
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
||||
|
||||
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
||||
def execute(cls, image) -> IO.NodeOutput:
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
batch_size = image.shape[0]
|
||||
|
||||
# Send progress text to display size on the node
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
|
||||
|
||||
return width, height, batch_size
|
||||
return IO.NodeOutput(width, height, batch_size)
|
||||
|
||||
class ImageRotate:
|
||||
get_size = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageRotate(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "rotate"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageRotate",
|
||||
search_aliases=["turn", "flip orientation"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def rotate(self, image, rotation):
|
||||
@classmethod
|
||||
def execute(cls, image, rotation) -> IO.NodeOutput:
|
||||
rotate_by = 0
|
||||
if rotation.startswith("90"):
|
||||
rotate_by = 1
|
||||
@@ -604,41 +555,57 @@ class ImageRotate:
|
||||
rotate_by = 3
|
||||
|
||||
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
|
||||
return (image,)
|
||||
return IO.NodeOutput(image)
|
||||
|
||||
class ImageFlip:
|
||||
rotate = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageFlip(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "flip"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFlip",
|
||||
search_aliases=["mirror", "reflect"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def flip(self, image, flip_method):
|
||||
@classmethod
|
||||
def execute(cls, image, flip_method) -> IO.NodeOutput:
|
||||
if flip_method.startswith("x"):
|
||||
image = torch.flip(image, dims=[1])
|
||||
elif flip_method.startswith("y"):
|
||||
image = torch.flip(image, dims=[2])
|
||||
|
||||
return (image,)
|
||||
return IO.NodeOutput(image)
|
||||
|
||||
class ImageScaleToMaxDimension:
|
||||
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
|
||||
flip = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageScaleToMaxDimension(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"image": ("IMAGE",),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "upscale"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageScaleToMaxDimension",
|
||||
category="image/upscaling",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input(
|
||||
"upscale_method",
|
||||
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
|
||||
),
|
||||
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
|
||||
def upscale(self, image, upscale_method, largest_size):
|
||||
@classmethod
|
||||
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
|
||||
@@ -655,20 +622,30 @@ class ImageScaleToMaxDimension:
|
||||
samples = image.movedim(-1, 1)
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"ImageAddNoise": ImageAddNoise,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
"ResizeAndPadImage": ResizeAndPadImage,
|
||||
"GetImageSize": GetImageSize,
|
||||
"ImageRotate": ImageRotate,
|
||||
"ImageFlip": ImageFlip,
|
||||
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
|
||||
}
|
||||
upscale = execute # TODO: remove
|
||||
|
||||
|
||||
class ImagesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ImageCrop,
|
||||
RepeatImageBatch,
|
||||
ImageFromBatch,
|
||||
ImageAddNoise,
|
||||
SaveAnimatedWEBP,
|
||||
SaveAnimatedPNG,
|
||||
SaveSVGNode,
|
||||
ImageStitch,
|
||||
ResizeAndPadImage,
|
||||
GetImageSize,
|
||||
ImageRotate,
|
||||
ImageFlip,
|
||||
ImageScaleToMaxDimension,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ImagesExtension:
|
||||
return ImagesExtension()
|
||||
|
||||
137
comfy_extras/nodes_kandinsky5.py
Normal file
137
comfy_extras/nodes_kandinsky5.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Kandinsky5ImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
||||
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond_latent_out = {}
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
cond_latent_out["samples"] = encoded
|
||||
|
||||
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
||||
|
||||
|
||||
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
|
||||
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
||||
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
||||
|
||||
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
||||
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
||||
|
||||
# normalization
|
||||
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||
normalized = normalized * reference_std + reference_mean
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class NormalizeVideoLatentStart(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="NormalizeVideoLatentStart",
|
||||
category="conditioning/video_models",
|
||||
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
|
||||
inputs=[
|
||||
io.Latent.Input("latent"),
|
||||
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
|
||||
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
|
||||
if latent["samples"].shape[2] <= 1:
|
||||
return io.NodeOutput(latent)
|
||||
s = latent.copy()
|
||||
samples = latent["samples"].clone()
|
||||
|
||||
first_frames = samples[:, :, :start_frame_count]
|
||||
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
|
||||
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||
|
||||
samples[:, :, :start_frame_count] = normalized_first_frames
|
||||
s["samples"] = samples
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeKandinsky5",
|
||||
search_aliases=["kandinsky prompt"],
|
||||
category="advanced/conditioning/kandinsky5",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
|
||||
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
class Kandinsky5Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
Kandinsky5ImageToVideo,
|
||||
NormalizeVideoLatentStart,
|
||||
CLIPTextEncodeKandinsky5,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> Kandinsky5Extension:
|
||||
return Kandinsky5Extension()
|
||||
@@ -4,7 +4,8 @@ import torch
|
||||
import nodes
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@@ -20,6 +21,7 @@ class LatentAdd(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentAdd",
|
||||
search_aliases=["combine latents", "sum latents"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@@ -46,6 +48,7 @@ class LatentSubtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentSubtract",
|
||||
search_aliases=["difference latent", "remove features"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@@ -72,6 +75,7 @@ class LatentMultiply(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentMultiply",
|
||||
search_aliases=["scale latent", "amplify latent", "latent gain"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@@ -95,6 +99,7 @@ class LatentInterpolate(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentInterpolate",
|
||||
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@@ -133,6 +138,7 @@ class LatentConcat(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentConcat",
|
||||
search_aliases=["join latents", "stitch latents"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@@ -172,6 +178,7 @@ class LatentCut(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCut",
|
||||
search_aliases=["crop latent", "slice latent", "extract region"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@@ -207,12 +214,56 @@ class LatentCut(io.ComfyNode):
|
||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentCutToBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCutToBatch",
|
||||
search_aliases=["slice to batch", "split latent", "tile latent"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("dim", options=["t", "x", "y"]),
|
||||
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
|
||||
if "x" in dim:
|
||||
dim = s1.ndim - 1
|
||||
elif "y" in dim:
|
||||
dim = s1.ndim - 2
|
||||
elif "t" in dim:
|
||||
dim = s1.ndim - 3
|
||||
|
||||
if dim < 2:
|
||||
return io.NodeOutput(samples)
|
||||
|
||||
s = s1.movedim(dim, 1)
|
||||
if s.shape[1] < slice_size:
|
||||
slice_size = s.shape[1]
|
||||
elif s.shape[1] % slice_size != 0:
|
||||
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
|
||||
new_shape = [-1, slice_size] + list(s.shape[2:])
|
||||
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentBatch",
|
||||
search_aliases=["combine latents", "merge latents", "join latents"],
|
||||
category="latent/batch",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
@@ -267,6 +318,7 @@ class LatentApplyOperation(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentApplyOperation",
|
||||
search_aliases=["transform latent"],
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
@@ -322,6 +374,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentOperationTonemapReinhard",
|
||||
search_aliases=["hdr latent"],
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
@@ -338,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
dims = list(range(1, latent_vector_magnitude.ndim))
|
||||
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
@@ -388,6 +442,42 @@ class LatentOperationSharpen(io.ComfyNode):
|
||||
return luminance * sharpened
|
||||
return io.NodeOutput(sharpen)
|
||||
|
||||
class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReplaceVideoLatentFrames",
|
||||
category="latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
|
||||
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
|
||||
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, index, source=None) -> io.NodeOutput:
|
||||
if source is None:
|
||||
return io.NodeOutput(destination)
|
||||
dest_frames = destination["samples"].shape[2]
|
||||
source_frames = source["samples"].shape[2]
|
||||
if index < 0:
|
||||
index = dest_frames + index
|
||||
if index > dest_frames:
|
||||
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
|
||||
return io.NodeOutput(destination)
|
||||
if index + source_frames > dest_frames:
|
||||
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
|
||||
return io.NodeOutput(destination)
|
||||
s = source.copy()
|
||||
s_source = source["samples"]
|
||||
s_destination = destination["samples"].clone()
|
||||
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
||||
s["samples"] = s_destination
|
||||
return io.NodeOutput(s)
|
||||
|
||||
class LatentExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -399,12 +489,14 @@ class LatentExtension(ComfyExtension):
|
||||
LatentInterpolate,
|
||||
LatentConcat,
|
||||
LatentCut,
|
||||
LatentCutToBatch,
|
||||
LatentBatch,
|
||||
LatentBatchSeedBehavior,
|
||||
LatentApplyOperation,
|
||||
LatentApplyOperationCFG,
|
||||
LatentOperationTonemapReinhard,
|
||||
LatentOperationSharpen,
|
||||
ReplaceVideoLatentFrames
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -11,9 +12,9 @@ from pathlib import Path
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
|
||||
class Load3D():
|
||||
class Load3D(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
def define_schema(cls):
|
||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
@@ -24,77 +25,32 @@ class Load3D():
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
|
||||
]
|
||||
return IO.Schema(
|
||||
node_id="Load3D",
|
||||
display_name="Load 3D & Animation",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
|
||||
IO.Load3D.Input("image"),
|
||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(display_name="image"),
|
||||
IO.Mask.Output(display_name="mask"),
|
||||
IO.String.Output(display_name="mesh_path"),
|
||||
IO.Image.Output(display_name="normal"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Video.Output(display_name="recording_video"),
|
||||
IO.File3DAny.Output(display_name="model_3d"),
|
||||
],
|
||||
)
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
"image": ("LOAD_3D", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
|
||||
]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
"image": ("LOAD_3D_ANIMATION", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||
@@ -109,74 +65,66 @@ class Load3DAnimation():
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
video = InputImpl.VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
|
||||
|
||||
class Preview3D():
|
||||
process = execute # TODO: remove
|
||||
|
||||
|
||||
class Preview3D(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
},
|
||||
"optional": {
|
||||
"camera_info": ("LOAD3D_CAMERA", {})
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Preview3D",
|
||||
search_aliases=["view mesh", "3d viewer"],
|
||||
display_name="Preview 3D & Animation",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
IO.String.Input("model_file", default="", multiline=False),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DFBX,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="3D model file or path string",
|
||||
),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True),
|
||||
IO.Image.Input("bg_image", optional=True),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ()
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
|
||||
return {
|
||||
"ui": {
|
||||
"result": [model_file, camera_info]
|
||||
}
|
||||
}
|
||||
|
||||
class Preview3DAnimation():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
},
|
||||
"optional": {
|
||||
"camera_info": ("LOAD3D_CAMERA", {})
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ()
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
|
||||
if isinstance(model_file, Types.File3D):
|
||||
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
|
||||
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
||||
else:
|
||||
filename = model_file
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
|
||||
|
||||
return {
|
||||
"ui": {
|
||||
"result": [model_file, camera_info]
|
||||
}
|
||||
}
|
||||
process = execute # TODO: remove
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Load3D": Load3D,
|
||||
"Load3DAnimation": Load3DAnimation,
|
||||
"Preview3D": Preview3D,
|
||||
"Preview3DAnimation": Preview3DAnimation
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Load3D": "Load 3D",
|
||||
"Load3DAnimation": "Load 3D - Animation",
|
||||
"Preview3D": "Preview 3D",
|
||||
"Preview3DAnimation": "Preview 3D - Animation"
|
||||
}
|
||||
class Load3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Load3D,
|
||||
Preview3D,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Load3DExtension:
|
||||
return Load3DExtension()
|
||||
|
||||
274
comfy_extras/nodes_logic.py
Normal file
274
comfy_extras/nodes_logic.py
Normal file
@@ -0,0 +1,274 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypedDict
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.latest import _io
|
||||
|
||||
# sentinel for missing inputs
|
||||
MISSING = object()
|
||||
|
||||
|
||||
class SwitchNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("switch")
|
||||
return io.Schema(
|
||||
node_id="ComfySwitchNode",
|
||||
display_name="Switch",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
io.MatchType.Input("on_false", template=template, lazy=True),
|
||||
io.MatchType.Input("on_true", template=template, lazy=True),
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(template=template, display_name="output"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, switch, on_false=None, on_true=None):
|
||||
if switch and on_true is None:
|
||||
return ["on_true"]
|
||||
if not switch and on_false is None:
|
||||
return ["on_false"]
|
||||
|
||||
@classmethod
|
||||
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
|
||||
return io.NodeOutput(on_true if switch else on_false)
|
||||
|
||||
|
||||
class SoftSwitchNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("switch")
|
||||
return io.Schema(
|
||||
node_id="ComfySoftSwitchNode",
|
||||
display_name="Soft Switch",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(template=template, display_name="output"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||
# We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
||||
|
||||
# One of the inputs may be missing, in which case we need to evaluate the other input
|
||||
if on_false is MISSING:
|
||||
return ["on_true"]
|
||||
if on_true is MISSING:
|
||||
return ["on_false"]
|
||||
# Normal lazy switch operation
|
||||
if switch and on_true is None:
|
||||
return ["on_true"]
|
||||
if not switch and on_false is None:
|
||||
return ["on_false"]
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||
# This check happens before check_lazy_status(), so we can eliminate the case where
|
||||
# both inputs are missing.
|
||||
if on_false is MISSING and on_true is MISSING:
|
||||
return "At least one of on_false or on_true must be connected to Switch node"
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
|
||||
if on_true is MISSING:
|
||||
return io.NodeOutput(on_false)
|
||||
if on_false is MISSING:
|
||||
return io.NodeOutput(on_true)
|
||||
return io.NodeOutput(on_true if switch else on_false)
|
||||
|
||||
|
||||
class CustomComboNode(io.ComfyNode):
|
||||
"""
|
||||
Frontend node that allows user to write their own options for a combo.
|
||||
This is here to make sure the node has a backend-representation to avoid some annoyances.
|
||||
"""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CustomCombo",
|
||||
display_name="Custom Combo",
|
||||
category="utils",
|
||||
is_experimental=True,
|
||||
inputs=[io.Combo.Input("choice", options=[])],
|
||||
outputs=[
|
||||
io.String.Output(display_name="STRING"),
|
||||
io.Int.Output(display_name="INDEX"),
|
||||
],
|
||||
accept_all_inputs=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool:
|
||||
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
|
||||
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
|
||||
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice, index)
|
||||
|
||||
|
||||
class DCTestNode(io.ComfyNode):
|
||||
class DCValues(TypedDict):
|
||||
combo: str
|
||||
string: str
|
||||
integer: int
|
||||
image: io.Image.Type
|
||||
subcombo: dict[str]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DCTestNode",
|
||||
display_name="DCTest",
|
||||
category="logic",
|
||||
is_output_node=True,
|
||||
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||
io.DynamicCombo.Option("option4", [
|
||||
io.DynamicCombo.Input("subcombo", options=[
|
||||
io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||
io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||
])
|
||||
])]
|
||||
)],
|
||||
outputs=[io.AnyType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, combo: DCValues) -> io.NodeOutput:
|
||||
combo_val = combo["combo"]
|
||||
if combo_val == "option1":
|
||||
return io.NodeOutput(combo["string"])
|
||||
elif combo_val == "option2":
|
||||
return io.NodeOutput(combo["integer"])
|
||||
elif combo_val == "option3":
|
||||
return io.NodeOutput(combo["image"])
|
||||
elif combo_val == "option4":
|
||||
return io.NodeOutput(f"{combo['subcombo']}")
|
||||
else:
|
||||
raise ValueError(f"Invalid combo: {combo_val}")
|
||||
|
||||
|
||||
class AutogrowNamesTestNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
|
||||
return io.Schema(
|
||||
node_id="AutogrowNamesTestNode",
|
||||
display_name="AutogrowNamesTest",
|
||||
category="logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
||||
vals = list(autogrow.values())
|
||||
combined = ",".join([str(x) for x in vals])
|
||||
return io.NodeOutput(combined)
|
||||
|
||||
class AutogrowPrefixTestNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
|
||||
return io.Schema(
|
||||
node_id="AutogrowPrefixTestNode",
|
||||
display_name="AutogrowPrefixTest",
|
||||
category="logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
||||
vals = list(autogrow.values())
|
||||
combined = ",".join([str(x) for x in vals])
|
||||
return io.NodeOutput(combined)
|
||||
|
||||
class ComboOutputTestNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ComboOptionTestNode",
|
||||
display_name="ComboOptionTest",
|
||||
category="logic",
|
||||
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(combo, combo2)
|
||||
|
||||
class ConvertStringToComboNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
outputs=[io.Combo.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string: str) -> io.NodeOutput:
|
||||
return io.NodeOutput(string)
|
||||
|
||||
class InvertBooleanNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, boolean: bool) -> io.NodeOutput:
|
||||
return io.NodeOutput(not boolean)
|
||||
|
||||
class LogicExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SwitchNode,
|
||||
CustomComboNode,
|
||||
# SoftSwitchNode,
|
||||
# ConvertStringToComboNode,
|
||||
# DCTestNode,
|
||||
# AutogrowNamesTestNode,
|
||||
# AutogrowPrefixTestNode,
|
||||
# ComboOutputTestNode,
|
||||
# InvertBooleanNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> LogicExtension:
|
||||
return LogicExtension()
|
||||
79
comfy_extras/nodes_lora_debug.py
Normal file
79
comfy_extras/nodes_lora_debug.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
|
||||
|
||||
class LoraLoaderBypass:
|
||||
"""
|
||||
Apply LoRA in bypass mode without modifying base model weights.
|
||||
|
||||
Bypass mode computes: output = base_forward(x) + lora_path(x)
|
||||
This is useful for training and when model weights are offloaded.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP")
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
||||
FUNCTION = "load_lora"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
else:
|
||||
self.loaded_lora = None
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
|
||||
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_lora_model_only"
|
||||
|
||||
def load_lora_model_only(self, model, lora_name, strength_model):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraLoaderBypass": LoraLoaderBypass,
|
||||
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)",
|
||||
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)",
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import logging
|
||||
from enum import Enum
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from tqdm.auto import trange
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
@@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||
"full_diff": LORAType.FULL_DIFF}
|
||||
|
||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
comfy.model_management.load_models_gpu([model_diff])
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
sd_keys = list(sd.keys())
|
||||
for index in trange(len(sd_keys), unit="weight"):
|
||||
k = sd_keys[index]
|
||||
op_keys = sd_keys[index].rsplit('.', 1)
|
||||
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
|
||||
continue
|
||||
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
|
||||
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
|
||||
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
|
||||
else:
|
||||
weight_diff = sd[k]
|
||||
|
||||
if op_keys[1] == "weight":
|
||||
if lora_type == LORAType.STANDARD:
|
||||
if weight_diff.ndim < 2:
|
||||
if bias_diff:
|
||||
@@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
||||
elif lora_type == LORAType.FULL_DIFF:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
|
||||
elif bias_diff and k.endswith(".bias"):
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||
elif bias_diff and op_keys[1] == "bias":
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
|
||||
return output_sd
|
||||
|
||||
class LoraSave(io.ComfyNode):
|
||||
@@ -78,6 +89,7 @@ class LoraSave(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraSave",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Extract and Save Lora",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVImgToVideoInplace",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
|
||||
if bypass:
|
||||
return (latent,)
|
||||
|
||||
samples = latent["samples"]
|
||||
_, height_scale_factor, width_scale_factor = (
|
||||
vae.downscale_index_formula
|
||||
)
|
||||
|
||||
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||
width = latent_width * width_scale_factor
|
||||
height = latent_height * height_scale_factor
|
||||
|
||||
if image.shape[1] != height or image.shape[2] != width:
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
else:
|
||||
pixels = image
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
samples[:, :, :t.shape[2]] = t
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch, 1, latent_frames, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=samples.device,
|
||||
)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def conditioning_get_any_value(conditioning, key, default=None):
|
||||
for t in conditioning:
|
||||
if key in t[1]:
|
||||
@@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
|
||||
class LTXVAddGuide(io.ComfyNode):
|
||||
NUM_PREFIX_FRAMES = 2
|
||||
PATCHIFIER = SymmetricPatchifier(1)
|
||||
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -170,11 +223,24 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return frame_idx, latent_idx
|
||||
|
||||
@classmethod
|
||||
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
|
||||
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1):
|
||||
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
|
||||
pixel_coords[:, 0] += frame_idx
|
||||
|
||||
# The following adjusts keyframe end positions for small grid IC-LoRA.
|
||||
# After dilation, the small grid has the same size and position as the large grid,
|
||||
# but each token encodes a larger image patch. We adjust the end position (not start)
|
||||
# so that RoPE represents the correct middle point of each token.
|
||||
# keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end])
|
||||
# We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3.
|
||||
spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor(
|
||||
scale_factors[1:],
|
||||
device=pixel_coords.device,
|
||||
).view(1, -1, 1, 1)
|
||||
pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype)
|
||||
|
||||
if keyframe_idxs is None:
|
||||
keyframe_idxs = pixel_coords
|
||||
else:
|
||||
@@ -182,26 +248,35 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
@classmethod
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||
_, latent_idx = cls.get_latent_index(
|
||||
cond=positive,
|
||||
latent_length=latent_image.shape[2],
|
||||
guide_length=guiding_latent.shape[2],
|
||||
frame_idx=frame_idx,
|
||||
scale_factors=scale_factors,
|
||||
)
|
||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1):
|
||||
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
if guide_mask is not None:
|
||||
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
|
||||
|
||||
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
|
||||
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
|
||||
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
|
||||
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
mask = guide_mask - strength
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
# This solves audio video combined latent case where latent_image has audio latent concatenated
|
||||
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
|
||||
if latent_image.shape[1] > guiding_latent.shape[1]:
|
||||
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
|
||||
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
|
||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||
return positive, negative, latent_image, noise_mask
|
||||
@@ -238,33 +313,17 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
positive,
|
||||
negative,
|
||||
frame_idx,
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t[:, :, :num_prefix_frames],
|
||||
t,
|
||||
strength,
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
latent_idx += num_prefix_frames
|
||||
|
||||
t = t[:, :, num_prefix_frames:]
|
||||
if t.shape[2] == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
latent_image, noise_mask = cls.replace_latent_frames(
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t,
|
||||
latent_idx,
|
||||
strength,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
@@ -507,18 +566,90 @@ class LTXVPreprocess(io.ComfyNode):
|
||||
|
||||
preprocess = execute # TODO: remove
|
||||
|
||||
|
||||
import comfy.nested_tensor
|
||||
class LTXVConcatAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVConcatAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Latent.Input("video_latent"),
|
||||
io.Latent.Input("audio_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
|
||||
output = {}
|
||||
output.update(video_latent)
|
||||
output.update(audio_latent)
|
||||
video_noise_mask = video_latent.get("noise_mask", None)
|
||||
audio_noise_mask = audio_latent.get("noise_mask", None)
|
||||
|
||||
if video_noise_mask is not None or audio_noise_mask is not None:
|
||||
if video_noise_mask is None:
|
||||
video_noise_mask = torch.ones_like(video_latent["samples"])
|
||||
if audio_noise_mask is None:
|
||||
audio_noise_mask = torch.ones_like(audio_latent["samples"])
|
||||
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
|
||||
|
||||
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
|
||||
|
||||
return io.NodeOutput(output)
|
||||
|
||||
|
||||
class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVSeparateAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
description="LTXV Separate AV Latent",
|
||||
inputs=[
|
||||
io.Latent.Input("av_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="video_latent"),
|
||||
io.Latent.Output(display_name="audio_latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, av_latent) -> io.NodeOutput:
|
||||
latents = av_latent["samples"].unbind()
|
||||
video_latent = av_latent.copy()
|
||||
video_latent["samples"] = latents[0]
|
||||
audio_latent = av_latent.copy()
|
||||
audio_latent["samples"] = latents[1]
|
||||
if "noise_mask" in av_latent:
|
||||
masks = av_latent["noise_mask"]
|
||||
if masks is not None:
|
||||
masks = masks.unbind()
|
||||
video_latent["noise_mask"] = masks[0]
|
||||
audio_latent["noise_mask"] = masks[1]
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyLTXVLatentVideo,
|
||||
LTXVImgToVideo,
|
||||
LTXVImgToVideoInplace,
|
||||
ModelSamplingLTXV,
|
||||
LTXVConditioning,
|
||||
LTXVScheduler,
|
||||
LTXVAddGuide,
|
||||
LTXVPreprocess,
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
]
|
||||
|
||||
|
||||
|
||||
224
comfy_extras/nodes_lt_audio.py
Normal file
224
comfy_extras/nodes_lt_audio.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import torch
|
||||
|
||||
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class LTXVAudioVAELoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
options=folder_paths.get_filename_list("checkpoints"),
|
||||
tooltip="Audio VAE checkpoint to load.",
|
||||
)
|
||||
],
|
||||
outputs=[io.Vae.Output(display_name="Audio VAE")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||
|
||||
|
||||
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model to use for encoding.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="Audio Latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latents = audio_vae.encode(audio)
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": int(audio_vae.sample_rate),
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model used for decoding the latent.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Audio.Output(display_name="Audio")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latent = samples["samples"]
|
||||
if audio_latent.is_nested:
|
||||
audio_latent = audio_latent.unbind()[-1]
|
||||
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"waveform": audio,
|
||||
"sample_rate": int(output_audio_sample_rate),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVEmptyLatentAudio",
|
||||
display_name="LTXV Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Int.Input(
|
||||
"frames_number",
|
||||
default=97,
|
||||
min=1,
|
||||
max=1000,
|
||||
step=1,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="Number of frames.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"frame_rate",
|
||||
default=25,
|
||||
min=1,
|
||||
max=1000,
|
||||
step=1,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="Number of frames per second.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"batch_size",
|
||||
default=1,
|
||||
min=1,
|
||||
max=4096,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="The number of latent audio samples in the batch.",
|
||||
),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model to get configuration from.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="Latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
frames_number: int,
|
||||
frame_rate: int,
|
||||
batch_size: int,
|
||||
audio_vae: AudioVAE,
|
||||
) -> io.NodeOutput:
|
||||
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||
|
||||
assert audio_vae is not None, "Audio VAE model is required"
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
audio_latents = torch.zeros(
|
||||
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": sampling_rate,
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXAVTextEncoderLoader",
|
||||
display_name="LTXV Audio Text Encoder Loader",
|
||||
category="advanced/loaders",
|
||||
description="[Recipes]\n\nltxav: gemma 3 12B",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"text_encoder",
|
||||
options=folder_paths.get_filename_list("text_encoders"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
options=folder_paths.get_filename_list("checkpoints"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"device",
|
||||
options=["default", "cpu"],
|
||||
)
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, text_encoder, ckpt_name, device="default"):
|
||||
clip_type = comfy.sd.CLIPType.LTXV
|
||||
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
class LTXVAudioExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LTXVAudioVAELoader,
|
||||
LTXVAudioVAEEncode,
|
||||
LTXVAudioVAEDecode,
|
||||
LTXVEmptyLatentAudio,
|
||||
LTXAVTextEncoderLoader,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ComfyExtension:
|
||||
return LTXVAudioExtension()
|
||||
75
comfy_extras/nodes_lt_upsampler.py
Normal file
75
comfy_extras/nodes_lt_upsampler.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from comfy import model_management
|
||||
import math
|
||||
|
||||
class LTXVLatentUpsampler:
|
||||
"""
|
||||
Upsamples a video latent by a factor of 2.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT",),
|
||||
"upscale_model": ("LATENT_UPSCALE_MODEL",),
|
||||
"vae": ("VAE",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upsample_latent"
|
||||
CATEGORY = "latent/video"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def upsample_latent(
|
||||
self,
|
||||
samples: dict,
|
||||
upscale_model,
|
||||
vae,
|
||||
) -> tuple:
|
||||
"""
|
||||
Upsample the input latent using the provided model.
|
||||
|
||||
Args:
|
||||
samples (dict): Input latent samples
|
||||
upscale_model (LatentUpsampler): Loaded upscale model
|
||||
vae: VAE model for normalization
|
||||
auto_tiling (bool): Whether to automatically tile the input for processing
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing the upsampled latent
|
||||
"""
|
||||
device = model_management.get_torch_device()
|
||||
memory_required = model_management.module_size(upscale_model)
|
||||
|
||||
model_dtype = next(upscale_model.parameters()).dtype
|
||||
latents = samples["samples"]
|
||||
input_dtype = latents.dtype
|
||||
|
||||
memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate
|
||||
model_management.free_memory(memory_required, device)
|
||||
|
||||
try:
|
||||
upscale_model.to(device) # TODO: use the comfy model management system.
|
||||
|
||||
latents = latents.to(dtype=model_dtype, device=device)
|
||||
|
||||
"""Upsample latents without tiling."""
|
||||
latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
|
||||
upsampled_latents = upscale_model(latents)
|
||||
finally:
|
||||
upscale_model.cpu()
|
||||
|
||||
upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
|
||||
upsampled_latents
|
||||
)
|
||||
upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device())
|
||||
return_dict = samples.copy()
|
||||
return_dict["samples"] = upsampled_latents
|
||||
return_dict.pop("noise_mask", None)
|
||||
return (return_dict,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LTXVLatentUpsampler": LTXVLatentUpsampler,
|
||||
}
|
||||
@@ -79,6 +79,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeLumina2",
|
||||
search_aliases=["lumina prompt"],
|
||||
display_name="CLIP Text Encode for Lumina2",
|
||||
category="conditioning",
|
||||
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||
|
||||
@@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
||||
display_name="Mahiro CFG",
|
||||
category="_for_testing",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
|
||||
@@ -3,11 +3,10 @@ import scipy.ndimage
|
||||
import torch
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import folder_paths
|
||||
import random
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
|
||||
import nodes
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||
source = source.to(destination.device)
|
||||
@@ -46,202 +45,221 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||
return destination
|
||||
|
||||
class LatentCompositeMasked:
|
||||
class LatentCompositeMasked(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("LATENT",),
|
||||
"source": ("LATENT",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"resize_source": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "composite"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LatentCompositeMasked",
|
||||
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
|
||||
category="latent",
|
||||
inputs=[
|
||||
IO.Latent.Input("destination"),
|
||||
IO.Latent.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||
IO.Boolean.Input("resize_source", default=False),
|
||||
IO.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||
output = destination.copy()
|
||||
destination = destination["samples"].clone()
|
||||
source = source["samples"]
|
||||
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class ImageCompositeMasked:
|
||||
composite = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageCompositeMasked(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("IMAGE",),
|
||||
"source": ("IMAGE",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"resize_source": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "composite"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompositeMasked",
|
||||
search_aliases=["paste image", "overlay", "layer"],
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("destination"),
|
||||
IO.Image.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Boolean.Input("resize_source", default=False),
|
||||
IO.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||
destination = destination.clone().movedim(-1, 1)
|
||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class MaskToImage:
|
||||
composite = execute # TODO: remove
|
||||
|
||||
|
||||
class MaskToImage(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskToImage",
|
||||
search_aliases=["convert mask"],
|
||||
display_name="Convert Mask to Image",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mask_to_image"
|
||||
|
||||
def mask_to_image(self, mask):
|
||||
@classmethod
|
||||
def execute(cls, mask) -> IO.NodeOutput:
|
||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return (result,)
|
||||
return IO.NodeOutput(result)
|
||||
|
||||
class ImageToMask:
|
||||
mask_to_image = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageToMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"channel": (["red", "green", "blue", "alpha"],),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageToMask",
|
||||
search_aliases=["extract channel", "channel to mask"],
|
||||
display_name="Convert Image to Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, image, channel):
|
||||
@classmethod
|
||||
def execute(cls, image, channel) -> IO.NodeOutput:
|
||||
channels = ["red", "green", "blue", "alpha"]
|
||||
mask = image[:, :, :, channels.index(channel)]
|
||||
return (mask,)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
class ImageColorToMask:
|
||||
image_to_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageColorToMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageColorToMask",
|
||||
search_aliases=["color keying", "chroma key"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, image, color):
|
||||
@classmethod
|
||||
def execute(cls, image, color) -> IO.NodeOutput:
|
||||
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
||||
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
||||
mask = torch.where(temp == color, 1.0, 0).float()
|
||||
return (mask,)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
class SolidMask:
|
||||
image_to_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class SolidMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SolidMask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "solid"
|
||||
|
||||
def solid(self, value, width, height):
|
||||
@classmethod
|
||||
def execute(cls, value, width, height) -> IO.NodeOutput:
|
||||
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
|
||||
return (out,)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class InvertMask:
|
||||
solid = execute # TODO: remove
|
||||
|
||||
|
||||
class InvertMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="InvertMask",
|
||||
search_aliases=["reverse mask", "flip mask"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "invert"
|
||||
|
||||
def invert(self, mask):
|
||||
@classmethod
|
||||
def execute(cls, mask) -> IO.NodeOutput:
|
||||
out = 1.0 - mask
|
||||
return (out,)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class CropMask:
|
||||
invert = execute # TODO: remove
|
||||
|
||||
|
||||
class CropMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CropMask",
|
||||
search_aliases=["cut mask", "extract mask region", "mask slice"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "crop"
|
||||
|
||||
def crop(self, mask, x, y, width, height):
|
||||
@classmethod
|
||||
def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
|
||||
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||
out = mask[:, y:y + height, x:x + width]
|
||||
return (out,)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class MaskComposite:
|
||||
crop = execute # TODO: remove
|
||||
|
||||
|
||||
class MaskComposite(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("MASK",),
|
||||
"source": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskComposite",
|
||||
search_aliases=["combine masks", "blend masks", "layer masks"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("destination"),
|
||||
IO.Mask.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "combine"
|
||||
|
||||
def combine(self, destination, source, x, y, operation):
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
|
||||
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
||||
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
||||
|
||||
@@ -267,28 +285,30 @@ class MaskComposite:
|
||||
|
||||
output = torch.clamp(output, 0.0, 1.0)
|
||||
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class FeatherMask:
|
||||
combine = execute # TODO: remove
|
||||
|
||||
|
||||
class FeatherMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FeatherMask",
|
||||
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "feather"
|
||||
|
||||
def feather(self, mask, left, top, right, bottom):
|
||||
@classmethod
|
||||
def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
|
||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||
|
||||
left = min(left, output.shape[-1])
|
||||
@@ -312,26 +332,29 @@ class FeatherMask:
|
||||
feather_rate = (y + 1) / bottom
|
||||
output[:, -y, :] *= feather_rate
|
||||
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class GrowMask:
|
||||
feather = execute # TODO: remove
|
||||
|
||||
|
||||
class GrowMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"tapered_corners": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrowMask",
|
||||
search_aliases=["expand mask", "shrink mask"],
|
||||
display_name="Grow Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Boolean.Input("tapered_corners", default=True),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "expand_mask"
|
||||
|
||||
def expand_mask(self, mask, expand, tapered_corners):
|
||||
@classmethod
|
||||
def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
|
||||
c = 0 if tapered_corners else 1
|
||||
kernel = np.array([[c, 1, c],
|
||||
[1, 1, 1],
|
||||
@@ -347,69 +370,76 @@ class GrowMask:
|
||||
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
||||
output = torch.from_numpy(output)
|
||||
out.append(output)
|
||||
return (torch.stack(out, dim=0),)
|
||||
return IO.NodeOutput(torch.stack(out, dim=0))
|
||||
|
||||
class ThresholdMask:
|
||||
expand_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class ThresholdMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ThresholdMask",
|
||||
search_aliases=["binary mask"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, mask, value):
|
||||
@classmethod
|
||||
def execute(cls, mask, value) -> IO.NodeOutput:
|
||||
mask = (mask > value).float()
|
||||
return (mask,)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
image_to_mask = execute # TODO: remove
|
||||
|
||||
|
||||
# Mask Preview - original implement from
|
||||
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||
class MaskPreview(nodes.SaveImage):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
self.type = "temp"
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
self.compress_level = 4
|
||||
class MaskPreview(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskPreview",
|
||||
search_aliases=["show mask", "view mask", "inspect mask", "debug mask"],
|
||||
display_name="Preview Mask",
|
||||
category="mask",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {"mask": ("MASK",), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "mask"
|
||||
|
||||
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
||||
def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(ui=UI.PreviewMask(mask))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentCompositeMasked": LatentCompositeMasked,
|
||||
"ImageCompositeMasked": ImageCompositeMasked,
|
||||
"MaskToImage": MaskToImage,
|
||||
"ImageToMask": ImageToMask,
|
||||
"ImageColorToMask": ImageColorToMask,
|
||||
"SolidMask": SolidMask,
|
||||
"InvertMask": InvertMask,
|
||||
"CropMask": CropMask,
|
||||
"MaskComposite": MaskComposite,
|
||||
"FeatherMask": FeatherMask,
|
||||
"GrowMask": GrowMask,
|
||||
"ThresholdMask": ThresholdMask,
|
||||
"MaskPreview": MaskPreview
|
||||
}
|
||||
class MaskExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
LatentCompositeMasked,
|
||||
ImageCompositeMasked,
|
||||
MaskToImage,
|
||||
ImageToMask,
|
||||
ImageColorToMask,
|
||||
SolidMask,
|
||||
InvertMask,
|
||||
CropMask,
|
||||
MaskComposite,
|
||||
FeatherMask,
|
||||
GrowMask,
|
||||
ThresholdMask,
|
||||
MaskPreview,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageToMask": "Convert Image to Mask",
|
||||
"MaskToImage": "Convert Mask to Image",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> MaskExtension:
|
||||
return MaskExtension()
|
||||
|
||||
@@ -299,6 +299,7 @@ class RescaleCFG:
|
||||
return (m, )
|
||||
|
||||
class ModelComputeDtype:
|
||||
SEARCH_ALIASES = ["model precision", "change dtype"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
|
||||
@@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"PatchModelAddDownscale": "",
|
||||
}
|
||||
|
||||
class ModelDownscaleExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
|
||||
@@ -91,6 +91,7 @@ class CLIPMergeSimple:
|
||||
|
||||
|
||||
class CLIPSubtract:
|
||||
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@@ -113,6 +114,7 @@ class CLIPSubtract:
|
||||
|
||||
|
||||
class CLIPAdd:
|
||||
SEARCH_ALIASES = ["combine clip"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@@ -225,6 +227,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||
|
||||
class CheckpointSave:
|
||||
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@@ -337,6 +340,7 @@ class VAESave:
|
||||
return {}
|
||||
|
||||
class ModelSave:
|
||||
SEARCH_ALIASES = ["export model", "checkpoint save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import comfy.ops
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@@ -189,6 +191,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
|
||||
|
||||
return embedding
|
||||
|
||||
def z_image_convert(sd):
|
||||
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
|
||||
".attention.norm_k.weight": ".attention.k_norm.weight",
|
||||
".attention.norm_q.weight": ".attention.q_norm.weight",
|
||||
".attention.to_out.0.weight": ".attention.out.weight"
|
||||
}
|
||||
|
||||
out_sd = {}
|
||||
for k in sorted(sd.keys()):
|
||||
w = sd[k]
|
||||
|
||||
k_out = k
|
||||
if k_out.endswith(".attention.to_k.weight"):
|
||||
cc = [w]
|
||||
continue
|
||||
if k_out.endswith(".attention.to_q.weight"):
|
||||
cc = [w] + cc
|
||||
continue
|
||||
if k_out.endswith(".attention.to_v.weight"):
|
||||
cc = cc + [w]
|
||||
w = torch.cat(cc, dim=0)
|
||||
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
|
||||
|
||||
for r, rr in replace_keys.items():
|
||||
k_out = k_out.replace(r, rr)
|
||||
out_sd[k_out] = w
|
||||
|
||||
return out_sd
|
||||
|
||||
class ModelPatchLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -211,10 +242,34 @@ class ModelPatchLoader:
|
||||
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||
sd = z_image_convert(sd)
|
||||
config = {}
|
||||
if 'control_layers.4.adaLN_modulation.0.weight' not in sd:
|
||||
config['n_control_layers'] = 3
|
||||
config['additional_in_dim'] = 17
|
||||
config['refiner_control'] = True
|
||||
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
|
||||
config['n_control_layers'] = 15
|
||||
config['additional_in_dim'] = 17
|
||||
config['refiner_control'] = True
|
||||
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
|
||||
if ref_weight is not None:
|
||||
if torch.count_nonzero(ref_weight) == 0:
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
elif "audio_proj.proj1.weight" in sd:
|
||||
model = MultiTalkModelPatch(
|
||||
audio_window=5, context_tokens=32, vae_scale=4,
|
||||
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
|
||||
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
return (model,)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||
return (model_patcher,)
|
||||
|
||||
|
||||
class DiffSynthCnetPatch:
|
||||
@@ -263,6 +318,129 @@ class DiffSynthCnetPatch:
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class ZImageControlPatch:
|
||||
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
|
||||
self.model_patch = model_patch
|
||||
self.vae = vae
|
||||
self.image = image
|
||||
self.inpaint_image = inpaint_image
|
||||
self.mask = mask
|
||||
self.strength = strength
|
||||
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
|
||||
|
||||
skip_encoding = False
|
||||
if self.image is not None and self.inpaint_image is not None:
|
||||
if self.image.shape != self.inpaint_image.shape:
|
||||
skip_encoding = True
|
||||
|
||||
if skip_encoding:
|
||||
self.encoded_image = None
|
||||
else:
|
||||
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
|
||||
if self.image is None:
|
||||
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
|
||||
else:
|
||||
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
|
||||
self.temp_data = None
|
||||
|
||||
def encode_latent_cond(self, control_image=None, inpaint_image=None):
|
||||
latent_image = None
|
||||
if control_image is not None:
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||
|
||||
if self.is_inpaint:
|
||||
if inpaint_image is None:
|
||||
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||
|
||||
if self.mask is not None:
|
||||
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
|
||||
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
|
||||
|
||||
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||
|
||||
if self.mask is None:
|
||||
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
||||
else:
|
||||
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
||||
|
||||
if latent_image is None:
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
||||
|
||||
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||
else:
|
||||
return latent_image
|
||||
|
||||
def __call__(self, kwargs):
|
||||
x = kwargs.get("x")
|
||||
img = kwargs.get("img")
|
||||
img_input = kwargs.get("img_input")
|
||||
txt = kwargs.get("txt")
|
||||
pe = kwargs.get("pe")
|
||||
vec = kwargs.get("vec")
|
||||
block_index = kwargs.get("block_index")
|
||||
block_type = kwargs.get("block_type", "")
|
||||
spacial_compression = self.vae.spacial_compression_encode()
|
||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||
image_scaled = None
|
||||
if self.image is not None:
|
||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
|
||||
|
||||
inpaint_scaled = None
|
||||
if self.inpaint_image is not None:
|
||||
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
|
||||
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
|
||||
cnet_blocks = self.model_patch.model.n_control_layers
|
||||
div = round(30 / cnet_blocks)
|
||||
|
||||
cnet_index = (block_index // div)
|
||||
cnet_index_float = (block_index / div)
|
||||
|
||||
kwargs.pop("img") # we do ops in place
|
||||
kwargs.pop("txt")
|
||||
|
||||
if cnet_index_float > (cnet_blocks - 1):
|
||||
self.temp_data = None
|
||||
return kwargs
|
||||
|
||||
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||
if block_type == "noise_refiner":
|
||||
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
else:
|
||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
|
||||
if block_type == "noise_refiner":
|
||||
next_layer = self.temp_data[0] + 1
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
if self.temp_data[1][0] is not None:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
else:
|
||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||
next_layer = self.temp_data[0] + 1
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
|
||||
if cnet_index_float == self.temp_data[0]:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
if cnet_blocks == self.temp_data[0] + 1:
|
||||
self.temp_data = None
|
||||
|
||||
return kwargs
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
if self.encoded_image is not None:
|
||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||
self.temp_data = None
|
||||
return self
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class QwenImageDiffsynthControlnet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -279,9 +457,12 @@ class QwenImageDiffsynthControlnet:
|
||||
|
||||
CATEGORY = "advanced/loaders/qwen"
|
||||
|
||||
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
||||
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
|
||||
model_patched = model.clone()
|
||||
image = image[:, :, :, :3]
|
||||
if image is not None:
|
||||
image = image[:, :, :, :3]
|
||||
if inpaint_image is not None:
|
||||
inpaint_image = inpaint_image[:, :, :, :3]
|
||||
if mask is not None:
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
@@ -289,9 +470,25 @@ class QwenImageDiffsynthControlnet:
|
||||
mask = mask.unsqueeze(2)
|
||||
mask = 1.0 - mask
|
||||
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
|
||||
model_patched.set_model_noise_refiner_patch(patch)
|
||||
model_patched.set_model_double_block_patch(patch)
|
||||
else:
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
return (model_patched,)
|
||||
|
||||
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"model_patch": ("MODEL_PATCH",),
|
||||
"vae": ("VAE",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
|
||||
|
||||
CATEGORY = "advanced/loaders/zimage"
|
||||
|
||||
class UsoStyleProjectorPatch:
|
||||
def __init__(self, model_patch, encoded_image):
|
||||
@@ -336,8 +533,41 @@ class USOStyleReference:
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
class MultiTalkModelPatch(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_window: int = 5,
|
||||
intermediate_dim: int = 512,
|
||||
in_dim: int = 5120,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
vae_scale: int = 4,
|
||||
num_layers: int = 40,
|
||||
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.audio_proj = MultiTalkAudioProjModel(
|
||||
seq_len=audio_window,
|
||||
seq_len_vf=audio_window+vae_scale-1,
|
||||
intermediate_dim=intermediate_dim,
|
||||
out_dim=out_dim,
|
||||
context_tokens=context_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
"ZImageFunControlnet": ZImageFunControlnet,
|
||||
"USOStyleReference": USOStyleReference,
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ class Morphology(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Morphology",
|
||||
search_aliases=["erode", "dilate"],
|
||||
display_name="ImageMorphology",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
@@ -57,6 +58,7 @@ class ImageRGBToYUV(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageRGBToYUV",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
@@ -78,6 +80,7 @@ class ImageYUVToRGB(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageYUVToRGB",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input("Y"),
|
||||
|
||||
99
comfy_extras/nodes_nag.py
Normal file
99
comfy_extras/nodes_nag.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class NAGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="NAGuidance",
|
||||
display_name="Normalized Attention Guidance",
|
||||
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
|
||||
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
|
||||
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
|
||||
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
|
||||
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with NAG enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
|
||||
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
|
||||
|
||||
def nag_attention_output_patch(out, extra_options):
|
||||
cond_or_uncond = extra_options.get("cond_or_uncond", None)
|
||||
if cond_or_uncond is None:
|
||||
return out
|
||||
|
||||
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
|
||||
return out
|
||||
|
||||
# sigma = extra_options.get("sigmas", None)
|
||||
# if sigma is not None and len(sigma) > 0:
|
||||
# sigma = sigma[0].item()
|
||||
# if sigma > sigma_start or sigma < sigma_end:
|
||||
# return out
|
||||
|
||||
img_slice = extra_options.get("img_slice", None)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out = out
|
||||
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
|
||||
|
||||
batch_size = out.shape[0]
|
||||
half_size = batch_size // len(cond_or_uncond)
|
||||
|
||||
ind_neg = cond_or_uncond.index(1)
|
||||
ind_pos = cond_or_uncond.index(0)
|
||||
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
|
||||
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
|
||||
|
||||
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
|
||||
|
||||
eps = 1e-6
|
||||
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
|
||||
ratio = norm_guided / norm_pos
|
||||
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
|
||||
|
||||
guided_normalized = guided * scale_factor
|
||||
|
||||
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
return orig_out
|
||||
else:
|
||||
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
|
||||
return out
|
||||
|
||||
m.set_model_attn1_output_patch(nag_attention_output_patch)
|
||||
m.disable_model_cfg1_optimization()
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class NagExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
NAGuidance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NagExtension:
|
||||
return NagExtension()
|
||||
39
comfy_extras/nodes_nop.py
Normal file
39
comfy_extras/nodes_nop.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list
|
||||
|
||||
# "native" block swap nodes are placebo at best and break the ComfyUI memory management system.
|
||||
# They are also considered harmful because instead of users reporting issues with the built in
|
||||
# memory management they install these stupid nodes and complain even harder. Now it completely
|
||||
# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it
|
||||
# out of all workflows.
|
||||
class wanBlockSwap(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="wanBlockSwap",
|
||||
category="",
|
||||
description="NOP",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model) -> io.NodeOutput:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class NopExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
wanBlockSwap
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> NopExtension:
|
||||
return NopExtension()
|
||||
@@ -7,6 +7,7 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodePixArtAlpha",
|
||||
search_aliases=["pixart prompt"],
|
||||
category="advanced/conditioning",
|
||||
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
|
||||
inputs=[
|
||||
|
||||
@@ -4,11 +4,15 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Literal
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_extras.nodes_latent import reshape_latent_to
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class Blend(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -221,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
@@ -228,18 +233,429 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
|
||||
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
||||
samples = image.movedim(-1,1)
|
||||
total = int(megapixels * 1024 * 1024)
|
||||
total = megapixels * 1024 * 1024
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
|
||||
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
|
||||
s = s.movedim(1,-1)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
class ResizeType(str, Enum):
|
||||
SCALE_BY = "scale by multiplier"
|
||||
SCALE_DIMENSIONS = "scale dimensions"
|
||||
SCALE_LONGER_DIMENSION = "scale longer dimension"
|
||||
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
|
||||
SCALE_WIDTH = "scale width"
|
||||
SCALE_HEIGHT = "scale height"
|
||||
SCALE_TOTAL_PIXELS = "scale total pixels"
|
||||
MATCH_SIZE = "match size"
|
||||
SCALE_TO_MULTIPLE = "scale to multiple"
|
||||
|
||||
def is_image(input: torch.Tensor) -> bool:
|
||||
# images have 4 dimensions: [batch, height, width, channels]
|
||||
# masks have 3 dimensions: [batch, height, width]
|
||||
return len(input.shape) == 4
|
||||
|
||||
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||
if is_type_image:
|
||||
input = input.movedim(-1, 1)
|
||||
else:
|
||||
input = input.unsqueeze(1)
|
||||
return input
|
||||
|
||||
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||
if is_type_image:
|
||||
input = input.movedim(1, -1)
|
||||
else:
|
||||
input = input.squeeze(1)
|
||||
return input
|
||||
|
||||
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = round(input.shape[-1] * multiplier)
|
||||
height = round(input.shape[-2] * multiplier)
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
|
||||
if width == 0 and height == 0:
|
||||
return input
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
|
||||
if width == 0:
|
||||
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
|
||||
elif height == 0:
|
||||
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = input.shape[-1]
|
||||
height = input.shape[-2]
|
||||
|
||||
if height > width:
|
||||
width = round((width / height) * longer_size)
|
||||
height = longer_size
|
||||
elif width > height:
|
||||
height = round((height / width) * longer_size)
|
||||
width = longer_size
|
||||
else:
|
||||
height = longer_size
|
||||
width = longer_size
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = input.shape[-1]
|
||||
height = input.shape[-2]
|
||||
|
||||
if height < width:
|
||||
width = round((width / height) * shorter_size)
|
||||
height = shorter_size
|
||||
elif width < height:
|
||||
height = round((height / width) * shorter_size)
|
||||
width = shorter_size
|
||||
else:
|
||||
height = shorter_size
|
||||
width = shorter_size
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
total = int(megapixels * 1024 * 1024)
|
||||
|
||||
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
|
||||
width = round(input.shape[-1] * scale_by)
|
||||
height = round(input.shape[-2] * scale_by)
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
match = init_image_mask_input(match, is_image(match))
|
||||
|
||||
width = match.shape[-1]
|
||||
height = match.shape[-2]
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor:
|
||||
if multiple <= 1:
|
||||
return input
|
||||
is_type_image = is_image(input)
|
||||
if is_type_image:
|
||||
_, height, width, _ = input.shape
|
||||
else:
|
||||
_, height, width = input.shape
|
||||
target_w = (width // multiple) * multiple
|
||||
target_h = (height // multiple) * multiple
|
||||
if target_w == 0 or target_h == 0:
|
||||
return input
|
||||
if target_w == width and target_h == height:
|
||||
return input
|
||||
s_w = target_w / width
|
||||
s_h = target_h / height
|
||||
if s_w >= s_h:
|
||||
scaled_w = target_w
|
||||
scaled_h = int(math.ceil(height * s_w))
|
||||
if scaled_h < target_h:
|
||||
scaled_h = target_h
|
||||
else:
|
||||
scaled_h = target_h
|
||||
scaled_w = int(math.ceil(width * s_h))
|
||||
if scaled_w < target_w:
|
||||
scaled_w = target_w
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
x0 = (scaled_w - target_w) // 2
|
||||
y0 = (scaled_h - target_h) // 2
|
||||
x1 = x0 + target_w
|
||||
y1 = y0 + target_h
|
||||
if is_type_image:
|
||||
return input[:, y0:y1, x0:x1, :]
|
||||
return input[:, y0:y1, x0:x1]
|
||||
|
||||
class ResizeImageMaskNode(io.ComfyNode):
|
||||
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
class ResizeTypedDict(TypedDict):
|
||||
resize_type: ResizeType
|
||||
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop: Literal["disabled", "center"]
|
||||
multiplier: float
|
||||
width: int
|
||||
height: int
|
||||
longer_size: int
|
||||
shorter_size: int
|
||||
megapixels: float
|
||||
multiple: int
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
|
||||
crop_combo = io.Combo.Input(
|
||||
"crop",
|
||||
options=cls.crop_methods,
|
||||
default="center",
|
||||
tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="ResizeImageMaskNode",
|
||||
display_name="Resize Image/Mask",
|
||||
description="Resize an image or mask using various scaling methods.",
|
||||
category="transform",
|
||||
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
||||
inputs=[
|
||||
io.MatchType.Input("input", template=template),
|
||||
io.DynamicCombo.Input(
|
||||
"resize_type",
|
||||
tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.",
|
||||
options=[
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."),
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||
io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
|
||||
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."),
|
||||
]),
|
||||
],
|
||||
),
|
||||
io.Combo.Input(
|
||||
"scale_method",
|
||||
options=cls.scale_methods,
|
||||
default="area",
|
||||
tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.",
|
||||
),
|
||||
],
|
||||
outputs=[io.MatchType.Output(template=template, display_name="resized")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
|
||||
selected_type = resize_type["resize_type"]
|
||||
if selected_type == ResizeType.SCALE_BY:
|
||||
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_DIMENSIONS:
|
||||
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
|
||||
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
|
||||
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
|
||||
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_WIDTH:
|
||||
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
|
||||
elif selected_type == ResizeType.SCALE_HEIGHT:
|
||||
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
|
||||
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
|
||||
elif selected_type == ResizeType.MATCH_SIZE:
|
||||
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
||||
elif selected_type == ResizeType.SCALE_TO_MULTIPLE:
|
||||
return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method))
|
||||
raise ValueError(f"Unsupported resize type: {selected_type}")
|
||||
|
||||
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
|
||||
if len(images) == 0:
|
||||
return None
|
||||
# first, get the max channels count
|
||||
max_channels = max(image.shape[-1] for image in images)
|
||||
# then, pad all images to have the same channels count
|
||||
padded_images: list[torch.Tensor] = []
|
||||
for image in images:
|
||||
if image.shape[-1] < max_channels:
|
||||
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
|
||||
else:
|
||||
padded_images.append(image)
|
||||
# resize all images to be the same size as the first image
|
||||
resized_images: list[torch.Tensor] = []
|
||||
first_image_shape = padded_images[0].shape
|
||||
for image in padded_images:
|
||||
if image.shape[1:] != first_image_shape[1:]:
|
||||
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
|
||||
else:
|
||||
resized_images.append(image)
|
||||
# batch the images in the format [b, h, w, c]
|
||||
return torch.cat(resized_images, dim=0)
|
||||
|
||||
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
|
||||
if len(masks) == 0:
|
||||
return None
|
||||
# resize all masks to be the same size as the first mask
|
||||
resized_masks: list[torch.Tensor] = []
|
||||
first_mask_shape = masks[0].shape
|
||||
for mask in masks:
|
||||
if mask.shape[1:] != first_mask_shape[1:]:
|
||||
mask = init_image_mask_input(mask, is_type_image=False)
|
||||
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
|
||||
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
|
||||
else:
|
||||
resized_masks.append(mask)
|
||||
# batch the masks in the format [b, h, w]
|
||||
return torch.cat(resized_masks, dim=0)
|
||||
|
||||
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
|
||||
if len(latents) == 0:
|
||||
return None
|
||||
samples_out = latents[0].copy()
|
||||
samples_out["batch_index"] = []
|
||||
first_samples = latents[0]["samples"]
|
||||
tensors: list[torch.Tensor] = []
|
||||
for latent in latents:
|
||||
# first, deal with latent tensors
|
||||
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
|
||||
# next, deal with batch_index
|
||||
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
|
||||
samples_out["samples"] = torch.cat(tensors, dim=0)
|
||||
return samples_out
|
||||
|
||||
class BatchImagesNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesNode",
|
||||
display_name="Batch Images",
|
||||
category="image",
|
||||
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
||||
inputs=[
|
||||
io.Autogrow.Input("images", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_images(list(images.values())))
|
||||
|
||||
class BatchMasksNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchMasksNode",
|
||||
search_aliases=["combine masks", "stack masks", "merge masks"],
|
||||
display_name="Batch Masks",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Autogrow.Input("masks", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_masks(list(masks.values())))
|
||||
|
||||
class BatchLatentsNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchLatentsNode",
|
||||
search_aliases=["combine latents", "stack latents", "merge latents"],
|
||||
display_name="Batch Latents",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.Autogrow.Input("latents", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_latents(list(latents.values())))
|
||||
|
||||
class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||
io.MatchType.Input("input", matchtype_template),
|
||||
prefix="input", min=1, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesMasksLatentsNode",
|
||||
search_aliases=["combine batch", "merge batch", "stack inputs"],
|
||||
display_name="Batch Images/Masks/Latents",
|
||||
category="util",
|
||||
inputs=[
|
||||
io.Autogrow.Input("inputs", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(id=None, template=matchtype_template)
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||
batched = None
|
||||
values = list(inputs.values())
|
||||
# latents
|
||||
if isinstance(values[0], dict):
|
||||
batched = batch_latents(values)
|
||||
# images
|
||||
elif is_image(values[0]):
|
||||
batched = batch_images(values)
|
||||
# masks
|
||||
else:
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -249,6 +665,11 @@ class PostProcessingExtension(ComfyExtension):
|
||||
Quantize,
|
||||
Sharpen,
|
||||
ImageScaleToTotalPixels,
|
||||
ResizeImageMaskNode,
|
||||
BatchImagesNode,
|
||||
BatchMasksNode,
|
||||
BatchLatentsNode,
|
||||
# BatchImagesMasksLatentsNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> PostProcessingExtension:
|
||||
|
||||
@@ -16,6 +16,7 @@ class PreviewAny():
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "utils"
|
||||
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
|
||||
|
||||
def main(self, source=None):
|
||||
value = 'None'
|
||||
@@ -39,5 +40,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PreviewAny": "Preview Any",
|
||||
"PreviewAny": "Preview as Text",
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ class Float(io.ComfyNode):
|
||||
display_name="Float",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
|
||||
],
|
||||
outputs=[io.Float.Output()],
|
||||
)
|
||||
|
||||
@@ -3,7 +3,9 @@ import comfy.utils
|
||||
import math
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyQwenImageLayeredLatentImage",
|
||||
display_name="Empty Qwen Image Layered Latent",
|
||||
category="latent/qwen",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
class QwenExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeQwenImageEdit,
|
||||
TextEncodeQwenImageEditPlus,
|
||||
EmptyQwenImageLayeredLatentImage,
|
||||
]
|
||||
|
||||
|
||||
|
||||
103
comfy_extras/nodes_replacements.py
Normal file
103
comfy_extras/nodes_replacements.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from comfy_api.latest import ComfyExtension, io, ComfyAPI
|
||||
|
||||
api = ComfyAPI()
|
||||
|
||||
|
||||
async def register_replacements():
|
||||
"""Register all built-in node replacements."""
|
||||
await register_replacements_longeredge()
|
||||
await register_replacements_batchimages()
|
||||
await register_replacements_upscaleimage()
|
||||
await register_replacements_controlnet()
|
||||
await register_replacements_load3d()
|
||||
await register_replacements_preview3d()
|
||||
await register_replacements_svdimg2vid()
|
||||
await register_replacements_conditioningavg()
|
||||
|
||||
async def register_replacements_longeredge():
|
||||
# No dynamic inputs here
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ImageScaleToMaxDimension",
|
||||
old_node_id="ResizeImagesByLongerEdge",
|
||||
old_widget_ids=["longer_edge"],
|
||||
input_mapping=[
|
||||
{"new_id": "image", "old_id": "images"},
|
||||
{"new_id": "largest_size", "old_id": "longer_edge"},
|
||||
{"new_id": "upscale_method", "set_value": "lanczos"},
|
||||
],
|
||||
# just to test the frontend output_mapping code, does nothing really here
|
||||
output_mapping=[{"new_idx": 0, "old_idx": 0}],
|
||||
))
|
||||
|
||||
async def register_replacements_batchimages():
|
||||
# BatchImages node uses Autogrow
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="BatchImagesNode",
|
||||
old_node_id="ImageBatch",
|
||||
input_mapping=[
|
||||
{"new_id": "images.image0", "old_id": "image1"},
|
||||
{"new_id": "images.image1", "old_id": "image2"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_upscaleimage():
|
||||
# ResizeImageMaskNode uses DynamicCombo
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ResizeImageMaskNode",
|
||||
old_node_id="ImageScaleBy",
|
||||
old_widget_ids=["upscale_method", "scale_by"],
|
||||
input_mapping=[
|
||||
{"new_id": "input", "old_id": "image"},
|
||||
{"new_id": "resize_type", "set_value": "scale by multiplier"},
|
||||
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
|
||||
{"new_id": "scale_method", "old_id": "upscale_method"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_controlnet():
|
||||
# T2IAdapterLoader → ControlNetLoader
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ControlNetLoader",
|
||||
old_node_id="T2IAdapterLoader",
|
||||
input_mapping=[
|
||||
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_load3d():
|
||||
# Load3DAnimation merged into Load3D
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="Load3D",
|
||||
old_node_id="Load3DAnimation",
|
||||
))
|
||||
|
||||
async def register_replacements_preview3d():
|
||||
# Preview3DAnimation merged into Preview3D
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="Preview3D",
|
||||
old_node_id="Preview3DAnimation",
|
||||
))
|
||||
|
||||
async def register_replacements_svdimg2vid():
|
||||
# Typo fix: SDV → SVD
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="SVD_img2vid_Conditioning",
|
||||
old_node_id="SDV_img2vid_Conditioning",
|
||||
))
|
||||
|
||||
async def register_replacements_conditioningavg():
|
||||
# Typo fix: trailing space in node name
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ConditioningAverage",
|
||||
old_node_id="ConditioningAverage ",
|
||||
))
|
||||
|
||||
class NodeReplacementsExtension(ComfyExtension):
|
||||
async def on_load(self) -> None:
|
||||
await register_replacements()
|
||||
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return []
|
||||
|
||||
async def comfy_entrypoint() -> NodeReplacementsExtension:
|
||||
return NodeReplacementsExtension()
|
||||
47
comfy_extras/nodes_rope.py
Normal file
47
comfy_extras/nodes_rope.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class ScaleROPE(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ScaleROPE",
|
||||
category="advanced/model_patches",
|
||||
description="Scale and shift the ROPE of the model.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||
|
||||
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||
|
||||
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||
|
||||
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class RopeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ScaleROPE
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> RopeExtension:
|
||||
return RopeExtension()
|
||||
@@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples":latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
@@ -65,6 +65,7 @@ class CLIPTextEncodeSD3(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeSD3",
|
||||
search_aliases=["sd3 prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode):
|
||||
node_id="StringConcatenate",
|
||||
display_name="Concatenate",
|
||||
category="utils/string",
|
||||
search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
@@ -31,6 +32,7 @@ class StringSubstring(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringSubstring",
|
||||
search_aliases=["extract text", "text portion"],
|
||||
display_name="Substring",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -53,6 +55,7 @@ class StringLength(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringLength",
|
||||
search_aliases=["character count", "text size"],
|
||||
display_name="Length",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -73,6 +76,7 @@ class CaseConverter(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CaseConverter",
|
||||
search_aliases=["text case", "uppercase", "lowercase", "capitalize"],
|
||||
display_name="Case Converter",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -105,6 +109,7 @@ class StringTrim(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringTrim",
|
||||
search_aliases=["clean whitespace", "remove whitespace"],
|
||||
display_name="Trim",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -135,6 +140,7 @@ class StringReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringReplace",
|
||||
search_aliases=["find and replace", "substitute", "swap text"],
|
||||
display_name="Replace",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -157,6 +163,7 @@ class StringContains(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringContains",
|
||||
search_aliases=["text includes", "string includes"],
|
||||
display_name="Contains",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -184,6 +191,7 @@ class StringCompare(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringCompare",
|
||||
search_aliases=["text match", "string equals", "starts with", "ends with"],
|
||||
display_name="Compare",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -219,6 +227,7 @@ class RegexMatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexMatch",
|
||||
search_aliases=["pattern match", "text contains", "string match"],
|
||||
display_name="Regex Match",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -259,6 +268,7 @@ class RegexExtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexExtract",
|
||||
search_aliases=["pattern extract", "text parser", "parse text"],
|
||||
display_name="Regex Extract",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@@ -333,6 +343,7 @@ class RegexReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexReplace",
|
||||
search_aliases=["pattern replace", "find and replace", "substitution"],
|
||||
display_name="Regex Replace",
|
||||
category="utils/string",
|
||||
description="Find and replace text using regex patterns.",
|
||||
|
||||
47
comfy_extras/nodes_toolkit.py
Normal file
47
comfy_extras/nodes_toolkit.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class CreateList(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template_matchtype = io.MatchType.Template("type")
|
||||
template_autogrow = io.Autogrow.TemplatePrefix(
|
||||
input=io.MatchType.Input("input", template=template_matchtype),
|
||||
prefix="input",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="CreateList",
|
||||
display_name="Create List",
|
||||
category="logic",
|
||||
is_input_list=True,
|
||||
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
|
||||
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
|
||||
outputs=[
|
||||
io.MatchType.Output(
|
||||
template=template_matchtype,
|
||||
is_output_list=True,
|
||||
display_name="list",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||
output_list = []
|
||||
for input in inputs.values():
|
||||
output_list += input
|
||||
return io.NodeOutput(output_list)
|
||||
|
||||
|
||||
class ToolkitExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
CreateList,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ToolkitExtension:
|
||||
return ToolkitExtension()
|
||||
@@ -2,6 +2,8 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
|
||||
def skip_torch_compile_dict(guard_entries):
|
||||
return [("transformer_options" not in entry.name) for entry in guard_entries]
|
||||
|
||||
class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model, backend) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
set_torch_compile_wrapper(model=m, backend=backend)
|
||||
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
node_id="ImageUpscaleWithModel",
|
||||
display_name="Upscale Image (using Model)",
|
||||
category="image/upscaling",
|
||||
search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"],
|
||||
inputs=[
|
||||
io.UpscaleModel.Input("upscale_model"),
|
||||
io.Image.Input("image"),
|
||||
@@ -78,18 +79,20 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
overlap = 32
|
||||
|
||||
oom = True
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
|
||||
upscale_model.to("cpu")
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
@@ -8,10 +8,7 @@ import json
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
||||
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@@ -19,6 +16,7 @@ class SaveWEBM(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveWEBM",
|
||||
search_aliases=["export webm"],
|
||||
category="image/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
@@ -28,7 +26,6 @@ class SaveWEBM(io.ComfyNode):
|
||||
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
||||
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
@@ -73,22 +70,22 @@ class SaveVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveVideo",
|
||||
search_aliases=["export video"],
|
||||
display_name="Save Video",
|
||||
category="image/video",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
@@ -105,10 +102,10 @@ class SaveVideo(io.ComfyNode):
|
||||
metadata["prompt"] = cls.hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=format,
|
||||
format=Types.VideoContainer(format),
|
||||
codec=codec,
|
||||
metadata=saved_metadata
|
||||
)
|
||||
@@ -121,6 +118,7 @@ class CreateVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CreateVideo",
|
||||
search_aliases=["images to video"],
|
||||
display_name="Create Video",
|
||||
category="image/video",
|
||||
description="Create a video from images.",
|
||||
@@ -135,9 +133,9 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@@ -145,6 +143,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetVideoComponents",
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
@@ -159,11 +158,11 @@ class GetVideoComponents(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -172,6 +171,7 @@ class LoadVideo(io.ComfyNode):
|
||||
files = folder_paths.filter_files_content_types(files, ["video"])
|
||||
return io.Schema(
|
||||
node_id="LoadVideo",
|
||||
search_aliases=["import video", "open video", "video file"],
|
||||
display_name="Load Video",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
@@ -185,7 +185,7 @@ class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, file) -> io.NodeOutput:
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return io.NodeOutput(VideoFromFile(video_path))
|
||||
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(s, file):
|
||||
@@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
|
||||
|
||||
return True
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Video Slice",
|
||||
display_name="Video Slice",
|
||||
search_aliases=[
|
||||
"trim video duration",
|
||||
"skip first frames",
|
||||
"frame load cap",
|
||||
"start time",
|
||||
],
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
io.Float.Input(
|
||||
"start_time",
|
||||
default=0.0,
|
||||
max=1e5,
|
||||
min=-1e5,
|
||||
step=0.001,
|
||||
tooltip="Start time in seconds",
|
||||
),
|
||||
io.Float.Input(
|
||||
"duration",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
step=0.001,
|
||||
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"strict_duration",
|
||||
default=False,
|
||||
tooltip="If True, when the specified duration is not possible, an error will be raised.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
|
||||
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
|
||||
if trimmed is not None:
|
||||
return io.NodeOutput(trimmed)
|
||||
raise ValueError(
|
||||
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
|
||||
)
|
||||
|
||||
|
||||
class VideoExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
VideoSlice,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
|
||||
@@ -8,9 +8,10 @@ import comfy.latent_formats
|
||||
import comfy.clip_vision
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from typing import Tuple, TypedDict
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -286,6 +287,7 @@ class WanVaceToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanVaceToVideo",
|
||||
search_aliases=["video conditioning", "video control"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
@@ -704,6 +706,7 @@ class WanTrackToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanTrackToVideo",
|
||||
search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
@@ -817,7 +820,7 @@ def get_sample_indices(original_fps,
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
if fixed_start is not None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
@@ -1288,6 +1291,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
|
||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
class DCValues(TypedDict):
|
||||
mode: str
|
||||
audio_encoder_output_2: io.AudioEncoderOutput.Type
|
||||
mask: io.Mask.Type
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanInfiniteTalkToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.DynamicCombo.Input("mode", options=[
|
||||
io.DynamicCombo.Option("single_speaker", []),
|
||||
io.DynamicCombo.Option("two_speakers", [
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
|
||||
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
|
||||
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
|
||||
]),
|
||||
]),
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
|
||||
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
|
||||
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
io.Image.Input("previous_frames", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
||||
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
||||
|
||||
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
||||
raise ValueError("Not enough previous frames provided.")
|
||||
|
||||
if mode["mode"] == "two_speakers":
|
||||
audio_encoder_output_2 = mode["audio_encoder_output_2"]
|
||||
mask_1 = mode["mask_1"]
|
||||
mask_2 = mode["mask_2"]
|
||||
|
||||
if audio_encoder_output_2 is not None:
|
||||
if mask_1 is None or mask_2 is None:
|
||||
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
|
||||
|
||||
ref_masks = None
|
||||
if mask_1 is not None and mask_2 is not None:
|
||||
if audio_encoder_output_2 is None:
|
||||
raise ValueError("Second audio encoder output must be provided if two masks are used.")
|
||||
ref_masks = torch.cat([mask_1, mask_2])
|
||||
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
model_patched = model.clone()
|
||||
|
||||
encoded_audio_list = []
|
||||
seq_lengths = []
|
||||
|
||||
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
|
||||
if audio_encoder_output is None:
|
||||
continue
|
||||
all_layers = audio_encoder_output["encoded_audio_all_layers"]
|
||||
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
|
||||
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
|
||||
encoded_audio_list.append(encoded_audio)
|
||||
seq_lengths.append(encoded_audio.shape[0])
|
||||
|
||||
# Pad / combine depending on multi_audio_type
|
||||
multi_audio_type = "add"
|
||||
if len(encoded_audio_list) > 1:
|
||||
if multi_audio_type == "para":
|
||||
max_len = max(seq_lengths)
|
||||
padded = []
|
||||
for emb in encoded_audio_list:
|
||||
if emb.shape[0] < max_len:
|
||||
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
|
||||
emb = torch.cat([emb, pad], dim=0)
|
||||
padded.append(emb)
|
||||
encoded_audio_list = padded
|
||||
elif multi_audio_type == "add":
|
||||
total_len = sum(seq_lengths)
|
||||
full_list = []
|
||||
offset = 0
|
||||
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
|
||||
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
|
||||
full[offset:offset+seq_len] = emb
|
||||
full_list.append(full)
|
||||
offset += seq_len
|
||||
encoded_audio_list = full_list
|
||||
|
||||
token_ref_target_masks = None
|
||||
if ref_masks is not None:
|
||||
token_ref_target_masks = torch.nn.functional.interpolate(
|
||||
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
|
||||
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
|
||||
|
||||
# when extending from previous frames
|
||||
if previous_frames is not None:
|
||||
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_offset = previous_frames.shape[0] - motion_frame_count
|
||||
|
||||
audio_start = frame_offset
|
||||
audio_end = audio_start + length
|
||||
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
||||
|
||||
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
||||
trim_image = motion_frame_count
|
||||
else:
|
||||
audio_start = trim_image = 0
|
||||
audio_end = length
|
||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||
|
||||
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
||||
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||
|
||||
# add outer sample wrapper
|
||||
model_patched.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
|
||||
"infinite_talk_outer_sample",
|
||||
InfiniteTalkOuterSampleWrapper(
|
||||
motion_frames_latent,
|
||||
model_patch,
|
||||
is_extend=previous_frames is not None,
|
||||
))
|
||||
# add cross-attention patch
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
|
||||
if token_ref_target_masks is not None:
|
||||
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -1307,6 +1475,7 @@ class WanExtension(ComfyExtension):
|
||||
WanHuMoImageToVideo,
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
536
comfy_extras/nodes_wanmove.py
Normal file
536
comfy_extras/nodes_wanmove.py
Normal file
@@ -0,0 +1,536 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.nodes_wan import parse_json_tracks
|
||||
|
||||
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
SKIP_ZERO = False
|
||||
|
||||
def get_pos_emb(
|
||||
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
|
||||
pos_emb_dim: int,
|
||||
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
|
||||
|
||||
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
|
||||
pos_k = pos_k.to(device, dtype)
|
||||
if SKIP_ZERO:
|
||||
pos_k = pos_k + 1
|
||||
batch_size = pos_k.size(0)
|
||||
|
||||
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
|
||||
# Expand denominator to match the shape needed for broadcasting
|
||||
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
|
||||
|
||||
thetas = theta_func(denominator_expanded, pos_emb_dim)
|
||||
|
||||
# Ensure pos_k is in the correct shape for broadcasting
|
||||
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
|
||||
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
|
||||
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
|
||||
|
||||
# Concatenate sine and cosine embeddings along the last dimension
|
||||
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
|
||||
|
||||
return pos_emb
|
||||
|
||||
def create_pos_embeddings(
|
||||
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
|
||||
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
|
||||
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
|
||||
height: int, # the height of the feature map
|
||||
width: int, # the width of the feature map
|
||||
track_num: int = -1, # the number of tracks to use
|
||||
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
|
||||
):
|
||||
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
|
||||
|
||||
t, n, _ = pred_tracks.shape
|
||||
t_down, h_down, w_down = downsample_ratios
|
||||
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
|
||||
|
||||
if track_num == -1:
|
||||
track_num = n
|
||||
|
||||
tracks_idx = torch.randperm(n)[:track_num]
|
||||
tracks = pred_tracks[:, tracks_idx]
|
||||
visibility = pred_visibility[:, tracks_idx]
|
||||
|
||||
for t_idx in range(0, t, t_down):
|
||||
if t_down_strategy == "sample" or t_idx == 0:
|
||||
cur_tracks = tracks[t_idx] # [N, 2]
|
||||
cur_visibility = visibility[t_idx] # [N]
|
||||
else:
|
||||
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
|
||||
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
|
||||
|
||||
for i in range(track_num):
|
||||
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
|
||||
continue
|
||||
x, y = cur_tracks[i]
|
||||
x, y = int(x // w_down), int(y // h_down)
|
||||
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
|
||||
|
||||
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
|
||||
|
||||
def replace_feature(
|
||||
vae_feature: torch.Tensor, # [B, C', T', H', W']
|
||||
track_pos: torch.Tensor, # [B, N, T', 2]
|
||||
strength: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
b, _, t, h, w = vae_feature.shape
|
||||
assert b == track_pos.shape[0], "Batch size mismatch."
|
||||
n = track_pos.shape[1]
|
||||
|
||||
# Shuffle the trajectory order
|
||||
track_pos = track_pos[:, torch.randperm(n), :, :]
|
||||
|
||||
# Extract coordinates at time steps ≥ 1 and generate a valid mask
|
||||
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
|
||||
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
|
||||
|
||||
# Get all valid indices
|
||||
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
|
||||
num_valid = valid_indices.shape[0]
|
||||
|
||||
if num_valid == 0:
|
||||
return vae_feature
|
||||
|
||||
# Decompose valid indices into each dimension
|
||||
batch_idx = valid_indices[:, 0]
|
||||
track_idx = valid_indices[:, 1]
|
||||
t_rel = valid_indices[:, 2]
|
||||
t_target = t_rel + 1 # Convert to original time step indices
|
||||
|
||||
# Extract target position coordinates
|
||||
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
|
||||
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
|
||||
|
||||
# Extract source position coordinates (t=0)
|
||||
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
|
||||
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
|
||||
|
||||
# Get source features and assign to target positions
|
||||
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
|
||||
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
|
||||
|
||||
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
|
||||
|
||||
|
||||
return vae_feature
|
||||
|
||||
# Visualize functions
|
||||
|
||||
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
|
||||
draw = ImageDraw.Draw(overlay, 'RGBA')
|
||||
points = points[::-1]
|
||||
|
||||
# Compute total length
|
||||
total_length = 0
|
||||
segment_lengths = []
|
||||
for i in range(len(points) - 1):
|
||||
dx = points[i + 1][0] - points[i][0]
|
||||
dy = points[i + 1][1] - points[i][1]
|
||||
length = (dx * dx + dy * dy) ** 0.5
|
||||
segment_lengths.append(length)
|
||||
total_length += length
|
||||
|
||||
if total_length == 0:
|
||||
return
|
||||
|
||||
accumulated_length = 0
|
||||
|
||||
# Draw the gradient polyline
|
||||
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
|
||||
segment_length = segment_lengths[idx]
|
||||
steps = max(int(segment_length), 1)
|
||||
|
||||
for i in range(steps):
|
||||
current_length = accumulated_length + (i / steps) * segment_length
|
||||
ratio = current_length / total_length
|
||||
|
||||
alpha = int(255 * (1 - ratio) * opacity)
|
||||
color = (*start_color, alpha)
|
||||
|
||||
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
|
||||
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
|
||||
|
||||
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
|
||||
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
|
||||
|
||||
accumulated_length += segment_length
|
||||
|
||||
|
||||
def add_weighted(rgb, track):
|
||||
rgb = np.array(rgb) # [H, W, C] "RGB"
|
||||
track = np.array(track) # [H, W, C] "RGBA"
|
||||
|
||||
alpha = track[:, :, 3] / 255.0
|
||||
alpha = np.stack([alpha] * 3, axis=-1)
|
||||
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
|
||||
|
||||
return Image.fromarray(blend_img.astype(np.uint8))
|
||||
|
||||
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
|
||||
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
|
||||
|
||||
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
|
||||
tracks = tracks[0].long().detach().cpu().numpy()
|
||||
if visibility is not None:
|
||||
visibility = visibility[0].detach().cpu().numpy()
|
||||
|
||||
num_frames, height, width = video.shape[:3]
|
||||
num_tracks = tracks.shape[1]
|
||||
alpha_opacity = int(255 * opacity)
|
||||
|
||||
output_frames = []
|
||||
for t in range(num_frames):
|
||||
frame_rgb = video[t].astype(np.float32)
|
||||
|
||||
# Create a single RGBA overlay for all tracks in this frame
|
||||
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
draw_overlay = ImageDraw.Draw(overlay)
|
||||
|
||||
polyline_data = []
|
||||
|
||||
# Draw all circles on a single overlay
|
||||
for n in range(num_tracks):
|
||||
if visibility is not None and visibility[t, n] == 0:
|
||||
continue
|
||||
|
||||
track_coord = tracks[t, n]
|
||||
color = color_map[n % len(color_map)]
|
||||
circle_color = color + (alpha_opacity,)
|
||||
|
||||
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
|
||||
fill=circle_color
|
||||
)
|
||||
|
||||
# Store polyline data for batch processing
|
||||
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
|
||||
if len(tracks_coord) > 1:
|
||||
polyline_data.append((tracks_coord, color))
|
||||
|
||||
# Blend circles overlay once
|
||||
overlay_np = np.array(overlay)
|
||||
alpha = overlay_np[:, :, 3:4] / 255.0
|
||||
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
||||
|
||||
# Draw all polylines on a single overlay
|
||||
if polyline_data:
|
||||
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
for tracks_coord, color in polyline_data:
|
||||
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
|
||||
|
||||
# Blend polylines overlay once
|
||||
polyline_np = np.array(polyline_overlay)
|
||||
alpha = polyline_np[:, :, 3:4] / 255.0
|
||||
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
||||
|
||||
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
|
||||
|
||||
return output_frames
|
||||
|
||||
|
||||
class WanMoveVisualizeTracks(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveVisualizeTracks",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.Tracks.Input("tracks", optional=True),
|
||||
io.Int.Input("line_resolution", default=24, min=1, max=1024),
|
||||
io.Int.Input("circle_size", default=12, min=1, max=128),
|
||||
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
|
||||
io.Int.Input("line_width", default=16, min=1, max=128),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
|
||||
if tracks is None:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
track_path = tracks["track_path"].unsqueeze(0)
|
||||
track_visibility = tracks["track_visibility"].unsqueeze(0)
|
||||
images_in = images * 255.0
|
||||
if images_in.shape[0] != track_path.shape[1]:
|
||||
repeat_count = track_path.shape[1] // images.shape[0]
|
||||
images_in = images_in.repeat(repeat_count, 1, 1, 1)
|
||||
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
|
||||
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
|
||||
|
||||
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
|
||||
|
||||
|
||||
class WanMoveTracksFromCoords(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveTracksFromCoords",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
|
||||
io.Mask.Input("track_mask", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Tracks.Output(),
|
||||
io.Int.Output(display_name="track_length"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
|
||||
device=comfy.model_management.intermediate_device()
|
||||
|
||||
tracks_data = parse_json_tracks(track_coords)
|
||||
track_length = len(tracks_data[0])
|
||||
|
||||
track_list = [
|
||||
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
|
||||
for frame in range(len(tracks_data[0]))
|
||||
]
|
||||
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
||||
|
||||
num_tracks = tracks.shape[-2]
|
||||
if track_mask is None:
|
||||
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
||||
else:
|
||||
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
||||
|
||||
out_track_info = {}
|
||||
out_track_info["track_path"] = tracks
|
||||
out_track_info["track_visibility"] = track_visibility
|
||||
return io.NodeOutput(out_track_info, track_length)
|
||||
|
||||
|
||||
class GenerateTracks(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GenerateTracks",
|
||||
search_aliases=["motion paths", "camera movement", "trajectory"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=4096, step=16),
|
||||
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
|
||||
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
|
||||
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
|
||||
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
|
||||
io.Int.Input("num_frames", default=81, min=1, max=1024),
|
||||
io.Int.Input("num_tracks", default=5, min=1, max=100),
|
||||
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
|
||||
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
|
||||
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||
io.Combo.Input(
|
||||
"interpolation",
|
||||
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
|
||||
tooltip="Controls the timing/speed of movement along the path.",
|
||||
),
|
||||
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
|
||||
],
|
||||
outputs=[
|
||||
io.Tracks.Output(),
|
||||
io.Int.Output(display_name="track_length"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
|
||||
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
|
||||
device = comfy.model_management.intermediate_device()
|
||||
track_length = num_frames
|
||||
|
||||
# normalized coordinates to pixel coordinates
|
||||
start_x_px = start_x * width
|
||||
start_y_px = start_y * height
|
||||
mid_x_px = mid_x * width
|
||||
mid_y_px = mid_y * height
|
||||
end_x_px = end_x * width
|
||||
end_y_px = end_y * height
|
||||
|
||||
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
|
||||
|
||||
t = torch.linspace(0, 1, num_frames, device=device)
|
||||
if interpolation == "constant": # All points stay at start position
|
||||
interp_values = torch.zeros_like(t)
|
||||
elif interpolation == "linear":
|
||||
interp_values = t
|
||||
elif interpolation == "ease_in":
|
||||
interp_values = t ** 2
|
||||
elif interpolation == "ease_out":
|
||||
interp_values = 1 - (1 - t) ** 2
|
||||
elif interpolation == "ease_in_out":
|
||||
interp_values = t * t * (3 - 2 * t)
|
||||
|
||||
if bezier: # apply interpolation to t for timing control along the bezier path
|
||||
t_interp = interp_values
|
||||
one_minus_t = 1 - t_interp
|
||||
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
|
||||
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
|
||||
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
|
||||
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
|
||||
else: # calculate base x and y positions for each frame (center track)
|
||||
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
|
||||
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
|
||||
# For non-bezier, tangent is constant (direction from start to end)
|
||||
tangent_x = torch.full_like(t, end_x_px - start_x_px)
|
||||
tangent_y = torch.full_like(t, end_y_px - start_y_px)
|
||||
|
||||
track_list = []
|
||||
for frame_idx in range(num_frames):
|
||||
# Calculate perpendicular direction at this frame
|
||||
tx = tangent_x[frame_idx].item()
|
||||
ty = tangent_y[frame_idx].item()
|
||||
length = (tx ** 2 + ty ** 2) ** 0.5
|
||||
|
||||
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
|
||||
perp_x = -ty / length
|
||||
perp_y = tx / length
|
||||
else: # If tangent is zero, spread horizontally
|
||||
perp_x = 1.0
|
||||
perp_y = 0.0
|
||||
|
||||
frame_tracks = []
|
||||
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
|
||||
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
|
||||
track_x = x_positions[frame_idx].item() + perp_x * offset
|
||||
track_y = y_positions[frame_idx].item() + perp_y * offset
|
||||
frame_tracks.append([track_x, track_y])
|
||||
track_list.append(frame_tracks)
|
||||
|
||||
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
||||
|
||||
if track_mask is None:
|
||||
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
||||
else:
|
||||
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
||||
|
||||
out_track_info = {}
|
||||
out_track_info["track_path"] = tracks
|
||||
out_track_info["track_visibility"] = track_visibility
|
||||
return io.NodeOutput(out_track_info, track_length)
|
||||
|
||||
|
||||
class WanMoveConcatTrack(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveConcatTrack",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Tracks.Input("tracks_1"),
|
||||
io.Tracks.Input("tracks_2", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Tracks.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
|
||||
if tracks_2 is None:
|
||||
return io.NodeOutput(tracks_1)
|
||||
|
||||
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
|
||||
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
|
||||
|
||||
out_track_info = {}
|
||||
out_track_info["track_path"] = tracks_out
|
||||
out_track_info["track_visibility"] = mask_out
|
||||
return io.NodeOutput(out_track_info)
|
||||
|
||||
|
||||
class WanMoveTrackToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveTrackToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Tracks.Input("tracks", optional=True),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image"),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
device=comfy.model_management.intermediate_device()
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
if tracks is not None and strength > 0.0:
|
||||
tracks_path = tracks["track_path"][:length] # [T, N, 2]
|
||||
num_tracks = tracks_path.shape[-2]
|
||||
|
||||
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
|
||||
|
||||
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
|
||||
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
|
||||
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
|
||||
else:
|
||||
concat_latent_image_pos = concat_latent_image
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanMoveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanMoveTrackToVideo,
|
||||
WanMoveTracksFromCoords,
|
||||
WanMoveConcatTrack,
|
||||
WanMoveVisualizeTracks,
|
||||
GenerateTracks,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanMoveExtension:
|
||||
return WanMoveExtension()
|
||||
@@ -5,6 +5,7 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
|
||||
class WebcamCapture(nodes.LoadImage):
|
||||
SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
|
||||
88
comfy_extras/nodes_zimage.py
Normal file
88
comfy_extras/nodes_zimage.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import node_helpers
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class TextEncodeZImageOmni(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeZImageOmni",
|
||||
category="advanced/conditioning",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.ClipVision.Input("image_encoder", optional=True),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Boolean.Input("auto_resize_images", default=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("image1", optional=True),
|
||||
io.Image.Input("image2", optional=True),
|
||||
io.Image.Input("image3", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
|
||||
ref_latents = []
|
||||
images = list(filter(lambda a: a is not None, [image1, image2, image3]))
|
||||
|
||||
prompt_list = []
|
||||
template = None
|
||||
if len(images) > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1)
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"
|
||||
|
||||
encoded_images = []
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if image_encoder is not None:
|
||||
encoded_images.append(image_encoder.encode_image(image))
|
||||
|
||||
if vae is not None:
|
||||
if auto_resize_images:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(1024 * 1024)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by / 8.0) * 8
|
||||
height = round(samples.shape[2] * scale_by / 8.0) * 8
|
||||
|
||||
image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
||||
ref_latents.append(vae.encode(image))
|
||||
|
||||
tokens = clip.tokenize(prompt, llama_template=template)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
extra_text_embeds = []
|
||||
for p in prompt_list:
|
||||
tokens = clip.tokenize(p, llama_template="{}")
|
||||
text_embeds = clip.encode_from_tokens_scheduled(tokens)
|
||||
extra_text_embeds.append(text_embeds[0][0])
|
||||
|
||||
if len(ref_latents) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
if len(encoded_images) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True)
|
||||
if len(extra_text_embeds) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True)
|
||||
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class ZImageExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeZImageOmni,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ZImageExtension:
|
||||
return ZImageExtension()
|
||||
Reference in New Issue
Block a user