Merge branch 'master' into rename-mahiro

This commit is contained in:
Jedrzej Kosinski
2026-02-27 09:14:21 -08:00
committed by GitHub
213 changed files with 10378 additions and 1284 deletions

View File

@@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)

View File

@@ -47,8 +47,8 @@ class SamplerLCMUpscale(io.ComfyNode):
node_id="SamplerLCMUpscale",
category="sampling/custom_sampling/samplers",
inputs=[
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True),
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[io.Sampler.Output()],
@@ -94,7 +94,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
display_name="SamplerEulerCFG++",
category="_for_testing", # "sampling/custom_sampling/samplers"
inputs=[
io.Combo.Input("version", options=["regular", "alternative"]),
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
],
outputs=[io.Sampler.Output()],
is_experimental=True,

View File

@@ -26,6 +26,7 @@ class APG(io.ComfyNode):
max=10.0,
step=0.01,
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
advanced=True,
),
io.Float.Input(
"norm_threshold",
@@ -34,6 +35,7 @@ class APG(io.ComfyNode):
max=50.0,
step=0.1,
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
advanced=True,
),
io.Float.Input(
"momentum",
@@ -42,6 +44,7 @@ class APG(io.ComfyNode):
max=1.0,
step=0.01,
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
advanced=True,
),
],
outputs=[io.Model.Output()],

View File

@@ -28,10 +28,10 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
category="_for_testing/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[io.Model.Output()],
is_experimental=True,
@@ -51,10 +51,10 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
category="_for_testing/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[io.Model.Output()],
is_experimental=True,
@@ -75,10 +75,10 @@ class CLIPAttentionMultiply(io.ComfyNode):
category="_for_testing/attention_experiments",
inputs=[
io.Clip.Input("clip"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[io.Clip.Output()],
is_experimental=True,
@@ -109,10 +109,10 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
category="_for_testing/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[io.Model.Output()],
is_experimental=True,

View File

@@ -22,7 +22,7 @@ class EmptyLatentAudio(IO.ComfyNode):
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch.",
),
],
outputs=[IO.Latent.Output()],
@@ -159,6 +159,7 @@ class SaveAudio(IO.ComfyNode):
search_aliases=["export flac"],
display_name="Save Audio (FLAC)",
category="audio",
essentials_category="Audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
@@ -300,6 +301,7 @@ class LoadAudio(IO.ComfyNode):
search_aliases=["import audio", "open audio", "audio file"],
display_name="Load Audio",
category="audio",
essentials_category="Audio",
inputs=[
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
],
@@ -677,6 +679,7 @@ class EmptyAudio(IO.ComfyNode):
tooltip="Sample rate of the empty audio clip.",
min=1,
max=192000,
advanced=True,
),
IO.Int.Input(
"channels",
@@ -684,6 +687,7 @@ class EmptyAudio(IO.ComfyNode):
min=1,
max=2,
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
advanced=True,
),
],
outputs=[IO.Audio.Output()],
@@ -698,6 +702,67 @@ class EmptyAudio(IO.ComfyNode):
create_empty_audio = execute # TODO: remove
class AudioEqualizer3Band(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="AudioEqualizer3Band",
search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
display_name="Audio Equalizer (3-Band)",
category="audio",
is_experimental=True,
inputs=[
IO.Audio.Input("audio"),
IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
],
outputs=[IO.Audio.Output()],
)
@classmethod
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
eq_waveform = waveform.clone()
# 1. Apply Low Shelf (Bass)
if low_gain_dB != 0:
eq_waveform = torchaudio.functional.bass_biquad(
eq_waveform,
sample_rate,
gain=low_gain_dB,
central_freq=float(low_freq),
Q=0.707
)
# 2. Apply Peaking EQ (Mids)
if mid_gain_dB != 0:
eq_waveform = torchaudio.functional.equalizer_biquad(
eq_waveform,
sample_rate,
center_freq=float(mid_freq),
gain=mid_gain_dB,
Q=mid_q
)
# 3. Apply High Shelf (Treble)
if high_gain_dB != 0:
eq_waveform = torchaudio.functional.treble_biquad(
eq_waveform,
sample_rate,
gain=high_gain_dB,
central_freq=float(high_freq),
Q=0.707
)
return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
class AudioExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -720,6 +785,7 @@ class AudioExtension(ComfyExtension):
AudioMerge,
AudioAdjustVolume,
EmptyAudio,
AudioEqualizer3Band,
]
async def comfy_entrypoint() -> AudioExtension:

View File

@@ -174,10 +174,10 @@ class WanCameraEmbedding(io.ComfyNode):
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True),
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True),
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True),
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True),
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True),
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True),
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True),
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True),
],
outputs=[
io.WanCameraEmbedding.Output(display_name="camera_embedding"),

View File

@@ -10,8 +10,10 @@ class Canny(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Canny",
display_name="Canny",
search_aliases=["edge detection", "outline", "contour detection", "line art"],
category="image/preprocessors",
essentials_category="Image Tools",
inputs=[
io.Image.Input("image"),
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),

View File

@@ -48,6 +48,7 @@ class ChromaRadianceOptions(io.ComfyNode):
min=0.0,
max=1.0,
tooltip="First sigma that these options will be in effect.",
advanced=True,
),
io.Float.Input(
id="end_sigma",
@@ -55,12 +56,14 @@ class ChromaRadianceOptions(io.ComfyNode):
min=0.0,
max=1.0,
tooltip="Last sigma that these options will be in effect.",
advanced=True,
),
io.Int.Input(
id="nerf_tile_size",
default=-1,
min=-1,
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
advanced=True,
),
],
outputs=[io.Model.Output()],

View File

@@ -35,8 +35,8 @@ class CLIPTextEncodeSDXL(io.ComfyNode):
io.Clip.Input("clip"),
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True),
io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True),
io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.String.Input("text_g", multiline=True, dynamic_prompts=True),

View File

@@ -38,8 +38,8 @@ class T5TokenizerOptions(io.ComfyNode):
category="_for_testing/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
],
outputs=[io.Clip.Output()],
is_experimental=True,

View File

@@ -14,15 +14,15 @@ class ContextWindowsManualNode(io.ComfyNode):
description="Manually set context windows.",
inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
@@ -67,15 +67,15 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
schema.description = "Manually set context windows for WAN-like models (dim=2)."
schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),

View File

@@ -48,8 +48,8 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
io.Image.Input("image"),
io.Mask.Input("mask"),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),

View File

@@ -50,9 +50,9 @@ class KarrasScheduler(io.ComfyNode):
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
],
outputs=[io.Sigmas.Output()]
)
@@ -72,8 +72,8 @@ class ExponentialScheduler(io.ComfyNode):
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
],
outputs=[io.Sigmas.Output()]
)
@@ -93,9 +93,9 @@ class PolyexponentialScheduler(io.ComfyNode):
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
],
outputs=[io.Sigmas.Output()]
)
@@ -115,10 +115,10 @@ class LaplaceScheduler(io.ComfyNode):
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False),
io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False, advanced=True),
io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False, advanced=True),
],
outputs=[io.Sigmas.Output()]
)
@@ -164,8 +164,8 @@ class BetaSamplingScheduler(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True),
io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True),
],
outputs=[io.Sigmas.Output()]
)
@@ -185,9 +185,9 @@ class VPScheduler(io.ComfyNode):
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values
io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False),
io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False),
io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values
io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
],
outputs=[io.Sigmas.Output()]
)
@@ -398,9 +398,9 @@ class SamplerDPMPP_3M_SDE(io.ComfyNode):
node_id="SamplerDPMPP_3M_SDE",
category="sampling/custom_sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Combo.Input("noise_device", options=['gpu', 'cpu']),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
],
outputs=[io.Sampler.Output()]
)
@@ -424,9 +424,9 @@ class SamplerDPMPP_2M_SDE(io.ComfyNode):
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=['midpoint', 'heun']),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Combo.Input("noise_device", options=['gpu', 'cpu']),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
],
outputs=[io.Sampler.Output()]
)
@@ -450,10 +450,10 @@ class SamplerDPMPP_SDE(io.ComfyNode):
node_id="SamplerDPMPP_SDE",
category="sampling/custom_sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False),
io.Combo.Input("noise_device", options=['gpu', 'cpu']),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
],
outputs=[io.Sampler.Output()]
)
@@ -496,8 +496,8 @@ class SamplerEulerAncestral(io.ComfyNode):
node_id="SamplerEulerAncestral",
category="sampling/custom_sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
],
outputs=[io.Sampler.Output()]
)
@@ -538,7 +538,7 @@ class SamplerLMS(io.ComfyNode):
return io.Schema(
node_id="SamplerLMS",
category="sampling/custom_sampling/samplers",
inputs=[io.Int.Input("order", default=4, min=1, max=100)],
inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)],
outputs=[io.Sampler.Output()]
)
@@ -556,16 +556,16 @@ class SamplerDPMAdaptative(io.ComfyNode):
node_id="SamplerDPMAdaptative",
category="sampling/custom_sampling/samplers",
inputs=[
io.Int.Input("order", default=3, min=2, max=3),
io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Int.Input("order", default=3, min=2, max=3, advanced=True),
io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
],
outputs=[io.Sampler.Output()]
)
@@ -588,9 +588,9 @@ class SamplerER_SDE(io.ComfyNode):
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
io.Int.Input("max_stage", default=3, min=1, max=3),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
],
outputs=[io.Sampler.Output()]
)
@@ -622,17 +622,18 @@ class SamplerSASolver(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerSASolver",
search_aliases=["sde"],
category="sampling/custom_sampling/samplers",
inputs=[
io.Model.Input("model"),
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001),
io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Int.Input("predictor_order", default=3, min=1, max=6),
io.Int.Input("corrector_order", default=4, min=0, max=6),
io.Boolean.Input("use_pece"),
io.Boolean.Input("simple_order_2"),
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True),
io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Int.Input("predictor_order", default=3, min=1, max=6, advanced=True),
io.Int.Input("corrector_order", default=4, min=0, max=6, advanced=True),
io.Boolean.Input("use_pece", advanced=True),
io.Boolean.Input("simple_order_2", advanced=True),
],
outputs=[io.Sampler.Output()]
)
@@ -666,12 +667,13 @@ class SamplerSEEDS2(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
search_aliases=["sde", "exp heun"],
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier", advanced=True),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)", advanced=True),
],
outputs=[io.Sampler.Output()],
description=(
@@ -728,7 +730,7 @@ class SamplerCustom(io.ComfyNode):
category="sampling/custom_sampling",
inputs=[
io.Model.Input("model"),
io.Boolean.Input("add_noise", default=True),
io.Boolean.Input("add_noise", default=True, advanced=True),
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Conditioning.Input("positive"),

View File

@@ -222,6 +222,7 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
"filename_prefix",
default="image",
tooltip="Prefix for saved image filenames.",
advanced=True,
),
],
outputs=[],
@@ -262,6 +263,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
"filename_prefix",
default="image",
tooltip="Prefix for saved image filenames.",
advanced=True,
),
],
outputs=[],
@@ -741,6 +743,7 @@ class NormalizeImagesNode(ImageProcessingNode):
min=0.0,
max=1.0,
tooltip="Mean value for normalization.",
advanced=True,
),
io.Float.Input(
"std",
@@ -748,6 +751,7 @@ class NormalizeImagesNode(ImageProcessingNode):
min=0.001,
max=1.0,
tooltip="Standard deviation for normalization.",
advanced=True,
),
]
@@ -961,6 +965,7 @@ class ImageDeduplicationNode(ImageProcessingNode):
min=0.0,
max=1.0,
tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
advanced=True,
),
]
@@ -1039,6 +1044,7 @@ class ImageGridNode(ImageProcessingNode):
min=32,
max=2048,
tooltip="Width of each cell in the grid.",
advanced=True,
),
io.Int.Input(
"cell_height",
@@ -1046,9 +1052,10 @@ class ImageGridNode(ImageProcessingNode):
min=32,
max=2048,
tooltip="Height of each cell in the grid.",
advanced=True,
),
io.Int.Input(
"padding", default=4, min=0, max=50, tooltip="Padding between images."
"padding", default=4, min=0, max=50, tooltip="Padding between images.", advanced=True
),
]
@@ -1339,6 +1346,7 @@ class SaveTrainingDataset(io.ComfyNode):
min=1,
max=100000,
tooltip="Number of samples per shard file.",
advanced=True,
),
],
outputs=[],

