diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py index f99ba13fb..eb94fc528 100644 --- a/comfy_extras/nodes_lt_upsampler.py +++ b/comfy_extras/nodes_lt_upsampler.py @@ -1,32 +1,32 @@ from comfy import model_management +from comfy_api.latest import ComfyExtension, IO +from typing_extensions import override import math -class LTXVLatentUpsampler: + +class LTXVLatentUpsampler(IO.ComfyNode): """ Upsamples a video latent by a factor of 2. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "samples": ("LATENT",), - "upscale_model": ("LATENT_UPSCALE_MODEL",), - "vae": ("VAE",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="LTXVLatentUpsampler", + category="latent/video", + is_experimental=True, + inputs=[ + IO.Latent.Input("samples"), + IO.LatentUpscaleModel.Input("upscale_model"), + IO.Vae.Input("vae"), + ], + outputs=[ + IO.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "upsample_latent" - CATEGORY = "latent/video" - EXPERIMENTAL = True - - def upsample_latent( - self, - samples: dict, - upscale_model, - vae, - ) -> tuple: + @classmethod + def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput: """ Upsample the input latent using the provided model. @@ -34,7 +34,6 @@ class LTXVLatentUpsampler: samples (dict): Input latent samples upscale_model (LatentUpsampler): Loaded upscale model vae: VAE model for normalization - auto_tiling (bool): Whether to automatically tile the input for processing Returns: tuple: Tuple containing the upsampled latent @@ -67,9 +66,16 @@ class LTXVLatentUpsampler: return_dict = samples.copy() return_dict["samples"] = upsampled_latents return_dict.pop("noise_mask", None) - return (return_dict,) + return IO.NodeOutput(return_dict) + + upsample_latent = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "LTXVLatentUpsampler": LTXVLatentUpsampler, -} +class LTXVLatentUpsamplerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [LTXVLatentUpsampler] + + +async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension: + return LTXVLatentUpsamplerExtension()