From ac1073be99887c6dfbbec2fef4a886cd9f1e3fd8 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 6 Feb 2026 15:58:30 +0200 Subject: [PATCH] convert model_merging and video_model nodes to V3 schema --- comfy_extras/nodes_model_merging.py | 425 ++++++++------ .../nodes_model_merging_model_specific.py | 527 +++++++++++------- comfy_extras/nodes_video_model.py | 238 +++++--- 3 files changed, 709 insertions(+), 481 deletions(-) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 5384ed531..fa53f3af3 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -10,146 +10,198 @@ import json import os from comfy.cli_args import args +from comfy_api.latest import io, ComfyExtension +from typing_extensions import override -class ModelMergeSimple: + +class ModelMergeSimple(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model1": ("MODEL",), - "model2": ("MODEL",), - "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "merge" + def define_schema(cls): + return io.Schema( + node_id="ModelMergeSimple", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "advanced/model_merging" - - def merge(self, model1, model2, ratio): + @classmethod + def execute(cls, model1, model2, ratio) -> io.NodeOutput: m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) - return (m, ) + return io.NodeOutput(m) -class ModelSubtract: + merge = execute # TODO: remove + + +class ModelSubtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model1": ("MODEL",), - "model2": ("MODEL",), - "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "merge" + def define_schema(cls): + return io.Schema( + node_id="ModelMergeSubtract", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "advanced/model_merging" - - def merge(self, model1, model2, multiplier): + @classmethod + def execute(cls, model1, model2, multiplier) -> io.NodeOutput: m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, - multiplier, multiplier) - return (m, ) + return io.NodeOutput(m) -class ModelAdd: + merge = execute # TODO: remove + + +class ModelAdd(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model1": ("MODEL",), - "model2": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "merge" + def define_schema(cls): + return io.Schema( + node_id="ModelMergeAdd", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model1"), + io.Model.Input("model2"), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "advanced/model_merging" - - def merge(self, model1, model2): + @classmethod + def execute(cls, model1, model2) -> io.NodeOutput: m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, 1.0, 1.0) - return (m, ) + return io.NodeOutput(m) + + merge = execute # TODO: remove -class CLIPMergeSimple: +class CLIPMergeSimple(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip1": ("CLIP",), - "clip2": ("CLIP",), - "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "merge" + def define_schema(cls): + return io.Schema( + node_id="CLIPMergeSimple", + category="advanced/model_merging", + inputs=[ + io.Clip.Input("clip1"), + io.Clip.Input("clip2"), + io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Clip.Output(), + ], + ) - CATEGORY = "advanced/model_merging" - - def merge(self, clip1, clip2, ratio): + @classmethod + def execute(cls, clip1, clip2, ratio) -> io.NodeOutput: m = clip1.clone() kp = clip2.get_key_patches() for k in kp: if k.endswith(".position_ids") or k.endswith(".logit_scale"): continue m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) - return (m, ) + return io.NodeOutput(m) + + merge = execute # TODO: remove -class CLIPSubtract: - SEARCH_ALIASES = ["clip difference", "text encoder subtract"] +class CLIPSubtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip1": ("CLIP",), - "clip2": ("CLIP",), - "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "merge" + def define_schema(cls): + return io.Schema( + node_id="CLIPMergeSubtract", + search_aliases=["clip difference", "text encoder subtract"], + category="advanced/model_merging", + inputs=[ + io.Clip.Input("clip1"), + io.Clip.Input("clip2"), + io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + io.Clip.Output(), + ], + ) - CATEGORY = "advanced/model_merging" - - def merge(self, clip1, clip2, multiplier): + @classmethod + def execute(cls, clip1, clip2, multiplier) -> io.NodeOutput: m = clip1.clone() kp = clip2.get_key_patches() for k in kp: if k.endswith(".position_ids") or k.endswith(".logit_scale"): continue m.add_patches({k: kp[k]}, - multiplier, multiplier) - return (m, ) + return io.NodeOutput(m) + + merge = execute # TODO: remove -class CLIPAdd: - SEARCH_ALIASES = ["combine clip"] +class CLIPAdd(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip1": ("CLIP",), - "clip2": ("CLIP",), - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "merge" + def define_schema(cls): + return io.Schema( + node_id="CLIPMergeAdd", + search_aliases=["combine clip"], + category="advanced/model_merging", + inputs=[ + io.Clip.Input("clip1"), + io.Clip.Input("clip2"), + ], + outputs=[ + io.Clip.Output(), + ], + ) - CATEGORY = "advanced/model_merging" - - def merge(self, clip1, clip2): + @classmethod + def execute(cls, clip1, clip2) -> io.NodeOutput: m = clip1.clone() kp = clip2.get_key_patches() for k in kp: if k.endswith(".position_ids") or k.endswith(".logit_scale"): continue m.add_patches({k: kp[k]}, 1.0, 1.0) - return (m, ) + return io.NodeOutput(m) + + merge = execute # TODO: remove -class ModelMergeBlocks: +class ModelMergeBlocks(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model1": ("MODEL",), - "model2": ("MODEL",), - "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "merge" + def define_schema(cls): + return io.Schema( + node_id="ModelMergeBlocks", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "advanced/model_merging" - - def merge(self, model1, model2, **kwargs): + @classmethod + def execute(cls, model1, model2, **kwargs) -> io.NodeOutput: m = model1.clone() kp = model2.get_key_patches("diffusion_model.") default_ratio = next(iter(kwargs.values())) @@ -165,7 +217,10 @@ class ModelMergeBlocks: last_arg_size = len(arg) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) - return (m, ) + return io.NodeOutput(m) + + merge = execute # TODO: remove + def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) @@ -226,59 +281,65 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) -class CheckpointSave: - SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"] - def __init__(self): - self.output_dir = folder_paths.get_output_directory() + +class CheckpointSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CheckpointSave", + display_name="Save Checkpoint", + search_aliases=["save model", "export checkpoint", "merge save"], + category="advanced/model_merging", + inputs=[ + io.Model.Input("model"), + io.Clip.Input("clip"), + io.Vae.Input("vae"), + io.String.Input("filename_prefix", default="checkpoints/ComfyUI"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "clip": ("CLIP",), - "vae": ("VAE",), - "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} - RETURN_TYPES = () - FUNCTION = "save" - OUTPUT_NODE = True + def execute(cls, model, clip, vae, filename_prefix) -> io.NodeOutput: + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo) + return io.NodeOutput() - CATEGORY = "advanced/model_merging" + save = execute # TODO: remove - def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): - save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) - return {} -class CLIPSave: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class CLIPSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPSave", + category="advanced/model_merging", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("filename_prefix", default="clip/ComfyUI"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "clip": ("CLIP",), - "filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} - RETURN_TYPES = () - FUNCTION = "save" - OUTPUT_NODE = True - - CATEGORY = "advanced/model_merging" - - def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): + def execute(cls, clip, filename_prefix) -> io.NodeOutput: prompt_info = "" - if prompt is not None: - prompt_info = json.dumps(prompt) + if cls.hidden.prompt is not None: + prompt_info = json.dumps(cls.hidden.prompt) metadata = {} if not args.disable_metadata: metadata["format"] = "pt" metadata["prompt"] = prompt_info - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) clip_sd = clip.get_sd() + output_dir = folder_paths.get_output_directory() for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]: k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys())) current_clip_sd = {} @@ -295,7 +356,7 @@ class CLIPSave: replace_prefix[prefix] = "" replace_prefix["transformer."] = "" - full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, output_dir) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) @@ -303,76 +364,88 @@ class CLIPSave: current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix) comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata) - return {} + return io.NodeOutput() -class VAESave: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() + save = execute # TODO: remove + + +class VAESave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VAESave", + category="advanced/model_merging", + inputs=[ + io.Vae.Input("vae"), + io.String.Input("filename_prefix", default="vae/ComfyUI_vae"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "vae": ("VAE",), - "filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} - RETURN_TYPES = () - FUNCTION = "save" - OUTPUT_NODE = True - - CATEGORY = "advanced/model_merging" - - def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + def execute(cls, vae, filename_prefix) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) prompt_info = "" - if prompt is not None: - prompt_info = json.dumps(prompt) + if cls.hidden.prompt is not None: + prompt_info = json.dumps(cls.hidden.prompt) metadata = {} if not args.disable_metadata: metadata["prompt"] = prompt_info - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) - return {} + return io.NodeOutput() -class ModelSave: - SEARCH_ALIASES = ["export model", "checkpoint save"] - def __init__(self): - self.output_dir = folder_paths.get_output_directory() + save = execute # TODO: remove + + +class ModelSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelSave", + search_aliases=["export model", "checkpoint save"], + category="advanced/model_merging", + inputs=[ + io.Model.Input("model"), + io.String.Input("filename_prefix", default="diffusion_models/ComfyUI"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} - RETURN_TYPES = () - FUNCTION = "save" - OUTPUT_NODE = True + def execute(cls, model, filename_prefix) -> io.NodeOutput: + save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo) + return io.NodeOutput() - CATEGORY = "advanced/model_merging" + save = execute # TODO: remove - def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): - save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) - return {} -NODE_CLASS_MAPPINGS = { - "ModelMergeSimple": ModelMergeSimple, - "ModelMergeBlocks": ModelMergeBlocks, - "ModelMergeSubtract": ModelSubtract, - "ModelMergeAdd": ModelAdd, - "CheckpointSave": CheckpointSave, - "CLIPMergeSimple": CLIPMergeSimple, - "CLIPMergeSubtract": CLIPSubtract, - "CLIPMergeAdd": CLIPAdd, - "CLIPSave": CLIPSave, - "VAESave": VAESave, - "ModelSave": ModelSave, -} +class ModelMergingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ModelMergeSimple, + ModelMergeBlocks, + ModelSubtract, + ModelAdd, + CheckpointSave, + CLIPMergeSimple, + CLIPSubtract, + CLIPAdd, + CLIPSave, + VAESave, + ModelSave, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "CheckpointSave": "Save Checkpoint", -} + +async def comfy_entrypoint() -> ModelMergingExtension: + return ModelMergingExtension() diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 55eb3ccfe..2680dcebe 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -1,356 +1,455 @@ import comfy_extras.nodes_model_merging +from comfy_api.latest import io, ComfyExtension +from typing_extensions import override + + class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["time_embed."] = argument - arg_dict["label_emb."] = argument + inputs.append(io.Float.Input("time_embed.", **argument)) + inputs.append(io.Float.Input("label_emb.", **argument)) for i in range(12): - arg_dict["input_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("input_blocks.{}.".format(i), **argument)) for i in range(3): - arg_dict["middle_block.{}.".format(i)] = argument + inputs.append(io.Float.Input("middle_block.{}.".format(i), **argument)) for i in range(12): - arg_dict["output_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("output_blocks.{}.".format(i), **argument)) - arg_dict["out."] = argument + inputs.append(io.Float.Input("out.", **argument)) - return {"required": arg_dict} + return io.Schema( + node_id="ModelMergeSD1", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) + + +class ModelMergeSD2(ModelMergeSD1): + # SD1 and SD2 have the same blocks + @classmethod + def define_schema(cls): + schema = ModelMergeSD1.define_schema() + schema.node_id = "ModelMergeSD2" + return schema class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["time_embed."] = argument - arg_dict["label_emb."] = argument + inputs.append(io.Float.Input("time_embed.", **argument)) + inputs.append(io.Float.Input("label_emb.", **argument)) for i in range(9): - arg_dict["input_blocks.{}".format(i)] = argument + inputs.append(io.Float.Input("input_blocks.{}".format(i), **argument)) for i in range(3): - arg_dict["middle_block.{}".format(i)] = argument + inputs.append(io.Float.Input("middle_block.{}".format(i), **argument)) for i in range(9): - arg_dict["output_blocks.{}".format(i)] = argument + inputs.append(io.Float.Input("output_blocks.{}".format(i), **argument)) - arg_dict["out."] = argument + inputs.append(io.Float.Input("out.", **argument)) + + return io.Schema( + node_id="ModelMergeSDXL", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["pos_embed."] = argument - arg_dict["x_embedder."] = argument - arg_dict["context_embedder."] = argument - arg_dict["y_embedder."] = argument - arg_dict["t_embedder."] = argument + inputs.append(io.Float.Input("pos_embed.", **argument)) + inputs.append(io.Float.Input("x_embedder.", **argument)) + inputs.append(io.Float.Input("context_embedder.", **argument)) + inputs.append(io.Float.Input("y_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) for i in range(24): - arg_dict["joint_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) - return {"required": arg_dict} + return io.Schema( + node_id="ModelMergeSD3_2B", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["init_x_linear."] = argument - arg_dict["positional_encoding"] = argument - arg_dict["cond_seq_linear."] = argument - arg_dict["register_tokens"] = argument - arg_dict["t_embedder."] = argument + inputs.append(io.Float.Input("init_x_linear.", **argument)) + inputs.append(io.Float.Input("positional_encoding", **argument)) + inputs.append(io.Float.Input("cond_seq_linear.", **argument)) + inputs.append(io.Float.Input("register_tokens", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) for i in range(4): - arg_dict["double_layers.{}.".format(i)] = argument + inputs.append(io.Float.Input("double_layers.{}.".format(i), **argument)) for i in range(32): - arg_dict["single_layers.{}.".format(i)] = argument + inputs.append(io.Float.Input("single_layers.{}.".format(i), **argument)) - arg_dict["modF."] = argument - arg_dict["final_linear."] = argument + inputs.append(io.Float.Input("modF.", **argument)) + inputs.append(io.Float.Input("final_linear.", **argument)) + + return io.Schema( + node_id="ModelMergeAuraflow", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["img_in."] = argument - arg_dict["time_in."] = argument - arg_dict["guidance_in"] = argument - arg_dict["vector_in."] = argument - arg_dict["txt_in."] = argument + inputs.append(io.Float.Input("img_in.", **argument)) + inputs.append(io.Float.Input("time_in.", **argument)) + inputs.append(io.Float.Input("guidance_in", **argument)) + inputs.append(io.Float.Input("vector_in.", **argument)) + inputs.append(io.Float.Input("txt_in.", **argument)) for i in range(19): - arg_dict["double_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("double_blocks.{}.".format(i), **argument)) for i in range(38): - arg_dict["single_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("single_blocks.{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) + + return io.Schema( + node_id="ModelMergeFlux1", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["pos_embed."] = argument - arg_dict["x_embedder."] = argument - arg_dict["context_embedder."] = argument - arg_dict["y_embedder."] = argument - arg_dict["t_embedder."] = argument + inputs.append(io.Float.Input("pos_embed.", **argument)) + inputs.append(io.Float.Input("x_embedder.", **argument)) + inputs.append(io.Float.Input("context_embedder.", **argument)) + inputs.append(io.Float.Input("y_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) for i in range(38): - arg_dict["joint_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) + + return io.Schema( + node_id="ModelMergeSD35_Large", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["pos_frequencies."] = argument - arg_dict["t_embedder."] = argument - arg_dict["t5_y_embedder."] = argument - arg_dict["t5_yproj."] = argument + inputs.append(io.Float.Input("pos_frequencies.", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) + inputs.append(io.Float.Input("t5_y_embedder.", **argument)) + inputs.append(io.Float.Input("t5_yproj.", **argument)) for i in range(48): - arg_dict["blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("blocks.{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) + + return io.Schema( + node_id="ModelMergeMochiPreview", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["patchify_proj."] = argument - arg_dict["adaln_single."] = argument - arg_dict["caption_projection."] = argument + inputs.append(io.Float.Input("patchify_proj.", **argument)) + inputs.append(io.Float.Input("adaln_single.", **argument)) + inputs.append(io.Float.Input("caption_projection.", **argument)) for i in range(28): - arg_dict["transformer_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument)) - arg_dict["scale_shift_table"] = argument - arg_dict["proj_out."] = argument + inputs.append(io.Float.Input("scale_shift_table", **argument)) + inputs.append(io.Float.Input("proj_out.", **argument)) + + return io.Schema( + node_id="ModelMergeLTXV", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) - - arg_dict["pos_embedder."] = argument - arg_dict["extra_pos_embedder."] = argument - arg_dict["x_embedder."] = argument - arg_dict["t_embedder."] = argument - arg_dict["affline_norm."] = argument + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) + inputs.append(io.Float.Input("pos_embedder.", **argument)) + inputs.append(io.Float.Input("extra_pos_embedder.", **argument)) + inputs.append(io.Float.Input("x_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) + inputs.append(io.Float.Input("affline_norm.", **argument)) for i in range(28): - arg_dict["blocks.block{}.".format(i)] = argument + inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) + + return io.Schema( + node_id="ModelMergeCosmos7B", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) - - arg_dict["pos_embedder."] = argument - arg_dict["extra_pos_embedder."] = argument - arg_dict["x_embedder."] = argument - arg_dict["t_embedder."] = argument - arg_dict["affline_norm."] = argument + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) + inputs.append(io.Float.Input("pos_embedder.", **argument)) + inputs.append(io.Float.Input("extra_pos_embedder.", **argument)) + inputs.append(io.Float.Input("x_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) + inputs.append(io.Float.Input("affline_norm.", **argument)) for i in range(36): - arg_dict["blocks.block{}.".format(i)] = argument + inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) + + return io.Schema( + node_id="ModelMergeCosmos14B", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb." - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["patch_embedding."] = argument - arg_dict["time_embedding."] = argument - arg_dict["time_projection."] = argument - arg_dict["text_embedding."] = argument - arg_dict["img_emb."] = argument + inputs.append(io.Float.Input("patch_embedding.", **argument)) + inputs.append(io.Float.Input("time_embedding.", **argument)) + inputs.append(io.Float.Input("time_projection.", **argument)) + inputs.append(io.Float.Input("text_embedding.", **argument)) + inputs.append(io.Float.Input("img_emb.", **argument)) for i in range(40): - arg_dict["blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("blocks.{}.".format(i), **argument)) - arg_dict["head."] = argument + inputs.append(io.Float.Input("head.", **argument)) + + return io.Schema( + node_id="ModelMergeWAN2_1", + category="advanced/model_merging/model_specific", + description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) - - arg_dict["pos_embedder."] = argument - arg_dict["x_embedder."] = argument - arg_dict["t_embedder."] = argument - arg_dict["t_embedding_norm."] = argument + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) + inputs.append(io.Float.Input("pos_embedder.", **argument)) + inputs.append(io.Float.Input("x_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedding_norm.", **argument)) for i in range(28): - arg_dict["blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("blocks.{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) + + return io.Schema( + node_id="ModelMergeCosmosPredict2_2B", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) - - arg_dict["pos_embedder."] = argument - arg_dict["x_embedder."] = argument - arg_dict["t_embedder."] = argument - arg_dict["t_embedding_norm."] = argument + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) + inputs.append(io.Float.Input("pos_embedder.", **argument)) + inputs.append(io.Float.Input("x_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedder.", **argument)) + inputs.append(io.Float.Input("t_embedding_norm.", **argument)) for i in range(36): - arg_dict["blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("blocks.{}.".format(i), **argument)) - arg_dict["final_layer."] = argument + inputs.append(io.Float.Input("final_layer.", **argument)) + + return io.Schema( + node_id="ModelMergeCosmosPredict2_14B", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) - return {"required": arg_dict} class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks): - CATEGORY = "advanced/model_merging/model_specific" - @classmethod - def INPUT_TYPES(s): - arg_dict = { "model1": ("MODEL",), - "model2": ("MODEL",)} + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + ] - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + argument = dict(default=1.0, min=0.0, max=1.0, step=0.01) - arg_dict["pos_embeds."] = argument - arg_dict["img_in."] = argument - arg_dict["txt_norm."] = argument - arg_dict["txt_in."] = argument - arg_dict["time_text_embed."] = argument + inputs.append(io.Float.Input("pos_embeds.", **argument)) + inputs.append(io.Float.Input("img_in.", **argument)) + inputs.append(io.Float.Input("txt_norm.", **argument)) + inputs.append(io.Float.Input("txt_in.", **argument)) + inputs.append(io.Float.Input("time_text_embed.", **argument)) for i in range(60): - arg_dict["transformer_blocks.{}.".format(i)] = argument + inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument)) - arg_dict["proj_out."] = argument + inputs.append(io.Float.Input("proj_out.", **argument)) - return {"required": arg_dict} + return io.Schema( + node_id="ModelMergeQwenImage", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[io.Model.Output()], + ) -NODE_CLASS_MAPPINGS = { - "ModelMergeSD1": ModelMergeSD1, - "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks - "ModelMergeSDXL": ModelMergeSDXL, - "ModelMergeSD3_2B": ModelMergeSD3_2B, - "ModelMergeAuraflow": ModelMergeAuraflow, - "ModelMergeFlux1": ModelMergeFlux1, - "ModelMergeSD35_Large": ModelMergeSD35_Large, - "ModelMergeMochiPreview": ModelMergeMochiPreview, - "ModelMergeLTXV": ModelMergeLTXV, - "ModelMergeCosmos7B": ModelMergeCosmos7B, - "ModelMergeCosmos14B": ModelMergeCosmos14B, - "ModelMergeWAN2_1": ModelMergeWAN2_1, - "ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B, - "ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B, - "ModelMergeQwenImage": ModelMergeQwenImage, -} + +class ModelMergingModelSpecificExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ModelMergeSD1, + ModelMergeSD2, + ModelMergeSDXL, + ModelMergeSD3_2B, + ModelMergeAuraflow, + ModelMergeFlux1, + ModelMergeSD35_Large, + ModelMergeMochiPreview, + ModelMergeLTXV, + ModelMergeCosmos7B, + ModelMergeCosmos14B, + ModelMergeWAN2_1, + ModelMergeCosmosPredict2_2B, + ModelMergeCosmosPredict2_14B, + ModelMergeQwenImage, + ] + + +async def comfy_entrypoint() -> ModelMergingModelSpecificExtension: + return ModelMergingModelSpecificExtension() diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 0f760aa26..b10511724 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -6,44 +6,62 @@ import folder_paths import comfy_extras.nodes_model_merging import node_helpers +from comfy_api.latest import io, ComfyExtension +from typing_extensions import override -class ImageOnlyCheckpointLoader: + +class ImageOnlyCheckpointLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} - RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") - FUNCTION = "load_checkpoint" + def define_schema(cls): + return io.Schema( + node_id="ImageOnlyCheckpointLoader", + display_name="Image Only Checkpoint Loader (img2vid model)", + category="loaders/video_models", + inputs=[ + io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")), + ], + outputs=[ + io.Model.Output(), + io.ClipVision.Output(), + io.Vae.Output(), + ], + ) - CATEGORY = "loaders/video_models" - - def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + @classmethod + def execute(cls, ckpt_name, output_vae=True, output_clip=True) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - return (out[0], out[3], out[2]) + return io.NodeOutput(out[0], out[3], out[2]) + + load_checkpoint = execute # TODO: remove -class SVD_img2vid_Conditioning: +class SVD_img2vid_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}), - "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}), - "fps": ("INT", {"default": 6, "min": 1, "max": 1024}), - "augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}) - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="SVD_img2vid_Conditioning", + category="conditioning/video_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("video_frames", default=14, min=1, max=4096), + io.Int.Input("motion_bucket_id", default=127, min=1, max=1023), + io.Int.Input("fps", default=6, min=1, max=1024), + io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -54,20 +72,28 @@ class SVD_img2vid_Conditioning: positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([video_frames, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent}) + return io.NodeOutput(positive, negative, {"samples":latent}) -class VideoLinearCFGGuidance: + encode = execute # TODO: remove + + +class VideoLinearCFGGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="VideoLinearCFGGuidance", + category="sampling/video_models", + inputs=[ + io.Model.Input("model"), + io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "sampling/video_models" - - def patch(self, model, min_cfg): + @classmethod + def execute(cls, model, min_cfg) -> io.NodeOutput: def linear_cfg(args): cond = args["cond"] uncond = args["uncond"] @@ -78,20 +104,28 @@ class VideoLinearCFGGuidance: m = model.clone() m.set_model_sampler_cfg_function(linear_cfg) - return (m, ) + return io.NodeOutput(m) -class VideoTriangleCFGGuidance: + patch = execute # TODO: remove + + +class VideoTriangleCFGGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="VideoTriangleCFGGuidance", + category="sampling/video_models", + inputs=[ + io.Model.Input("model"), + io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "sampling/video_models" - - def patch(self, model, min_cfg): + @classmethod + def execute(cls, model, min_cfg) -> io.NodeOutput: def linear_cfg(args): cond = args["cond"] uncond = args["uncond"] @@ -105,57 +139,79 @@ class VideoTriangleCFGGuidance: m = model.clone() m.set_model_sampler_cfg_function(linear_cfg) - return (m, ) + return io.NodeOutput(m) -class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): - CATEGORY = "advanced/model_merging" + patch = execute # TODO: remove + + +class ImageOnlyCheckpointSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageOnlyCheckpointSave", + search_aliases=["save model", "export checkpoint", "merge save"], + category="advanced/model_merging", + inputs=[ + io.Model.Input("model"), + io.ClipVision.Input("clip_vision"), + io.Vae.Input("vae"), + io.String.Input("filename_prefix", default="checkpoints/ComfyUI"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "clip_vision": ("CLIP_VISION",), - "vae": ("VAE",), - "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + def execute(cls, model, clip_vision, vae, filename_prefix) -> io.NodeOutput: + comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo) + return io.NodeOutput() - def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): - comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) - return {} + save = execute # TODO: remove -class ConditioningSetAreaPercentageVideo: +class ConditioningSetAreaPercentageVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - "temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), - "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), - "z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" + def define_schema(cls): + return io.Schema( + node_id="ConditioningSetAreaPercentageVideo", + category="conditioning", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Float.Input("width", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("height", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("temporal", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("x", default=0.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("y", default=0.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("z", default=0.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "conditioning" - - def append(self, conditioning, width, height, temporal, x, y, z, strength): + @classmethod + def execute(cls, conditioning, width, height, temporal, x, y, z, strength) -> io.NodeOutput: c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x), "strength": strength, "set_area_to_bounds": False}) - return (c, ) + return io.NodeOutput(c) + + append = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, - "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, - "VideoLinearCFGGuidance": VideoLinearCFGGuidance, - "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, - "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, - "ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo, -} +class VideoModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ImageOnlyCheckpointLoader, + SVD_img2vid_Conditioning, + VideoLinearCFGGuidance, + VideoTriangleCFGGuidance, + ImageOnlyCheckpointSave, + ConditioningSetAreaPercentageVideo, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", -} + +async def comfy_entrypoint() -> VideoModelExtension: + return VideoModelExtension()