View File

@@ -108,7 +108,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
x: torch.Tensor = args[0][:, :easycache.output_channels]
# prepare next x_prev
next_x_prev = x
input_change = None
@@ -367,10 +367,10 @@ class EasyCacheNode(io.ComfyNode):
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add EasyCache to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache.", advanced=True),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache.", advanced=True),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True),
],
outputs=[
io.Model.Output(tooltip="The model with EasyCache."),
@@ -500,10 +500,10 @@ class LazyCacheNode(io.ComfyNode):
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add LazyCache to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache.", advanced=True),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache.", advanced=True),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True),
],
outputs=[
io.Model.Output(tooltip="The model with LazyCache."),

View File

@@ -28,6 +28,7 @@ class EpsilonScaling(io.ComfyNode):
max=1.5,
step=0.001,
display_mode=io.NumberDisplay.number,
advanced=True,
),
],
outputs=[
@@ -97,6 +98,7 @@ class TemporalScoreRescaling(io.ComfyNode):
max=100.0,
step=0.001,
display_mode=io.NumberDisplay.number,
advanced=True,
),
io.Float.Input(
"tsr_sigma",
@@ -109,6 +111,7 @@ class TemporalScoreRescaling(io.ComfyNode):
max=100.0,
step=0.001,
display_mode=io.NumberDisplay.number,
advanced=True,
),
],
outputs=[

View File

@@ -161,6 +161,7 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
advanced=True,
),
],
outputs=[

View File

@@ -32,10 +32,10 @@ class FreeU(IO.ComfyNode):
category="model_patches/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True),
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01, advanced=True),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[
IO.Model.Output(),
@@ -79,10 +79,10 @@ class FreeU_V2(IO.ComfyNode):
category="model_patches/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True),
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01, advanced=True),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[
IO.Model.Output(),

View File

@@ -65,11 +65,11 @@ class FreSca(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
tooltip="Scaling factor for low-frequency components"),
tooltip="Scaling factor for low-frequency components", advanced=True),
io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
tooltip="Scaling factor for high-frequency components"),
tooltip="Scaling factor for high-frequency components", advanced=True),
io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
tooltip="Number of frequency indices around center to consider as low-frequency"),
tooltip="Number of frequency indices around center to consider as low-frequency", advanced=True),
],
outputs=[
io.Model.Output(),

View File

@@ -342,7 +342,7 @@ class GITSScheduler(io.ComfyNode):
node_id="GITSScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True),
io.Int.Input("steps", default=10, min=2, max=1000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],

896
comfy_extras/nodes_glsl.py Normal file
View File

@@ -0,0 +1,896 @@
import os
import sys
import re
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
height: int
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# Vertex shader using gl_VertexID trick - no VBO needed.
# Draws a single triangle that covers the entire screen:
#
# (-1,3)
# /|
# / | <- visible area is the unit square from (-1,-1) to (1,1)
# / | parts outside get clipped away
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
gl_Position = vec4(verts[gl_VertexID], 0, 1);
}
"""
DEFAULT_FRAGMENT_SHADER = """#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
fragColor0 = texture(u_image0, v_texCoord);
}
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _detect_output_count(source: str) -> int:
"""Detect how many fragColor outputs are used in the shader.
Returns the count of outputs needed (1 to MAX_OUTPUTS).
"""
matches = re.findall(r"fragColor(\d+)", source)
if not matches:
return 1 # Default to 1 output if none found
max_index = max(int(m) for m in matches)
return min(max_index + 1, MAX_OUTPUTS)
def _detect_pass_count(source: str) -> int:
"""Detect multi-pass rendering from #pragma passes N directive.
Returns the number of passes (1 if not specified).
"""
match = re.search(r'#pragma\s+passes\s+(\d+)', source)
if match:
return max(1, int(match.group(1)))
return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
if self._backend is None:
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
)
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
shader = gl.glCreateShader(shader_type)
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
return shader
def _create_program(vertex_source: str, fragment_source: str) -> int:
"""Create and link a shader program."""
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
try:
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
except RuntimeError:
gl.glDeleteShader(vertex_shader)
raise
program = gl.glCreateProgram()
gl.glAttachShader(program, vertex_shader)
gl.glAttachShader(program, fragment_shader)
gl.glLinkProgram(program)
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
return program
def _render_shader_batch(
fragment_code: str,
width: int,
height: int,
image_batches: list[list[np.ndarray]],
floats: list[float],
ints: list[int],
) -> list[list[np.ndarray]]:
"""
Render a fragment shader for multiple batches efficiently.
Compiles shader once, reuses framebuffer/textures across batches.
Supports multi-pass rendering via #pragma passes N directive.
Args:
fragment_code: User's fragment shader code
width: Output width
height: Output height
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms
ints: List of int uniforms
Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
"""
import time
start_time = time.perf_counter()
if not image_batches:
return []
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
# Detect multi-pass rendering
num_passes = _detect_pass_count(fragment_code)
# Track resources for cleanup
program = None
fbo = None
output_textures = []
input_textures = []
ping_pong_textures = []
ping_pong_fbos = []
num_inputs = len(image_batches[0])
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
raise
gl.glUseProgram(program)
# Create framebuffer with only the needed color attachments
fbo = gl.glGenFramebuffers(1)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
draw_buffers = []
for i in range(num_outputs):
tex = gl.glGenTextures(1)
output_textures.append(tex)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
gl.glDrawBuffers(num_outputs, draw_buffers)
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
raise RuntimeError("Framebuffer is not complete")
# Create ping-pong resources for multi-pass rendering
if num_passes > 1:
for _ in range(2):
pp_tex = gl.glGenTextures(1)
ping_pong_textures.append(pp_tex)
gl.glBindTexture(gl.GL_TEXTURE_2D, pp_tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
pp_fbo = gl.glGenFramebuffers(1)
ping_pong_fbos.append(pp_fbo)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, pp_fbo)
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, pp_tex, 0)
gl.glDrawBuffers(1, [gl.GL_COLOR_ATTACHMENT0])
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
raise RuntimeError("Ping-pong framebuffer is not complete")
# Create input textures (reused for all batches)
for i in range(num_inputs):
tex = gl.glGenTextures(1)
input_textures.append(tex)
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_image{i}")
if loc >= 0:
gl.glUniform1i(loc, i)
# Set static uniforms (once for all batches)
loc = gl.glGetUniformLocation(program, "u_resolution")
if loc >= 0:
gl.glUniform2f(loc, float(width), float(height))
for i, v in enumerate(floats):
loc = gl.glGetUniformLocation(program, f"u_float{i}")
if loc >= 0:
gl.glUniform1f(loc, v)
for i, v in enumerate(ints):
loc = gl.glGetUniformLocation(program, f"u_int{i}")
if loc >= 0:
gl.glUniform1i(loc, v)
# Get u_pass uniform location for multi-pass
pass_loc = gl.glGetUniformLocation(program, "u_pass")
gl.glViewport(0, 0, width, height)
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
# Process each batch
all_batch_outputs = []
for images in image_batches:
# Update input textures with this batch's images
for i, img in enumerate(images):
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
# Flip vertically for GL coordinates, ensure RGBA
h, w, c = img.shape
if c == 3:
img_upload = np.empty((h, w, 4), dtype=np.float32)
img_upload[:, :, :3] = img[::-1, :, :]
img_upload[:, :, 3] = 1.0
else:
img_upload = np.ascontiguousarray(img[::-1, :, :])
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
if num_passes == 1:
# Single pass - render directly to output FBO
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
if pass_loc >= 0:
gl.glUniform1i(pass_loc, 0)
gl.glClearColor(0, 0, 0, 0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
else:
# Multi-pass rendering with ping-pong
for p in range(num_passes):
is_last_pass = (p == num_passes - 1)
# Set pass uniform
if pass_loc >= 0:
gl.glUniform1i(pass_loc, p)
if is_last_pass:
# Last pass renders to the main output FBO
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
else:
# Intermediate passes render to ping-pong FBO
target_fbo = ping_pong_fbos[p % 2]
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, target_fbo)
# Set input texture for this pass
gl.glActiveTexture(gl.GL_TEXTURE0)
if p == 0:
# First pass reads from original input
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[0])
else:
# Subsequent passes read from previous pass output
source_tex = ping_pong_textures[(p - 1) % 2]
gl.glBindTexture(gl.GL_TEXTURE_2D, source_tex)
gl.glClearColor(0, 0, 0, 0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(img[::-1, :, :].copy())
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
for _ in range(num_outputs, MAX_OUTPUTS):
batch_outputs.append(black_img)
all_batch_outputs.append(batch_outputs)
elapsed = (time.perf_counter() - start_time) * 1000
num_batches = len(image_batches)
pass_info = f", {num_passes} passes" if num_passes > 1 else ""
logger.info(f"GLSL shader executed in {elapsed:.1f}ms ({num_batches} batch{'es' if num_batches != 1 else ''}, {width}x{height}{pass_info})")
return all_batch_outputs
finally:
# Unbind before deleting
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(int(tex))
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo])
if program is not None:
gl.glDeleteProgram(program)
class GLSLShader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
image_template = io.Autogrow.TemplatePrefix(
io.Image.Input("image"),
prefix="image",
min=1,
max=MAX_IMAGES,
)
float_template = io.Autogrow.TemplatePrefix(
io.Float.Input("float", default=0.0),
prefix="u_float",
min=0,
max=MAX_UNIFORMS,
)
int_template = io.Autogrow.TemplatePrefix(
io.Int.Input("int", default=0),
prefix="u_int",
min=0,
max=MAX_UNIFORMS,
)
return io.Schema(
node_id="GLSLShader",
display_name="GLSL Shader",
category="image/shader",
description=(
"Apply GLSL ES fragment shaders to images. "
"u_resolution (vec2) is always available."
),
inputs=[
io.String.Input(
"fragment_shader",
default=DEFAULT_FRAGMENT_SHADER,
multiline=True,
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
),
io.DynamicCombo.Input(
"size_mode",
options=[
io.DynamicCombo.Option("from_input", []),
io.DynamicCombo.Option(
"custom",
[
io.Int.Input(
"width",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
io.Int.Input(
"height",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
],
),
],
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
),
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
],
outputs=[
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
io.Image.Output(display_name="IMAGE1", tooltip="Available via layout(location = 1) out vec4 fragColor1 in the shader code"),
io.Image.Output(display_name="IMAGE2", tooltip="Available via layout(location = 2) out vec4 fragColor2 in the shader code"),
io.Image.Output(display_name="IMAGE3", tooltip="Available via layout(location = 3) out vec4 fragColor3 in the shader code"),
],
)
@classmethod
def execute(
cls,
fragment_shader: str,
size_mode: SizeModeInput,
images: io.Autogrow.Type,
floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None,
**kwargs,
) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None]
float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else []
)
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
if not image_list:
raise ValueError("At least one input image is required")
# Determine output dimensions
if size_mode["size_mode"] == "custom":
out_width = size_mode["width"]
out_height = size_mode["height"]
else:
out_height, out_width = image_list[0].shape[1:3]
batch_size = image_list[0].shape[0]
# Prepare batches
image_batches = []
for batch_idx in range(batch_size):
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
image_batches.append(batch_images)
all_batch_outputs = _render_shader_batch(
fragment_shader,
out_width,
out_height,
image_batches,
float_list,
int_list,
)
# Collect outputs into tensors
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
for batch_outputs in all_batch_outputs:
for i, out_img in enumerate(batch_outputs):
all_outputs[i].append(torch.from_numpy(out_img))
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
return io.NodeOutput(
*output_tensors,
ui=cls._build_ui_output(image_list, output_tensors[0]),
)
@classmethod
def _build_ui_output(
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
) -> dict[str, list]:
"""Build UI output with input and output images for client-side shader execution."""
input_images_ui = []
for img in image_list:
input_images_ui.extend(ui.ImageSaveHelper.save_images(
img,
filename_prefix="GLSLShader_input",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
))
output_images_ui = ui.ImageSaveHelper.save_images(
output_batch,
filename_prefix="GLSLShader_output",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
)
return {"input_images": input_images_ui, "images": output_images_ui}
class GLSLExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [GLSLShader]
async def comfy_entrypoint() -> GLSLExtension:
return GLSLExtension()

