Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski
2025-09-11 20:59:47 -07:00
75 changed files with 7128 additions and 1305 deletions

View File

@@ -1,6 +1,10 @@
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
@@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
class AlignYourStepsScheduler:
class AlignYourStepsScheduler(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AlignYourStepsScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
io.Int.Input("steps", default=10, min=1, max=10000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Sigmas.Output()],
)
def get_sigmas(self, model_type, steps, denoise):
# Deprecated: use the V3 schema's `execute` method instead of this.
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
@classmethod
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
@@ -46,8 +55,15 @@ class AlignYourStepsScheduler:
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"AlignYourStepsScheduler": AlignYourStepsScheduler,
}
class AlignYourStepsExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
AlignYourStepsScheduler,
]
async def comfy_entrypoint() -> AlignYourStepsExtension:
return AlignYourStepsExtension()

View File

@@ -162,7 +162,12 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
# catch division by zero for log statement; sucks to crash after all sampling is done
try:
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
except ZeroDivisionError:
speedup = 1.0
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
easycache.reset()
guider.model_options = orig_model_options

View File

@@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
@@ -115,6 +115,8 @@ class FluxKontextMultiReferenceLatentMethod:
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )

View File

@@ -113,6 +113,20 @@ class HunyuanImageToVideo:
out_latent["samples"] = latent
return (positive, out_latent)
class EmptyHunyuanImageLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
@@ -120,4 +134,5 @@ NODE_CLASS_MAPPINGS = {
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
}

View File

@@ -8,13 +8,16 @@ import folder_paths
import comfy.model_management
from comfy.cli_args import args
class EmptyLatentHunyuan3Dv2:
@classmethod
def INPUT_TYPES(s):
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
return {
"required": {
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2:
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
class Hunyuan3Dv2Conditioning:
@classmethod
def INPUT_TYPES(s):
@@ -81,7 +83,6 @@ class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def INPUT_TYPES(s):
@@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D:
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
@@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
z_idx, y_idx, x_idx = pos.unbind(-1)
corner_values = padded[z_idx, y_idx, x_idx]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)

View File

@@ -625,6 +625,37 @@ class ImageFlip:
return (image,)
class ImageScaleToMaxDimension:
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"upscale_method": (s.upscale_methods,),
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, largest_size):
height = image.shape[1]
width = image.shape[2]
if height > width:
width = round((width / height) * largest_size)
height = largest_size
elif width > height:
height = round((height / width) * largest_size)
width = largest_size
else:
height = largest_size
width = largest_size
samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
@@ -639,4 +670,5 @@ NODE_CLASS_MAPPINGS = {
"GetImageSize": GetImageSize,
"ImageRotate": ImageRotate,
"ImageFlip": ImageFlip,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
}

View File

@@ -1,4 +1,5 @@
import torch
from torch import nn
import folder_paths
import comfy.utils
import comfy.ops
@@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
class SigLIPMultiFeatProjModel(torch.nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.
Args:
siglip_token_nums (int): Number of SigLIP tokens, default 257
style_token_nums (int): Number of style tokens, default 256
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
hidden_size (int): Hidden layer size, default 3072
context_layer_norm (bool): Whether to use context layer normalization, default False
"""
def __init__(
self,
siglip_token_nums: int = 729,
style_token_nums: int = 64,
siglip_token_dims: int = 1152,
hidden_size: int = 3072,
context_layer_norm: bool = True,
device=None, dtype=None, operations=None
):
super().__init__()
# High-level feature processing (layer -2)
self.high_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.high_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Mid-level feature processing (layer -11)
self.mid_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.mid_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Low-level feature processing (layer -20)
self.low_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.low_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
def forward(self, siglip_outputs):
"""
Forward pass function
Args:
siglip_outputs: Output from SigLIP model, containing hidden_states
Returns:
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
"""
dtype = next(self.high_embedding_linear.parameters()).dtype
# Process high-level features (layer -2)
high_embedding = self._process_layer_features(
siglip_outputs[2],
self.high_embedding_linear,
self.high_layer_norm,
self.high_projection,
dtype
)
# Process mid-level features (layer -11)
mid_embedding = self._process_layer_features(
siglip_outputs[1],
self.mid_embedding_linear,
self.mid_layer_norm,
self.mid_projection,
dtype
)
# Process low-level features (layer -20)
low_embedding = self._process_layer_features(
siglip_outputs[0],
self.low_embedding_linear,
self.low_layer_norm,
self.low_projection,
dtype
)
# Concatenate features from all layersmodel_patch
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""
Helper function to process features from a single layer
Args:
hidden_states: Input hidden states [bs, seq_len, dim]
embedding_linear: Embedding linear layer
layer_norm: Layer normalization
projection: Projection layer
dtype: Target data type
Returns:
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
"""
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)
# Apply layer normalization
embedding = layer_norm(embedding)
# Project to target hidden space
embedding = projection(embedding)
return embedding
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@@ -73,9 +204,14 @@ class ModelPatchLoader:
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
if 'controlnet_blocks.0.y_rms.weight' in sd:
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
@@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet:
return (model_patched,)
class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
self.model_patch = model_patch
self.encoded_image = encoded_image
def __call__(self, kwargs):
txt_ids = kwargs.get("txt_ids")
txt = kwargs.get("txt")
siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype)
txt = torch.cat([siglip_embedding, txt], dim=1)
kwargs['txt'] = txt
kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1)
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
class USOStyleReference:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_patch"
EXPERIMENTAL = True
CATEGORY = "advanced/model_patches/flux"
def apply_patch(self, model, model_patch, clip_vision_output):
encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states))
model_patched = model.clone()
model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image))
return (model_patched,)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"USOStyleReference": USOStyleReference,
}

