mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-28 18:31:31 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
100
comfy_extras/nodes_fresca.py
Normal file
100
comfy_extras/nodes_fresca.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Code based on https://github.com/WikiChao/FreSca (MIT License)
|
||||
import torch
|
||||
import torch.fft as fft
|
||||
|
||||
|
||||
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
||||
"""
|
||||
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
|
||||
|
||||
Parameters:
|
||||
x: Input tensor of shape (B, C, H, W)
|
||||
scale_low: Scaling factor for low-frequency components (default: 1.0)
|
||||
scale_high: Scaling factor for high-frequency components (default: 1.5)
|
||||
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
|
||||
|
||||
Returns:
|
||||
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
|
||||
"""
|
||||
# Preserve input dtype and device
|
||||
dtype, device = x.dtype, x.device
|
||||
|
||||
# Convert to float32 for FFT computations
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# 1) Apply FFT and shift low frequencies to center
|
||||
x_freq = fft.fftn(x, dim=(-2, -1))
|
||||
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
||||
|
||||
# Initialize mask with high-frequency scaling factor
|
||||
mask = torch.ones(x_freq.shape, device=device) * scale_high
|
||||
m = mask
|
||||
for d in range(len(x_freq.shape) - 2):
|
||||
dim = d + 2
|
||||
cc = x_freq.shape[dim] // 2
|
||||
f_c = min(freq_cutoff, cc)
|
||||
m = m.narrow(dim, cc - f_c, f_c * 2)
|
||||
|
||||
# Apply low-frequency scaling factor to center region
|
||||
m[:] = scale_low
|
||||
|
||||
# 3) Apply frequency-specific scaling
|
||||
x_freq = x_freq * mask
|
||||
|
||||
# 4) Convert back to spatial domain
|
||||
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
||||
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||
|
||||
# 5) Restore original dtype
|
||||
x_filtered = x_filtered.to(dtype)
|
||||
|
||||
return x_filtered
|
||||
|
||||
|
||||
class FreSca:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
|
||||
"tooltip": "Scaling factor for low-frequency components"}),
|
||||
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
|
||||
"tooltip": "Scaling factor for high-frequency components"}),
|
||||
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
|
||||
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "_for_testing"
|
||||
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
|
||||
def patch(self, model, scale_low, scale_high, freq_cutoff):
|
||||
def custom_cfg_function(args):
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
|
||||
guidance = cond - uncond
|
||||
filtered_guidance = Fourier_filter(
|
||||
guidance,
|
||||
scale_low=scale_low,
|
||||
scale_high=scale_high,
|
||||
freq_cutoff=freq_cutoff,
|
||||
)
|
||||
filtered_cond = filtered_guidance + uncond
|
||||
|
||||
return [filtered_cond, uncond]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FreSca": FreSca,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FreSca": "FreSca",
|
||||
}
|
||||
@@ -21,8 +21,8 @@ class Load3D():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -41,7 +41,7 @@ class Load3D():
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@@ -59,8 +59,8 @@ class Load3DAnimation():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -77,13 +77,16 @@ class Load3DAnimation():
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
},
|
||||
"optional": {
|
||||
"camera_info": ("LOAD3D_CAMERA", {})
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@@ -95,13 +98,22 @@ class Preview3D():
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
|
||||
return {
|
||||
"ui": {
|
||||
"result": [model_file, camera_info]
|
||||
}
|
||||
}
|
||||
|
||||
class Preview3DAnimation():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
},
|
||||
"optional": {
|
||||
"camera_info": ("LOAD3D_CAMERA", {})
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@@ -113,7 +125,13 @@ class Preview3DAnimation():
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
|
||||
return {
|
||||
"ui": {
|
||||
"result": [model_file, camera_info]
|
||||
}
|
||||
}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Load3D": Load3D,
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.latent_formats
|
||||
import comfy.clip_vision
|
||||
|
||||
|
||||
class WanImageToVideo:
|
||||
@@ -99,6 +100,72 @@ class WanFunControlToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
class WanFirstLastFrameToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
if end_image is not None:
|
||||
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
image = torch.ones((length, height, width, 3)) * 0.5
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
|
||||
if start_image is not None:
|
||||
image[:start_image.shape[0]] = start_image
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
|
||||
if end_image is not None:
|
||||
image[-end_image.shape[0]:] = end_image
|
||||
mask[:, :, -end_image.shape[0]:] = 0.0
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_start_image is not None:
|
||||
clip_vision_output = clip_vision_start_image
|
||||
|
||||
if clip_vision_end_image is not None:
|
||||
if clip_vision_output is not None:
|
||||
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
|
||||
clip_vision_output = comfy.clip_vision.Output()
|
||||
clip_vision_output.penultimate_hidden_states = states
|
||||
else:
|
||||
clip_vision_output = clip_vision_end_image
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanFunInpaintToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -122,38 +189,13 @@ class WanFunInpaintToVideo:
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
if end_image is not None:
|
||||
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
flfv = WanFirstLastFrameToVideo()
|
||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||
|
||||
image = torch.ones((length, height, width, 3)) * 0.5
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
|
||||
if start_image is not None:
|
||||
image[:start_image.shape[0]] = start_image
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
|
||||
if end_image is not None:
|
||||
image[-end_image.shape[0]:] = end_image
|
||||
mask[:, :, -end_image.shape[0]:] = 0.0
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user