View File

@@ -233,8 +233,8 @@ class SetClipHooks:
return {
"required": {
"clip": ("CLIP",),
"apply_to_conds": ("BOOLEAN", {"default": True}),
"schedule_clip": ("BOOLEAN", {"default": False})
"apply_to_conds": ("BOOLEAN", {"default": True, "advanced": True}),
"schedule_clip": ("BOOLEAN", {"default": False, "advanced": True})
},
"optional": {
"hooks": ("HOOKS",)
@@ -512,7 +512,7 @@ class CreateHookKeyframesInterpolated:
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
"print_keyframes": ("BOOLEAN", {"default": False}),
"print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}),
},
"optional": {
"prev_hook_kf": ("HOOK_KEYFRAMES",),
@@ -557,7 +557,7 @@ class CreateHookKeyframesFromFloats:
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"print_keyframes": ("BOOLEAN", {"default": False}),
"print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}),
},
"optional": {
"prev_hook_kf": ("HOOK_KEYFRAMES",),

View File

@@ -138,7 +138,7 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
io.Image.Input("start_image", optional=True),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Latent.Input("latent"),
io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01),
io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01, advanced=True),
],
outputs=[
@@ -285,6 +285,7 @@ class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
min=1,
max=512,
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
advanced=True,
),
],
outputs=[
@@ -313,7 +314,7 @@ class HunyuanImageToVideo(io.ComfyNode):
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"], advanced=True),
io.Image.Input("start_image", optional=True),
],
outputs=[
@@ -384,7 +385,7 @@ class HunyuanRefinerLatent(io.ComfyNode):
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01),
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01, advanced=True),
],
outputs=[

View File

@@ -106,8 +106,8 @@ class VAEDecodeHunyuan3D(IO.ComfyNode):
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Int.Input("num_chunks", default=8000, min=1000, max=500000),
IO.Int.Input("octree_resolution", default=256, min=16, max=512),
IO.Int.Input("num_chunks", default=8000, min=1000, max=500000, advanced=True),
IO.Int.Input("octree_resolution", default=256, min=16, max=512, advanced=True),
],
outputs=[
IO.Voxel.Output(),
@@ -456,7 +456,7 @@ class VoxelToMesh(IO.ComfyNode):
category="3d",
inputs=[
IO.Voxel.Input("voxel"),
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
@@ -621,6 +621,7 @@ class SaveGLB(IO.ComfyNode):
display_name="Save 3D Model",
search_aliases=["export 3d model", "save mesh"],
category="3d",
essentials_category="Basics",
is_output_node=True,
inputs=[
IO.MultiType.Input(

View File

@@ -30,10 +30,10 @@ class HyperTile(io.ComfyNode):
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("tile_size", default=256, min=1, max=2048),
io.Int.Input("swap_size", default=2, min=1, max=128),
io.Int.Input("max_depth", default=0, min=0, max=10),
io.Boolean.Input("scale_depth", default=False),
io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True),
io.Int.Input("swap_size", default=2, min=1, max=128, advanced=True),
io.Int.Input("max_depth", default=0, min=0, max=10, advanced=True),
io.Boolean.Input("scale_depth", default=False, advanced=True),
],
outputs=[
io.Model.Output(),

View File

@@ -6,6 +6,7 @@ import folder_paths
import json
import os
import re
import math
import torch
import comfy.utils
@@ -23,8 +24,10 @@ class ImageCrop(IO.ComfyNode):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
display_name="Image Crop",
display_name="Image Crop (Deprecated)",
category="image/transform",
is_deprecated=True,
essentials_category="Image Tools",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
@@ -47,6 +50,57 @@ class ImageCrop(IO.ComfyNode):
crop = execute # TODO: remove
class ImageCropV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageCropV2",
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, image, crop_region) -> IO.NodeOutput:
x = crop_region.get("x", 0)
y = crop_region.get("y", 0)
width = crop_region.get("width", 512)
height = crop_region.get("height", 512)
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return IO.NodeOutput(img, ui=UI.PreviewImage(img))
class BoundingBox(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PrimitiveBoundingBox",
display_name="Bounding Box",
category="utils/primitive",
inputs=[
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
],
outputs=[IO.BoundingBox.Output()],
)
@classmethod
def execute(cls, x, y, width, height) -> IO.NodeOutput:
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
class RepeatImageBatch(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -175,7 +229,7 @@ class SaveAnimatedPNG(IO.ComfyNode):
IO.Image.Input("images"),
IO.String.Input("filename_prefix", default="ComfyUI"),
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
IO.Int.Input("compress_level", default=4, min=0, max=9),
IO.Int.Input("compress_level", default=4, min=0, max=9, advanced=True),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
@@ -212,8 +266,8 @@ class ImageStitch(IO.ComfyNode):
IO.Image.Input("image1"),
IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
IO.Boolean.Input("match_image_size", default=True),
IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2, advanced=True),
IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white", advanced=True),
IO.Image.Input("image2", optional=True),
],
outputs=[IO.Image.Output()],
@@ -383,8 +437,8 @@ class ResizeAndPadImage(IO.ComfyNode):
IO.Image.Input("image"),
IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Combo.Input("padding_color", options=["white", "black"]),
IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
IO.Combo.Input("padding_color", options=["white", "black"], advanced=True),
IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"], advanced=True),
],
outputs=[IO.Image.Output()],
)
@@ -535,8 +589,10 @@ class ImageRotate(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageRotate",
display_name="Image Rotate",
search_aliases=["turn", "flip orientation"],
category="image/transform",
essentials_category="Image Tools",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
@@ -627,11 +683,179 @@ class ImageScaleToMaxDimension(IO.ComfyNode):
upscale = execute # TODO: remove
class SplitImageToTileList(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SplitImageToTileList",
category="image/batch",
search_aliases=["split image", "tile image", "slice image"],
display_name="Split Image into List of Tiles",
description="Splits an image into a batched list of tiles with a specified overlap.",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION),
IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION),
IO.Int.Input("overlap", default=128, min=0, max=4096),
],
outputs=[
IO.Image.Output(is_output_list=True),
],
)
@staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = []
stride_x = max(1, tile_width - overlap)
stride_y = max(1, tile_height - overlap)
y = 0
while y < height:
x = 0
y_end = min(y + tile_height, height)
y_start = max(0, y_end - tile_height)
while x < width:
x_end = min(x + tile_width, width)
x_start = max(0, x_end - tile_width)
coords.append((x_start, y_start, x_end, y_end))
if x_end >= width:
break
x += stride_x
if y_end >= height:
break
y += stride_y
return coords
@classmethod
def execute(cls, image, tile_width, tile_height, overlap):
b, h, w, c = image.shape
coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap)
output_list = []
for (x_start, y_start, x_end, y_end) in coords:
tile = image[:, y_start:y_end, x_start:x_end, :]
output_list.append(tile)
return IO.NodeOutput(output_list)
class ImageMergeTileList(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageMergeTileList",
display_name="Merge List of Tiles to Image",
category="image/batch",
search_aliases=["split image", "tile image", "slice image"],
is_input_list=True,
inputs=[
IO.Image.Input("image_list"),
IO.Int.Input("final_width", default=1024, min=64, max=32768),
IO.Int.Input("final_height", default=1024, min=64, max=32768),
IO.Int.Input("overlap", default=128, min=0, max=4096),
],
outputs=[
IO.Image.Output(is_output_list=False),
],
)
@staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = []
stride_x = max(1, tile_width - overlap)
stride_y = max(1, tile_height - overlap)
y = 0
while y < height:
x = 0
y_end = min(y + tile_height, height)
y_start = max(0, y_end - tile_height)
while x < width:
x_end = min(x + tile_width, width)
x_start = max(0, x_end - tile_width)
coords.append((x_start, y_start, x_end, y_end))
if x_end >= width:
break
x += stride_x
if y_end >= height:
break
y += stride_y
return coords
@classmethod
def execute(cls, image_list, final_width, final_height, overlap):
w = final_width[0]
h = final_height[0]
ovlp = overlap[0]
feather_str = 1.0
first_tile = image_list[0]
b, t_h, t_w, c = first_tile.shape
device = first_tile.device
dtype = first_tile.dtype
coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp)
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
if ovlp > 0:
y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype))
x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype))
y_w = torch.clamp(y_w, min=1e-5)
x_w = torch.clamp(x_w, min=1e-5)
sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1)
flat_mask = torch.ones_like(sine_mask)
weight_mask = torch.lerp(flat_mask, sine_mask, feather_str)
else:
weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype)
for i, (x_start, y_start, x_end, y_end) in enumerate(coords):
if i >= len(image_list):
break
tile = image_list[i]
region_h = y_end - y_start
region_w = x_end - x_start
real_h = min(region_h, tile.shape[1])
real_w = min(region_w, tile.shape[2])
y_end_actual = y_start + real_h
x_end_actual = x_start + real_w
tile_crop = tile[:, :real_h, :real_w, :]
mask_crop = weight_mask[:, :real_h, :real_w, :]
canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop
weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop
weights[weights == 0] = 1.0
merged_image = canvas / weights
return IO.NodeOutput(merged_image)
class ImagesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCrop,
ImageCropV2,
BoundingBox,
RepeatImageBatch,
ImageFromBatch,
ImageAddNoise,
@@ -644,6 +868,8 @@ class ImagesExtension(ComfyExtension):
ImageRotate,
ImageFlip,
ImageScaleToMaxDimension,
SplitImageToTileList,
ImageMergeTileList,
]