View File

@@ -1,98 +1,109 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
import sys
from typing_extensions import override
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
from comfy_api.latest import ComfyExtension, io
class String(ComfyNodeABC):
class String(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveString",
display_name="String",
category="utils/primitive",
inputs=[
io.String.Input("value"),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
class StringMultiline(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {"multiline": True,},)},
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Int(ComfyNodeABC):
class StringMultiline(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveStringMultiline",
display_name="String (Multiline)",
category="utils/primitive",
inputs=[
io.String.Input("value", multiline=True),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.INT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: int) -> tuple[int]:
return (value,)
class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
}
RETURN_TYPES = (IO.FLOAT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: float) -> tuple[float]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(ComfyNodeABC):
class Int(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.BOOLEAN, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveInt",
display_name="Int",
category="utils/primitive",
inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
],
outputs=[io.Int.Output()],
)
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: bool) -> tuple[bool]:
return (value,)
@classmethod
def execute(cls, value: int) -> io.NodeOutput:
return io.NodeOutput(value)
NODE_CLASS_MAPPINGS = {
"PrimitiveString": String,
"PrimitiveStringMultiline": StringMultiline,
"PrimitiveInt": Int,
"PrimitiveFloat": Float,
"PrimitiveBoolean": Boolean,
}
class Float(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveFloat",
display_name="Float",
category="utils/primitive",
inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
],
outputs=[io.Float.Output()],
)
NODE_DISPLAY_NAME_MAPPINGS = {
"PrimitiveString": "String",
"PrimitiveStringMultiline": "String (Multiline)",
"PrimitiveInt": "Int",
"PrimitiveFloat": "Float",
"PrimitiveBoolean": "Boolean",
}
@classmethod
def execute(cls, value: float) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveBoolean",
display_name="Boolean",
category="utils/primitive",
inputs=[
io.Boolean.Input("value"),
],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, value: bool) -> io.NodeOutput:
return io.NodeOutput(value)
class PrimitivesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
String,
StringMultiline,
Int,
Float,
Boolean,
]
async def comfy_entrypoint() -> PrimitivesExtension:
return PrimitivesExtension()

View File