View File

@@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
dims = list(range(1, latent_vector_magnitude.ndim))
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
top = (std * 5 + mean) * multiplier
@@ -412,9 +413,9 @@ class LatentOperationSharpen(io.ComfyNode):
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1, advanced=True),
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01, advanced=True),
],
outputs=[
io.LatentOperation.Output(),

View File

@@ -31,6 +31,7 @@ class Load3D(IO.ComfyNode):
node_id="Load3D",
display_name="Load 3D & Animation",
category="3d",
essentials_category="Basics",
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
@@ -97,8 +98,8 @@ class Preview3D(IO.ComfyNode):
],
tooltip="3D model file or path string",
),
IO.Load3DCamera.Input("camera_info", optional=True),
IO.Image.Input("bg_image", optional=True),
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
IO.Image.Input("bg_image", optional=True, advanced=True),
],
outputs=[],
)

View File

@@ -7,6 +7,7 @@ import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from tqdm.auto import trange
CLAMP_QUANTILE = 0.99
@@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
comfy.model_management.load_models_gpu([model_diff])
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
sd_keys = list(sd.keys())
for index in trange(len(sd_keys), unit="weight"):
k = sd_keys[index]
op_keys = sd_keys[index].rsplit('.', 1)
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
continue
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
else:
weight_diff = sd[k]
if op_keys[1] == "weight":
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
@@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
elif bias_diff and op_keys[1] == "bias":
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):
@@ -83,9 +94,9 @@ class LoraSave(io.ComfyNode):
category="_for_testing",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
io.Boolean.Input("bias_diff", default=True),
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys()), advanced=True),
io.Boolean.Input("bias_diff", default=True, advanced=True),
io.Model.Input(
"model_diff",
tooltip="The ModelSubtract output to be converted to a lora.",

View File

@@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
Entries are derived independently from each conditioning to avoid cross-contamination.
"""
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
existing = []
for t in cond:
found = t[1].get("guide_attention_entries", None)
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
))
return results[0], results[1]
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
@@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
scale_factors,
)
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
@@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
positive = node_helpers.conditioning_set_values(positive, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
negative = node_helpers.conditioning_set_values(negative, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
@@ -450,6 +493,7 @@ class LTXVScheduler(io.ComfyNode):
id="stretch",
default=True,
tooltip="Stretch the sigmas to be in the range [terminal, 1].",
advanced=True,
),
io.Float.Input(
id="terminal",
@@ -458,6 +502,7 @@ class LTXVScheduler(io.ComfyNode):
max=0.99,
step=0.01,
tooltip="The terminal value of the sigmas after stretching.",
advanced=True,
),
io.Latent.Input("latent", optional=True),
],

View File

@@ -189,6 +189,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
io.Combo.Input(
"device",
options=["default", "cpu"],
advanced=True,
)
],
outputs=[io.Clip.Output()],

View File

@@ -12,8 +12,8 @@ class RenormCFG(io.ComfyNode):
category="advanced/model",
inputs=[
io.Model.Input("model"),
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01, advanced=True),
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01, advanced=True),
],
outputs=[
io.Model.Output(),

View File

@@ -348,7 +348,7 @@ class GrowMask(IO.ComfyNode):
inputs=[
IO.Mask.Input("mask"),
IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("tapered_corners", default=True),
IO.Boolean.Input("tapered_corners", default=True, advanced=True),
],
outputs=[IO.Mask.Output()],
)

View File

@@ -52,8 +52,8 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
"zsnr": ("BOOLEAN", {"default": False}),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],),
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
@@ -76,6 +76,8 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.X0
elif sampling == "img_to_img":
sampling_type = comfy.model_sampling.IMG_TO_IMG
elif sampling == "img_to_img_flow":
sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
@@ -153,8 +155,8 @@ class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01, "advanced": True}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "advanced": True}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}}
@@ -190,8 +192,8 @@ class ModelSamplingContinuousEDM:
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
@@ -235,8 +237,8 @@ class ModelSamplingContinuousV:
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction"],),
"sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}),
"sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
@@ -303,7 +305,7 @@ class ModelComputeDtype:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"dtype": (["default", "fp32", "fp16", "bf16"],),
"dtype": (["default", "fp32", "fp16", "bf16"], {"advanced": True}),
}}
RETURN_TYPES = ("MODEL",)

View File

@@ -13,11 +13,11 @@ class PatchModelAddDownscale(io.ComfyNode):
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
io.Int.Input("block_number", default=3, min=1, max=32, step=1, advanced=True),
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
io.Boolean.Input("downscale_after_skip", default=True),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001, advanced=True),
io.Boolean.Input("downscale_after_skip", default=True, advanced=True),
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],

99
comfy_extras/nodes_nag.py Normal file
View File

@@ -0,0 +1,99 @@
import torch
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class NAGuidance(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="NAGuidance",
display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="advanced/guidance",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."),
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
],
outputs=[
io.Model.Output(tooltip="The patched model with NAG enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
m = model.clone()
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
def nag_attention_output_patch(out, extra_options):
cond_or_uncond = extra_options.get("cond_or_uncond", None)
if cond_or_uncond is None:
return out
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
return out
# sigma = extra_options.get("sigmas", None)
# if sigma is not None and len(sigma) > 0:
# sigma = sigma[0].item()
# if sigma > sigma_start or sigma < sigma_end:
# return out
img_slice = extra_options.get("img_slice", None)
if img_slice is not None:
orig_out = out
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
batch_size = out.shape[0]
half_size = batch_size // len(cond_or_uncond)
ind_neg = cond_or_uncond.index(1)
ind_pos = cond_or_uncond.index(0)
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
eps = 1e-6
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
ratio = norm_guided / norm_pos
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
guided_normalized = guided * scale_factor
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
if img_slice is not None:
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
return orig_out
else:
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
return out
m.set_model_attn1_output_patch(nag_attention_output_patch)
m.disable_model_cfg1_optimization()
return io.NodeOutput(m)
class NagExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
NAGuidance,
]
async def comfy_entrypoint() -> NagExtension:
return NagExtension()

View File

@@ -29,7 +29,7 @@ class PerpNeg(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("empty_conditioning"),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01, advanced=True),
],
outputs=[
io.Model.Output(),
@@ -134,7 +134,7 @@ class PerpNegGuider(io.ComfyNode):
io.Conditioning.Input("negative"),
io.Conditioning.Input("empty_conditioning"),
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01, advanced=True),
],
outputs=[
io.Guider.Output(),

View File

@@ -19,6 +19,7 @@ class Blend(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageBlend",
display_name="Image Blend",
category="image/postprocessing",
inputs=[
io.Image.Input("image1"),
@@ -76,6 +77,7 @@ class Blur(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageBlur",
display_name="Image Blur",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
@@ -179,9 +181,9 @@ class Sharpen(io.ComfyNode):
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01),
io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01),
io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1, advanced=True),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01, advanced=True),
io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01, advanced=True),
],
outputs=[
io.Image.Output(),
@@ -225,7 +227,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
io.Int.Input("resolution_steps", default=1, min=1, max=256, advanced=True),
],
outputs=[
io.Image.Output(),
@@ -565,6 +567,7 @@ class BatchImagesNode(io.ComfyNode):
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
essentials_category="Image Tools",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[
io.Autogrow.Input("images", template=autogrow_template)
@@ -655,6 +658,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@@ -29,6 +29,7 @@ class StringMultiline(io.ComfyNode):
node_id="PrimitiveStringMultiline",
display_name="String (Multiline)",
category="utils/primitive",
essentials_category="Basics",
inputs=[
io.String.Input("value", multiline=True),
],

View File

@@ -116,7 +116,7 @@ class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
inputs=[
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1, advanced=True),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[

View File

@@ -0,0 +1,103 @@
from comfy_api.latest import ComfyExtension, io, ComfyAPI
api = ComfyAPI()
async def register_replacements():
"""Register all built-in node replacements."""
await register_replacements_longeredge()
await register_replacements_batchimages()
await register_replacements_upscaleimage()
await register_replacements_controlnet()
await register_replacements_load3d()
await register_replacements_preview3d()
await register_replacements_svdimg2vid()
await register_replacements_conditioningavg()
async def register_replacements_longeredge():
# No dynamic inputs here
await api.node_replacement.register(io.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
{"new_id": "image", "old_id": "images"},
{"new_id": "largest_size", "old_id": "longer_edge"},
{"new_id": "upscale_method", "set_value": "lanczos"},
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[{"new_idx": 0, "old_idx": 0}],
))
async def register_replacements_batchimages():
# BatchImages node uses Autogrow
await api.node_replacement.register(io.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
{"new_id": "images.image0", "old_id": "image1"},
{"new_id": "images.image1", "old_id": "image2"},
],
))
async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
await api.node_replacement.register(io.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
{"new_id": "input", "old_id": "image"},
{"new_id": "resize_type", "set_value": "scale by multiplier"},
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
{"new_id": "scale_method", "old_id": "upscale_method"},
],
))
async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
await api.node_replacement.register(io.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
],
))
async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
await api.node_replacement.register(io.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
await api.node_replacement.register(io.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
class NodeReplacementsExtension(ComfyExtension):
async def on_load(self) -> None:
await register_replacements()
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return []
async def comfy_entrypoint() -> NodeReplacementsExtension:
return NodeReplacementsExtension()

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import math
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class AspectRatio(str, Enum):
SQUARE = "1:1 (Square)"
PHOTO_H = "3:2 (Photo)"
STANDARD_H = "4:3 (Standard)"
WIDESCREEN_H = "16:9 (Widescreen)"
ULTRAWIDE_H = "21:9 (Ultrawide)"
PHOTO_V = "2:3 (Portrait Photo)"
STANDARD_V = "3:4 (Portrait Standard)"
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
ASPECT_RATIOS: dict[str, tuple[int, int]] = {
"1:1 (Square)": (1, 1),
"3:2 (Photo)": (3, 2),
"4:3 (Standard)": (4, 3),
"16:9 (Widescreen)": (16, 9),
"21:9 (Ultrawide)": (21, 9),
"2:3 (Portrait Photo)": (2, 3),
"3:4 (Portrait Standard)": (3, 4),
"9:16 (Portrait Widescreen)": (9, 16),
}
class ResolutionSelector(io.ComfyNode):
"""Calculate width and height from aspect ratio and megapixel target."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionSelector",
display_name="Resolution Selector",
category="utils",
description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.",
inputs=[
io.Combo.Input(
"aspect_ratio",
options=AspectRatio,
default=AspectRatio.SQUARE,
tooltip="The aspect ratio for the output dimensions.",
),
io.Float.Input(
"megapixels",
default=1.0,
min=0.1,
max=16.0,
step=0.1,
tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.",
),
],
outputs=[
io.Int.Output("width", tooltip="Calculated width in pixels (multiple of 8)."),
io.Int.Output("height", tooltip="Calculated height in pixels (multiple of 8)."),
],
)
@classmethod
def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput:
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
total_pixels = megapixels * 1024 * 1024
scale = math.sqrt(total_pixels / (w_ratio * h_ratio))
width = round(w_ratio * scale / 8) * 8
height = round(h_ratio * scale / 8) * 8
return io.NodeOutput(width, height)
class ResolutionExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ResolutionSelector,
]
async def comfy_entrypoint() -> ResolutionExtension:
return ResolutionExtension()

View File

@@ -12,14 +12,14 @@ class ScaleROPE(io.ComfyNode):
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1, advanced=True),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1, advanced=True),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1, advanced=True),
],

View File

@@ -117,7 +117,7 @@ class SelfAttentionGuidance(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1, advanced=True),
],
outputs=[
io.Model.Output(),

View File

@@ -72,7 +72,7 @@ class CLIPTextEncodeSD3(io.ComfyNode):
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
io.Combo.Input("empty_padding", options=["none", "empty_prompt"], advanced=True),
],
outputs=[
io.Conditioning.Output(),
@@ -179,10 +179,10 @@ class SkipLayerGuidanceSD3(io.ComfyNode):
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
inputs=[
io.Model.Input("model"),
io.String.Input("layers", default="7, 8, 9", multiline=False),
io.String.Input("layers", default="7, 8, 9", multiline=False, advanced=True),
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[
io.Model.Output(),

View File

@@ -0,0 +1,740 @@
import torch
import comfy.utils
import numpy as np
import math
import colorsys
from tqdm import tqdm
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_lotus import LotusConditioning
def _preprocess_keypoints(kp_raw, sc_raw):
"""Insert neck keypoint and remap from MMPose to OpenPose ordering.
Returns (kp, sc) where kp has shape (134, 2) and sc has shape (134,).
Layout:
0-17 body (18 kp, OpenPose order)
18-23 feet (6 kp)
24-91 face (68 kp)
92-112 right hand (21 kp)
113-133 left hand (21 kp)
"""
kp = np.array(kp_raw, dtype=np.float32)
sc = np.array(sc_raw, dtype=np.float32)
if len(kp) >= 17:
neck = (kp[5] + kp[6]) / 2
neck_score = min(sc[5], sc[6]) if sc[5] > 0.3 and sc[6] > 0.3 else 0
kp = np.insert(kp, 17, neck, axis=0)
sc = np.insert(sc, 17, neck_score)
mmpose_idx = np.array([17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3])
openpose_idx = np.array([ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17])
tmp_kp, tmp_sc = kp.copy(), sc.copy()
tmp_kp[openpose_idx] = kp[mmpose_idx]
tmp_sc[openpose_idx] = sc[mmpose_idx]
kp, sc = tmp_kp, tmp_sc
return kp, sc
def _to_openpose_frames(all_keypoints, all_scores, height, width):
"""Convert raw keypoint lists to a list of OpenPose-style frame dicts.
Each frame dict contains:
canvas_width, canvas_height, people: list of person dicts with keys:
pose_keypoints_2d - 18 body kp as flat [x,y,score,...] (absolute pixels)
foot_keypoints_2d - 6 foot kp as flat [x,y,score,...] (absolute pixels)
face_keypoints_2d - 70 face kp as flat [x,y,score,...] (absolute pixels)
indices 0-67: 68 face landmarks
index 68: right eye (body[14])
index 69: left eye (body[15])
hand_right_keypoints_2d - 21 right-hand kp (absolute pixels)
hand_left_keypoints_2d - 21 left-hand kp (absolute pixels)
"""
def _flatten(kp_slice, sc_slice):
return np.stack([kp_slice[:, 0], kp_slice[:, 1], sc_slice], axis=1).flatten().tolist()
frames = []
for img_idx in range(len(all_keypoints)):
people = []
for kp_raw, sc_raw in zip(all_keypoints[img_idx], all_scores[img_idx]):
kp, sc = _preprocess_keypoints(kp_raw, sc_raw)
# 70 face kp = 68 face landmarks + REye (body[14]) + LEye (body[15])
face_kp = np.concatenate([kp[24:92], kp[[14, 15]]], axis=0)
face_sc = np.concatenate([sc[24:92], sc[[14, 15]]], axis=0)
people.append({
"pose_keypoints_2d": _flatten(kp[0:18], sc[0:18]),
"foot_keypoints_2d": _flatten(kp[18:24], sc[18:24]),
"face_keypoints_2d": _flatten(face_kp, face_sc),
"hand_right_keypoints_2d": _flatten(kp[92:113], sc[92:113]),
"hand_left_keypoints_2d": _flatten(kp[113:134], sc[113:134]),
})
frames.append({"canvas_width": width, "canvas_height": height, "people": people})
return frames
class KeypointDraw:
"""
Pose keypoint drawing class that supports both numpy and cv2 backends.
"""
def __init__(self):
try:
import cv2
self.draw = cv2
except ImportError:
self.draw = self
# Hand connections (same for both hands)
self.hand_edges = [
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
[0, 5], [5, 6], [6, 7], [7, 8], # index
[0, 9], [9, 10], [10, 11], [11, 12], # middle
[0, 13], [13, 14], [14, 15], [15, 16], # ring
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
]
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
self.body_limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
[1, 16], [16, 18]
]
# Colors matching DWPose
self.colors = [
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
]
@staticmethod
def circle(canvas_np, center, radius, color, **kwargs):
"""Draw a filled circle using NumPy vectorized operations."""
cx, cy = center
h, w = canvas_np.shape[:2]
radius_int = int(np.ceil(radius))
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
if y_max <= y_min or x_max <= x_min:
return
y, x = np.ogrid[y_min:y_max, x_min:x_max]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
"""Draw line using Bresenham's algorithm with NumPy operations."""
x0, y0, x1, y1 = *pt1, *pt2
h, w = canvas_np.shape[:2]
dx, dy = abs(x1 - x0), abs(y1 - y0)
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
err, x, y, line_points = dx - dy, x0, y0, []
while True:
line_points.append((x, y))
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 > -dy:
err, x = err - dy, x + sx
if e2 < dx:
err, y = err + dx, y + sy
if thickness > 1:
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
for px, py in line_points:
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
if y_max > y_min and x_max > x_min:
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
else:
line_points = np.array(line_points)
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
if (valid_points := line_points[valid]).size:
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
@staticmethod
def fillConvexPoly(canvas_np, pts, color, **kwargs):
"""Fill polygon using vectorized scanline algorithm."""
if len(pts) < 3:
return
pts = np.array(pts, dtype=np.int32)
h, w = canvas_np.shape[:2]
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
if y_max <= y_min or x_max <= x_min:
return
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
for i in range(len(pts)):
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
y1, y2 = p1[1], p2[1]
if y1 == y2:
continue
if y1 > y2:
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
continue
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
"""Python implementation of cv2.ellipse2Poly."""
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
angle = angle % 360
if arc_start > arc_end:
arc_start, arc_end = arc_end, arc_start
while arc_start < 0:
arc_start, arc_end = arc_start + 360, arc_end + 360
while arc_end > 360:
arc_end, arc_start = arc_end - 360, arc_start - 360
if arc_end - arc_start > 360:
arc_start, arc_end = 0, 360
angle_rad = math.radians(angle)
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
pts = []
for i in range(arc_start, arc_end + delta, delta):
theta_rad = math.radians(min(i, arc_end))
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
unique_pts, prev_pt = [], (float('inf'), float('inf'))
for pt in pts:
if (pt_tuple := tuple(pt)) != prev_pt:
unique_pts.append(pt)
prev_pt = pt_tuple
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3):
"""
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
Expected keypoint format (after neck insertion and remapping):
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
- Foot: 18-23 (6 keypoints)
- Face: 24-91 (68 landmarks)
- Right hand: 92-112 (21 keypoints)
- Left hand: 113-133 (21 keypoints)
Args:
canvas: The canvas to draw on (numpy array)
keypoints: Array of keypoint coordinates
scores: Optional confidence scores for each keypoint
threshold: Minimum confidence threshold for drawing keypoints
Returns:
canvas: The canvas with keypoints drawn
"""
H, W, C = canvas.shape
# Draw body limbs
if draw_body and len(keypoints) >= 18:
for i, limb in enumerate(self.body_limbSeq):
# Convert from 1-indexed to 0-indexed
idx1, idx2 = limb[0] - 1, limb[1] - 1
if idx1 >= 18 or idx2 >= 18:
continue
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
Y = [keypoints[idx1][0], keypoints[idx2][0]]
X = [keypoints[idx1][1], keypoints[idx2][1]]
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
if length < 1:
continue
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)])
# Draw body keypoints
if draw_body and len(keypoints) >= 18:
for i in range(18):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw foot keypoints (18-23, 6 keypoints)
if draw_feet and len(keypoints) >= 24:
for i in range(18, 24):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw right hand (92-112)
if draw_hands and len(keypoints) >= 113:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 92 + edge[0], 92 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw right hand keypoints
for i in range(92, 113):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw left hand (113-133)
if draw_hands and len(keypoints) >= 134:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 113 + edge[0], 113 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw left hand keypoints
for i in range(113, 134):
if scores is not None and i < len(scores) and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw face keypoints (24-91) - white dots only, no lines
if draw_face and len(keypoints) >= 92:
eps = 0.01
for i in range(24, 92):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
return canvas
class SDPoseDrawKeypoints(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SDPoseDrawKeypoints",
category="image/preprocessors",
search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "pose"],
inputs=[
io.Custom("POSE_KEYPOINT").Input("keypoints"),
io.Boolean.Input("draw_body", default=True),
io.Boolean.Input("draw_hands", default=True),
io.Boolean.Input("draw_face", default=True),
io.Boolean.Input("draw_feet", default=False),
io.Int.Input("stick_width", default=4, min=1, max=10, step=1),
io.Int.Input("face_point_size", default=3, min=1, max=10, step=1),
io.Float.Input("score_threshold", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, keypoints, draw_body, draw_hands, draw_face, draw_feet, stick_width, face_point_size, score_threshold) -> io.NodeOutput:
if not keypoints:
return io.NodeOutput(torch.zeros((1, 64, 64, 3), dtype=torch.float32))
height = keypoints[0]["canvas_height"]
width = keypoints[0]["canvas_width"]
def _parse(flat, n):
arr = np.array(flat, dtype=np.float32).reshape(n, 3)
return arr[:, :2], arr[:, 2]
def _zeros(n):
return np.zeros((n, 2), dtype=np.float32), np.zeros(n, dtype=np.float32)
pose_outputs = []
drawer = KeypointDraw()
for frame in tqdm(keypoints, desc="Drawing keypoints on frames"):
canvas = np.zeros((height, width, 3), dtype=np.uint8)
for person in frame["people"]:
body_kp, body_sc = _parse(person["pose_keypoints_2d"], 18)
foot_raw = person.get("foot_keypoints_2d")
foot_kp, foot_sc = _parse(foot_raw, 6) if foot_raw else _zeros(6)
face_kp, face_sc = _parse(person["face_keypoints_2d"], 70)
face_kp, face_sc = face_kp[:68], face_sc[:68] # drop appended eye kp; body already draws them
rhand_kp, rhand_sc = _parse(person["hand_right_keypoints_2d"], 21)
lhand_kp, lhand_sc = _parse(person["hand_left_keypoints_2d"], 21)
kp = np.concatenate([body_kp, foot_kp, face_kp, rhand_kp, lhand_kp], axis=0)
sc = np.concatenate([body_sc, foot_sc, face_sc, rhand_sc, lhand_sc], axis=0)
canvas = drawer.draw_wholebody_keypoints(
canvas, kp, sc,
threshold=score_threshold,
draw_body=draw_body, draw_feet=draw_feet,
draw_face=draw_face, draw_hands=draw_hands,
stick_width=stick_width, face_point_size=face_point_size,
)
pose_outputs.append(canvas)
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
return io.NodeOutput(final_pose_output)
class SDPoseKeypointExtractor(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SDPoseKeypointExtractor",
category="image/preprocessors",
search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "sdpose"],
description="Extract pose keypoints from images using the SDPose model: https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints",
inputs=[
io.Model.Input("model"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Int.Input("batch_size", default=16, min=1, max=10000, step=1),
io.BoundingBox.Input("bboxes", optional=True, force_input=True, tooltip="Optional bounding boxes for more accurate detections. Required for multi-person detection."),
],
outputs=[
io.Custom("POSE_KEYPOINT").Output("keypoints", tooltip="Keypoints in OpenPose frame format (canvas_width, canvas_height, people)"),
],
)
@classmethod
def execute(cls, model, vae, image, batch_size, bboxes=None) -> io.NodeOutput:
height, width = image.shape[-3], image.shape[-2]
context = LotusConditioning().execute().result[0]
# Use output_block_patch to capture the last 640-channel feature
def output_patch(h, hsp, transformer_options):
nonlocal captured_feat
if h.shape[1] == 640: # Capture the features for wholebody
captured_feat = h.clone()
return h, hsp
model_clone = model.clone()
model_clone.model_options["transformer_options"] = {"patches": {"output_block_patch": [output_patch]}}
if not hasattr(model.model.diffusion_model, 'heatmap_head'):
raise ValueError("The provided model does not have a heatmap_head. Please use SDPose model from here https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints.")
head = model.model.diffusion_model.heatmap_head
total_images = image.shape[0]
captured_feat = None
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
def _run_on_latent(latent_batch):
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
nonlocal captured_feat
captured_feat = None
_ = comfy.sample.sample(
model_clone,
noise=torch.zeros_like(latent_batch),
steps=1, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=context, negative=context,
latent_image=latent_batch, disable_noise=True, disable_pbar=True,
)
return head(captured_feat) # keypoints_batch, scores_batch
# all_keypoints / all_scores are lists-of-lists:
# outer index = input image index
# inner index = detected person (one per bbox, or one for full-image)
all_keypoints = [] # shape: [n_images][n_persons]
all_scores = [] # shape: [n_images][n_persons]
pbar = comfy.utils.ProgressBar(total_images)
if bboxes is not None:
if not isinstance(bboxes, list):
bboxes = [[bboxes]]
elif len(bboxes) == 0:
bboxes = [None] * total_images
# --- bbox-crop mode: one forward pass per crop -------------------------
for img_idx in tqdm(range(total_images), desc="Extracting keypoints from crops"):
img = image[img_idx:img_idx + 1] # (1, H, W, C)
# Broadcasting: if fewer bbox lists than images, repeat the last one.
img_bboxes = bboxes[min(img_idx, len(bboxes) - 1)] if bboxes else None
img_keypoints = []
img_scores = []
if img_bboxes:
for bbox in img_bboxes:
x1 = max(0, int(bbox["x"]))
y1 = max(0, int(bbox["y"]))
x2 = min(width, int(bbox["x"] + bbox["width"]))
y2 = min(height, int(bbox["y"] + bbox["height"]))
if x2 <= x1 or y2 <= y1:
continue
crop_h_px, crop_w_px = y2 - y1, x2 - x1
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
scale = min(model_h / crop_h_px, model_w / crop_w_px)
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
latent_crop = vae.encode(crop_resized)
kp_batch, sc_batch = _run_on_latent(latent_crop)
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
# remove padding offset, undo scale, offset to full-image coordinates.
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
img_keypoints.append(kp)
img_scores.append(sc)
else:
# No bboxes for this image run on the full image
latent_img = vae.encode(img)
kp_batch, sc_batch = _run_on_latent(latent_img)
img_keypoints.append(kp_batch[0])
img_scores.append(sc_batch[0])
all_keypoints.append(img_keypoints)
all_scores.append(img_scores)
pbar.update(1)
else: # full-image mode, batched
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
for batch_start in range(0, total_images, batch_size):
batch_end = min(batch_start + batch_size, total_images)
latent_batch = vae.encode(image[batch_start:batch_end])
kp_batch, sc_batch = _run_on_latent(latent_batch)
for kp, sc in zip(kp_batch, sc_batch):
all_keypoints.append([kp])
all_scores.append([sc])
tqdm_pbar.update(1)
pbar.update(batch_end - batch_start)
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
return io.NodeOutput(openpose_frames)
def get_face_bboxes(kp2ds, scale, image_shape):
h, w = image_shape
kp2ds_face = kp2ds.copy()[1:] * (w, h)
min_x, min_y = np.min(kp2ds_face, axis=0)
max_x, max_y = np.max(kp2ds_face, axis=0)
initial_width = max_x - min_x
initial_height = max_y - min_y
if initial_width <= 0 or initial_height <= 0:
return [0, 0, 0, 0]
initial_area = initial_width * initial_height
expanded_area = initial_area * scale
new_width = np.sqrt(expanded_area * (initial_width / initial_height))
new_height = np.sqrt(expanded_area * (initial_height / initial_width))
delta_width = (new_width - initial_width) / 2
delta_height = (new_height - initial_height) / 4
expanded_min_x = max(min_x - delta_width, 0)
expanded_max_x = min(max_x + delta_width, w)
expanded_min_y = max(min_y - 3 * delta_height, 0)
expanded_max_y = min(max_y + delta_height, h)
return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
class SDPoseFaceBBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SDPoseFaceBBoxes",
category="image/preprocessors",
search_aliases=["face bbox", "face bounding box", "pose", "keypoints"],
inputs=[
io.Custom("POSE_KEYPOINT").Input("keypoints"),
io.Float.Input("scale", default=1.5, min=1.0, max=10.0, step=0.1, tooltip="Multiplier for the bounding box area around each detected face."),
io.Boolean.Input("force_square", default=True, tooltip="Expand the shorter bbox axis so the crop region is always square."),
],
outputs=[
io.BoundingBox.Output("bboxes", tooltip="Face bounding boxes per frame, compatible with SDPoseKeypointExtractor bboxes input."),
],
)
@classmethod
def execute(cls, keypoints, scale, force_square) -> io.NodeOutput:
all_bboxes = []
for frame in keypoints:
h = frame["canvas_height"]
w = frame["canvas_width"]
frame_bboxes = []
for person in frame["people"]:
face_flat = person.get("face_keypoints_2d", [])
if not face_flat:
continue
# Parse absolute-pixel face keypoints (70 kp: 68 landmarks + REye + LEye)
face_arr = np.array(face_flat, dtype=np.float32).reshape(-1, 3)
face_xy = face_arr[:, :2] # (70, 2) in absolute pixels
kp_norm = face_xy / np.array([w, h], dtype=np.float32)
kp_padded = np.vstack([np.zeros((1, 2), dtype=np.float32), kp_norm]) # (71, 2)
x1, x2, y1, y2 = get_face_bboxes(kp_padded, scale, (h, w))
if x2 > x1 and y2 > y1:
if force_square:
bw, bh = x2 - x1, y2 - y1
if bw != bh:
side = max(bw, bh)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
half = side // 2
x1 = max(0, cx - half)
y1 = max(0, cy - half)
x2 = min(w, x1 + side)
y2 = min(h, y1 + side)
# Re-anchor if clamped
x1 = max(0, x2 - side)
y1 = max(0, y2 - side)
frame_bboxes.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1})
all_bboxes.append(frame_bboxes)
return io.NodeOutput(all_bboxes)
class CropByBBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CropByBBoxes",
category="image/preprocessors",
search_aliases=["crop", "face crop", "bbox crop", "pose", "bounding box"],
description="Crop and resize regions from the input image batch based on provided bounding boxes.",
inputs=[
io.Image.Input("image"),
io.BoundingBox.Input("bboxes", force_input=True),
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
],
outputs=[
io.Image.Output(tooltip="All crops stacked into a single image batch."),
],
)
@classmethod
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
total_frames = image.shape[0]
img_h = image.shape[1]
img_w = image.shape[2]
num_ch = image.shape[3]
if not isinstance(bboxes, list):
bboxes = [[bboxes]]
elif len(bboxes) == 0:
return io.NodeOutput(image)
crops = []
for frame_idx in range(total_frames):
frame_bboxes = bboxes[min(frame_idx, len(bboxes) - 1)]
if not frame_bboxes:
continue
frame_chw = image[frame_idx].permute(2, 0, 1).unsqueeze(0) # BHWC → BCHW (1, C, H, W)
# Union all bboxes for this frame into a single crop region
x1 = min(int(b["x"]) for b in frame_bboxes)
y1 = min(int(b["y"]) for b in frame_bboxes)
x2 = max(int(b["x"] + b["width"]) for b in frame_bboxes)
y2 = max(int(b["y"] + b["height"]) for b in frame_bboxes)
if padding > 0:
x1 = max(0, x1 - padding)
y1 = max(0, y1 - padding)
x2 = min(img_w, x2 + padding)
y2 = min(img_h, y2 + padding)
x1, x2 = max(0, x1), min(img_w, x2)
y1, y2 = max(0, y1), min(img_h, y2)
# Fallback for empty/degenerate crops
if x2 <= x1 or y2 <= y1:
fallback_size = int(min(img_h, img_w) * 0.3)
fb_x1 = max(0, (img_w - fallback_size) // 2)
fb_y1 = max(0, int(img_h * 0.1))
fb_x2 = min(img_w, fb_x1 + fallback_size)
fb_y2 = min(img_h, fb_y1 + fallback_size)
if fb_x2 <= fb_x1 or fb_y2 <= fb_y1:
crops.append(torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device))
continue
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
crops.append(resized)
if not crops:
return io.NodeOutput(image)
out_images = torch.cat(crops, dim=0).permute(0, 2, 3, 1) # (N, H, W, C)
return io.NodeOutput(out_images)
class SDPoseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SDPoseKeypointExtractor,
SDPoseDrawKeypoints,
SDPoseFaceBBoxes,
CropByBBoxes,
]
async def comfy_entrypoint() -> SDPoseExtension:
return SDPoseExtension()