@@ -17,55 +17,61 @@
"""
import torch
import nodes
from typing_extensions import override
import comfy.utils
import nodes
from comfy_api.latest import ComfyExtension, io
class StableCascade_EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_EmptyLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_EmptyLatentImage",
category="latent/stable_cascade",
inputs=[
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, width, height, compression, batch_size=1):
def execute(cls, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_StageC_VAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageC_VAEEncode",
category="latent/stable_cascade",
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, image, vae, compression):
def execute(cls, image, vae, compression):
width = image.shape[-2]
height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio
@@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageB_Conditioning:
class StableCascade_StageB_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageB_Conditioning",
category="conditioning/stable_cascade",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("stage_c"),
],
outputs=[
io.Conditioning.Output(),
],
)
FUNCTION = "set_prior"
CATEGORY = "conditioning/stable_cascade"
def set_prior(self, conditioning, stage_c):
@classmethod
def execute(cls, conditioning, stage_c):
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
d["stable_cascade_prior"] = stage_c["samples"]
n = [t[0], d]
c.append(n)
return (c, )
return io.NodeOutput(c)
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
],
outputs=[
io.Image.Output(display_name="controlnet_input"),
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
def execute(cls, image, vae):
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
@@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
return io.NodeOutput(controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
}
class StableCascadeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
StableCascade_EmptyLatentImage,
StableCascade_StageB_Conditioning,
StableCascade_StageC_VAEEncode,
StableCascade_SuperResolutionControlnet,
]
async def comfy_entrypoint() -> StableCascadeExtension:
return StableCascadeExtension()

View File

@@ -5,52 +5,49 @@ import av
import torch
import folder_paths
import json
from typing import Optional, Literal
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
from comfy_api.latest import Input, InputImpl, Types
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy.cli_args import args
class SaveWEBM:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class SaveWEBM(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveWEBM",
category="image/video",
is_experimental=True,
inputs=[
io.Image.Input("images"),
io.String.Input("filename_prefix", default="ComfyUI"),
io.Combo.Input("codec", options=["vp9", "av1"]),
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"codec": (["vp9", "av1"],),
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/video"
EXPERIMENTAL = True
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
)
file = f"{filename}_{counter:05}_.webm"
container = av.open(os.path.join(full_output_folder, file), mode="w")
if prompt is not None:
container.metadata["prompt"] = json.dumps(prompt)
if cls.hidden.prompt is not None:
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
container.metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
@@ -69,63 +66,46 @@ class SaveWEBM:
container.mux(stream.encode())
container.close()
results: list[FileLocator] = [{
"filename": file,
"subfolder": subfolder,
"type": self.type
}]
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
class SaveVideo(ComfyNodeABC):
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type: Literal["output"] = "output"
self.prefix_append = ""
class SaveVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveVideo",
display_name="Save Video",
category="image/video",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
FUNCTION = "save_video"
OUTPUT_NODE = True
CATEGORY = "image/video"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
self.output_dir,
folder_paths.get_output_directory(),
width,
height
)
results: list[FileLocator] = list()
saved_metadata = None
if not args.disable_metadata:
metadata = {}
if extra_pnginfo is not None:
metadata.update(extra_pnginfo)
if prompt is not None:
metadata["prompt"] = prompt
if cls.hidden.extra_pnginfo is not None:
metadata.update(cls.hidden.extra_pnginfo)
if cls.hidden.prompt is not None:
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=format,
@@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC):
metadata=saved_metadata
)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return { "ui": { "images": results, "animated": (True,) } }
class CreateVideo(ComfyNodeABC):
class CreateVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
},
"optional": {
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
}
}
def define_schema(cls):
return io.Schema(
node_id="CreateVideo",
display_name="Create Video",
category="image/video",
description="Create a video from images.",
inputs=[
io.Image.Input("images", tooltip="The images to create a video from."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
],
outputs=[
io.Video.Output(),
],
)
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "create_video"
CATEGORY = "image/video"
DESCRIPTION = "Create a video from images."
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
return (InputImpl.VideoFromComponents(
Types.VideoComponents(
images=images,
audio=audio,
frame_rate=Fraction(fps),
)
),)
class GetVideoComponents(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
}
}
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
RETURN_NAMES = ("images", "audio", "fps")
FUNCTION = "get_components"
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
CATEGORY = "image/video"
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
class GetVideoComponents(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetVideoComponents",
display_name="Get Video Components",
category="image/video",
description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[
io.Video.Input("video", tooltip="The video to extract components from."),
],
outputs=[
io.Image.Output(display_name="images"),
io.Audio.Output(display_name="audio"),
io.Float.Output(display_name="fps"),
],
)
def get_components(self, video: Input.Video):
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
components = video.get_components()
return (components.images, components.audio, float(components.frame_rate))
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(ComfyNodeABC):
class LoadVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["video"])
return {"required":
{"file": (sorted(files), {"video_upload": True})},
}
CATEGORY = "image/video"
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "load_video"
def load_video(self, file):
video_path = folder_paths.get_annotated_filepath(file)
return (InputImpl.VideoFromFile(video_path),)
return io.Schema(
node_id="LoadVideo",
display_name="Load Video",
category="image/video",
inputs=[
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def IS_CHANGED(cls, file):
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):
video_path = folder_paths.get_annotated_filepath(file)
mod_time = os.path.getmtime(video_path)
# Instead of hashing the file, we can just use the modification time to avoid
@@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC):
return mod_time
@classmethod
def VALIDATE_INPUTS(cls, file):
def validate_inputs(s, file):
if not folder_paths.exists_annotated_filepath(file):
return "Invalid video file: {}".format(file)
return True
NODE_CLASS_MAPPINGS = {
"SaveWEBM": SaveWEBM,
"SaveVideo": SaveVideo,
"CreateVideo": CreateVideo,
"GetVideoComponents": GetVideoComponents,
"LoadVideo": LoadVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveVideo": "Save Video",
"CreateVideo": "Create Video",
"GetVideoComponents": "Get Video Components",
"LoadVideo": "Load Video",
}
class VideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SaveWEBM,
SaveVideo,
CreateVideo,
GetVideoComponents,
LoadVideo,
]
async def comfy_entrypoint() -> VideoExtension:
return VideoExtension()