View File

@@ -15,7 +15,7 @@ class SD_4XUpscale_Conditioning(io.ComfyNode):
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),

View File

@@ -21,11 +21,11 @@ class SkipLayerGuidanceDiT(io.ComfyNode):
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.String.Input("double_layers", default="7, 8, 9"),
io.String.Input("single_layers", default="7, 8, 9"),
io.String.Input("double_layers", default="7, 8, 9", advanced=True),
io.String.Input("single_layers", default="7, 8, 9", advanced=True),
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
@@ -101,10 +101,10 @@ class SkipLayerGuidanceDiTSimple(io.ComfyNode):
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.String.Input("double_layers", default="7, 8, 9"),
io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
io.String.Input("double_layers", default="7, 8, 9", advanced=True),
io.String.Input("single_layers", default="7, 8, 9", advanced=True),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[
io.Model.Output(),

View File

@@ -75,8 +75,8 @@ class StableZero123_Conditioning_Batched(io.ComfyNode):
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False, advanced=True),
io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False, advanced=True)
],
outputs=[
io.Conditioning.Output(display_name="positive"),

View File

@@ -33,7 +33,7 @@ class StableCascade_EmptyLatentImage(io.ComfyNode):
inputs=[
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("compression", default=42, min=4, max=128, step=1, advanced=True),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
@@ -62,7 +62,7 @@ class StableCascade_StageC_VAEEncode(io.ComfyNode):
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("compression", default=42, min=4, max=128, step=1, advanced=True),
],
outputs=[
io.Latent.Output(display_name="stage_c"),

View File

@@ -169,7 +169,7 @@ class StringContains(io.ComfyNode):
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("substring", multiline=True),
io.Boolean.Input("case_sensitive", default=True),
io.Boolean.Input("case_sensitive", default=True, advanced=True),
],
outputs=[
io.Boolean.Output(display_name="contains"),
@@ -198,7 +198,7 @@ class StringCompare(io.ComfyNode):
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),
io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]),
io.Boolean.Input("case_sensitive", default=True),
io.Boolean.Input("case_sensitive", default=True, advanced=True),
],
outputs=[
io.Boolean.Output(),
@@ -233,9 +233,9 @@ class RegexMatch(io.ComfyNode):
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.Boolean.Input("case_insensitive", default=True),
io.Boolean.Input("multiline", default=False),
io.Boolean.Input("dotall", default=False),
io.Boolean.Input("case_insensitive", default=True, advanced=True),
io.Boolean.Input("multiline", default=False, advanced=True),
io.Boolean.Input("dotall", default=False, advanced=True),
],
outputs=[
io.Boolean.Output(display_name="matches"),
@@ -275,10 +275,10 @@ class RegexExtract(io.ComfyNode):
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]),
io.Boolean.Input("case_insensitive", default=True),
io.Boolean.Input("multiline", default=False),
io.Boolean.Input("dotall", default=False),
io.Int.Input("group_index", default=1, min=0, max=100),
io.Boolean.Input("case_insensitive", default=True, advanced=True),
io.Boolean.Input("multiline", default=False, advanced=True),
io.Boolean.Input("dotall", default=False, advanced=True),
io.Int.Input("group_index", default=1, min=0, max=100, advanced=True),
],
outputs=[
io.String.Output(),
@@ -351,10 +351,10 @@ class RegexReplace(io.ComfyNode):
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.String.Input("replace", multiline=True),
io.Boolean.Input("case_insensitive", default=True, optional=True),
io.Boolean.Input("multiline", default=False, optional=True),
io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."),
io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."),
io.Boolean.Input("case_insensitive", default=True, optional=True, advanced=True),
io.Boolean.Input("multiline", default=False, optional=True, advanced=True),
io.Boolean.Input("dotall", default=False, optional=True, advanced=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."),
io.Int.Input("count", default=0, min=0, max=100, optional=True, advanced=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."),
],
outputs=[
io.String.Output(),

View File

@@ -0,0 +1,176 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class TextGenerate(io.ComfyNode):
@classmethod
def define_schema(cls):
# Define dynamic combo options for sampling mode
sampling_options = [
io.DynamicCombo.Option(
key="on",
inputs=[
io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001),
io.Int.Input("top_k", default=64, min=0, max=1000),
io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
]
),
io.DynamicCombo.Option(
key="off",
inputs=[]
),
]
return io.Schema(
node_id="TextGenerate",
category="textgen/",
search_aliases=["LLM", "gemma"],
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
],
outputs=[
io.String.Output(display_name="generated_text"),
],
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
temperature = sampling_mode.get("temperature", 1.0)
top_k = sampling_mode.get("top_k", 50)
top_p = sampling_mode.get("top_p", 1.0)
min_p = sampling_mode.get("min_p", 0.0)
seed = sampling_mode.get("seed", None)
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
generated_ids = clip.generate(
tokens,
do_sample=do_sample,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
repetition_penalty=repetition_penalty,
seed=seed
)
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
return io.NodeOutput(generated_text)
LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
#### Example
Input: "A woman at a coffee shop talking on the phone"
Output:
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.
"""
LTX2_I2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output:
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
"""
class TextGenerateLTX2Prompt(TextGenerate):
@classmethod
def define_schema(cls):
parent_schema = super().define_schema()
return io.Schema(
node_id="TextGenerateLTX2Prompt",
category=parent_schema.category,
inputs=parent_schema.inputs,
outputs=parent_schema.outputs,
search_aliases=["prompt enhance", "LLM", "gemma"],
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
class TextgenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextGenerate,
TextGenerateLTX2Prompt,
]
async def comfy_entrypoint() -> TextgenExtension:
return TextgenExtension()

View File

@@ -16,6 +16,7 @@ class TorchCompileModel(io.ComfyNode):
io.Combo.Input(
"backend",
options=["inductor", "cudagraphs"],
advanced=True,
),
],
outputs=[io.Model.Output()],
@@ -24,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone()
m = model.clone(disable_dynamic=True)
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m)

View File

@@ -4,6 +4,7 @@ import os
import numpy as np
import safetensors
import torch
import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont
@@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample(
self,
noise,
@@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape,
self.conds,
self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
force_full_load=not self.offloading,
force_offload=self.offloading,
)
)
torch.cuda.empty_cache()
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
return result
def patch(m):
def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
@@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
default=True,
tooltip="Use gradient checkpointing for training.",
),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
@@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype,
algorithm,
gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora,
bucket_mode,
bypass_mode,
@@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
@@ -1019,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
@@ -1054,16 +1168,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model, depth=checkpoint_depth
)
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
[mp], memory_required=1e20, force_full_load=not offloading
)
torch.cuda.empty_cache()
@@ -1100,7 +1216,7 @@ class TrainLoraNode(io.ComfyNode):
)
# Setup guider
guider = TrainGuider(mp)
guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
@@ -1113,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop
try:
comfy.model_management.in_training = True
_run_training_loop(
guider,
train_sampler,
@@ -1123,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
@@ -1132,19 +1250,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
del adapter
del all_weight_adapters
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):#
class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
@@ -1166,6 +1285,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
],
outputs=[
io.Model.Output(
@@ -1175,13 +1299,18 @@ class LoraModelLoader(io.ComfyNode):#
)
@classmethod
def execute(cls, model, lora, strength_model):
def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0:
return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
if bypass:
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)

View File

@@ -73,6 +73,7 @@ class SaveVideo(io.ComfyNode):
search_aliases=["export video"],
display_name="Save Video",
category="image/video",
essentials_category="Basics",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
io.Video.Input("video", tooltip="The video to save."),
@@ -174,6 +175,7 @@ class LoadVideo(io.ComfyNode):
search_aliases=["import video", "open video", "video file"],
display_name="Load Video",
category="image/video",
essentials_category="Basics",
inputs=[
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
],
@@ -202,6 +204,57 @@ class LoadVideo(io.ComfyNode):
return True
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
category="image/video",
essentials_category="Video Tools",
inputs=[
io.Video.Input("video"),
io.Float.Input(
"start_time",
default=0.0,
max=1e5,
min=-1e5,
step=0.001,
tooltip="Start time in seconds",
),
io.Float.Input(
"duration",
default=0.0,
min=0.0,
step=0.001,
tooltip="Duration in seconds, or 0 for unlimited duration",
),
io.Boolean.Input(
"strict_duration",
default=False,
tooltip="If True, when the specified duration is not possible, an error will be raised.",
),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
if trimmed is not None:
return io.NodeOutput(trimmed)
raise ValueError(
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
)
class VideoExtension(ComfyExtension):
@override
@@ -212,6 +265,7 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
VideoSlice,
]
async def comfy_entrypoint() -> VideoExtension:

View File

@@ -32,9 +32,9 @@ class SVD_img2vid_Conditioning:
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023, "advanced": True}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "advanced": True})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@@ -60,7 +60,7 @@ class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@@ -84,7 +84,7 @@ class VideoTriangleCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

View File

@@ -717,8 +717,8 @@ class WanTrackToVideo(io.ComfyNode):
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1),
io.Int.Input("topk", default=2, min=1, max=10),
io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1, advanced=True),
io.Int.Input("topk", default=2, min=1, max=10, advanced=True),
io.Image.Input("start_image"),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
@@ -1323,7 +1323,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context.", advanced=True),
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
io.Image.Input("previous_frames", optional=True),
],

View File

@@ -252,9 +252,9 @@ class WanMoveVisualizeTracks(io.ComfyNode):
io.Image.Input("images"),
io.Tracks.Input("tracks", optional=True),
io.Int.Input("line_resolution", default=24, min=1, max=1024),
io.Int.Input("circle_size", default=12, min=1, max=128),
io.Int.Input("circle_size", default=12, min=1, max=128, advanced=True),
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
io.Int.Input("line_width", default=16, min=1, max=128),
io.Int.Input("line_width", default=16, min=1, max=128, advanced=True),
],
outputs=[
io.Image.Output(),

View File

@@ -16,7 +16,7 @@ class TextEncodeZImageOmni(io.ComfyNode):
io.Clip.Input("clip"),
io.ClipVision.Input("image_encoder", optional=True),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Boolean.Input("auto_resize_images", default=True),
io.Boolean.Input("auto_resize_images", default=True, advanced=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),