Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski
2025-09-24 23:45:26 -07:00
29 changed files with 1931 additions and 79 deletions

View File

@@ -11,6 +11,7 @@ import json
import random
import hashlib
import node_helpers
import logging
from comfy.cli_args import args
from comfy.comfy_types import FileLocator
@@ -364,6 +365,216 @@ class RecordAudio:
return (audio, )
class TrimAudioDuration:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
},
}
FUNCTION = "trim"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Trim audio tensor into chosen time range."
def trim(self, audio, start_index, duration):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
if start_index < 0:
start_frame = audio_length + int(round(start_index * sample_rate))
else:
start_frame = int(round(start_index * sample_rate))
start_frame = max(0, min(start_frame, audio_length - 1))
end_frame = start_frame + int(round(duration * sample_rate))
end_frame = max(0, min(end_frame, audio_length))
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
class SplitAudioChannels:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO", "AUDIO")
RETURN_NAMES = ("left", "right")
FUNCTION = "separate"
CATEGORY = "audio"
DESCRIPTION = "Separates the audio into left and right channels."
def separate(self, audio):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[1] != 2:
raise ValueError("AudioSplit: Input audio has only one channel.")
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
if sample_rate_1 != sample_rate_2:
if sample_rate_1 > sample_rate_2:
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
output_sample_rate = sample_rate_1
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
else:
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
output_sample_rate = sample_rate_2
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
else:
output_sample_rate = sample_rate_1
return waveform_1, waveform_2, output_sample_rate
class AudioConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "concat"
CATEGORY = "audio"
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
def concat(self, audio1, audio2, direction):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
if waveform_1.shape[1] == 1:
waveform_1 = waveform_1.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
if waveform_2.shape[1] == 1:
waveform_2 = waveform_2.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
if direction == 'after':
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
class AudioMerge:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
},
}
FUNCTION = "merge"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
def merge(self, audio1, audio2, merge_method):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
length_1 = waveform_1.shape[-1]
length_2 = waveform_2.shape[-1]
if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1]
elif length_2 < length_1:
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
pad_shape = list(waveform_2.shape)
pad_shape[-1] = length_1 - length_2
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
if merge_method == "add":
waveform = waveform_1 + waveform_2
elif merge_method == "subtract":
waveform = waveform_1 - waveform_2
elif merge_method == "multiply":
waveform = waveform_1 * waveform_2
elif merge_method == "mean":
waveform = (waveform_1 + waveform_2) / 2
max_val = waveform.abs().max()
if max_val > 1.0:
waveform = waveform / max_val
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
class AudioAdjustVolume:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "adjust_volume"
CATEGORY = "audio"
def adjust_volume(self, audio, volume):
if volume == 0:
return (audio,)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
return ({"waveform": waveform, "sample_rate": sample_rate},)
class EmptyAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "create_empty_audio"
CATEGORY = "audio"
def create_empty_audio(self, duration, sample_rate, channels):
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
return ({"waveform": waveform, "sample_rate": sample_rate},)
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
@@ -375,6 +586,12 @@ NODE_CLASS_MAPPINGS = {
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
"TrimAudioDuration": TrimAudioDuration,
"SplitAudioChannels": SplitAudioChannels,
"AudioConcat": AudioConcat,
"AudioMerge": AudioMerge,
"AudioAdjustVolume": AudioAdjustVolume,
"EmptyAudio": EmptyAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -387,4 +604,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
"TrimAudioDuration": "Trim Audio Duration",
"SplitAudioChannels": "Split Audio Channels",
"AudioConcat": "Audio Concat",
"AudioMerge": "Audio Merge",
"AudioAdjustVolume": "Audio Adjust Volume",
"EmptyAudio": "Empty Audio",
}

View File

@@ -5,19 +5,30 @@ import torch
class DifferentialDiffusion():
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
}}
return {
"required": {
"model": ("MODEL", ),
},
"optional": {
"strength": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False
def apply(self, model):
def apply(self, model, strength=1.0):
model = model.clone()
model.set_model_denoise_mask_function(self.forward)
return (model,)
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
return (model, )
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"]
step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min
@@ -31,7 +42,15 @@ class DifferentialDiffusion():
threshold = (current_ts - ts_to) / (ts_from - ts_to)
return (denoise_mask >= threshold).to(denoise_mask.dtype)
# Generate the binary mask based on the threshold
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
# Blend binary mask with the original denoise_mask using strength
if strength and strength < 1:
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
return blended_mask
else:
return binary_mask
NODE_CLASS_MAPPINGS = {

View File

@@ -233,6 +233,7 @@ class Sharpen:
kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
kernel = kernel.to(dtype=image.dtype)
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

View File

@@ -43,6 +43,61 @@ class TextEncodeQwenImageEdit:
return (conditioning, )
class TextEncodeQwenImageEditPlus:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
},
"optional": {"vae": ("VAE", ),
"image1": ("IMAGE", ),
"image2": ("IMAGE", ),
"image3": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None):
ref_latents = []
images = [image1, image2, image3]
images_vl = []
llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
image_prompt = ""
for i, image in enumerate(images):
if image is not None:
samples = image.movedim(-1, 1)
total = int(384 * 384)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
images_vl.append(s.movedim(1, -1))
if vae is not None:
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 8.0) * 8
height = round(samples.shape[2] * scale_by / 8.0) * 8
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1)
tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template)
conditioning = clip.encode_from_tokens_scheduled(tokens)
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
return (conditioning, )
NODE_CLASS_MAPPINGS = {
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
"TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus,
}

View File

@@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
return new_dict
def process_cond_list(d, prefix=""):
if hasattr(d, "__iter__") and not hasattr(d, "items"):
for index, item in enumerate(d):
process_cond_list(item, f"{prefix}.{index}")
return d
elif hasattr(d, "items"):
for k, v in list(d.items()):
if isinstance(v, dict):
process_cond_list(v, f"{prefix}.{k}")
elif isinstance(v, torch.Tensor):
d[k] = v.clone()
elif isinstance(v, (list, tuple)):
for index, item in enumerate(v):
process_cond_list(item, f"{prefix}.{k}.{index}")
return d
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
@@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler):
self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()

View File

@@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode):
return io.Schema(
node_id="WanVaceToVideo",
category="conditioning/video_models",
is_experimental=True,
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@@ -375,7 +374,6 @@ class TrimVideoLatent(io.ComfyNode):
return io.Schema(
node_id="TrimVideoLatent",
category="latent/video",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.Int.Input("trim_amount", default=0, min=0, max=99999),
@@ -969,7 +967,6 @@ class WanSoundImageToVideo(io.ComfyNode):
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
@classmethod
@@ -1000,7 +997,6 @@ class WanSoundImageToVideoExtend(io.ComfyNode):
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
@classmethod
@@ -1095,10 +1091,6 @@ class WanHuMoImageToVideo(io.ComfyNode):
audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0)
# pad for ref latent
zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype)
audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
audio_emb = audio_emb.unsqueeze(0)
audio_emb_neg = torch.zeros_like(audio_emb)
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb})
@@ -1112,6 +1104,146 @@ class WanHuMoImageToVideo(io.ComfyNode):
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanAnimateToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanAnimateToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("face_video", optional=True),
io.Image.Input("pose_video", optional=True),
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Image.Input("background_video", optional=True),
io.Mask.Input("character_mask", optional=True),
io.Image.Input("continue_motion", optional=True),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_latent"),
io.Int.Output(display_name="trim_image"),
io.Int.Output(display_name="video_frame_offset"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None, background_video=None, character_mask=None) -> io.NodeOutput:
trim_to_pose_video = False
latent_length = ((length - 1) // 4) + 1
latent_width = width // 8
latent_height = height // 8
trim_latent = 0
if reference_image is None:
reference_image = torch.zeros((1, height, width, 3))
image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.zeros((1, 4, concat_latent_image.shape[-3], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
trim_latent += concat_latent_image.shape[2]
ref_motion_latent_length = 0
if continue_motion is None:
image = torch.ones((length, height, width, 3)) * 0.5
else:
continue_motion = continue_motion[-continue_motion_max_frames:]
video_frame_offset -= continue_motion.shape[0]
video_frame_offset = max(0, video_frame_offset)
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
image[:continue_motion.shape[0]] = continue_motion
ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
if pose_video.shape[0] <= video_frame_offset:
pose_video = None
else:
pose_video = pose_video[video_frame_offset:]
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
if not trim_to_pose_video:
if pose_video.shape[0] < length:
pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0)
pose_video_latent = vae.encode(pose_video[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent})
negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent})
if trim_to_pose_video:
latent_length = pose_video_latent.shape[2]
length = latent_length * 4 - 3
image = image[:length]
if face_video is not None:
if face_video.shape[0] <= video_frame_offset:
face_video = None
else:
face_video = face_video[video_frame_offset:]
if face_video is not None:
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
face_video = face_video.movedim(0, 1).unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
ref_images_num = max(0, ref_motion_latent_length * 4 - 3)
if background_video is not None:
if background_video.shape[0] > video_frame_offset:
background_video = background_video[video_frame_offset:]
background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
if background_video.shape[0] > ref_images_num:
image[ref_images_num:background_video.shape[0]] = background_video[ref_images_num:]
mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
if continue_motion is not None:
mask_refmotion[:, :, :ref_motion_latent_length * 4] = 0.0
if character_mask is not None:
if character_mask.shape[0] > video_frame_offset or character_mask.shape[0] == 1:
if character_mask.shape[0] == 1:
character_mask = character_mask.repeat((length,) + (1,) * (character_mask.ndim - 1))
else:
character_mask = character_mask[video_frame_offset:]
if character_mask.ndim == 3:
character_mask = character_mask.unsqueeze(1)
character_mask = character_mask.movedim(0, 1)
if character_mask.ndim == 4:
character_mask = character_mask.unsqueeze(1)
character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center")
if character_mask.shape[2] > ref_images_num:
mask_refmotion[:, :, ref_images_num:character_mask.shape[2]] = character_mask[:, :, ref_images_num:]
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
mask_refmotion = mask_refmotion.view(1, mask_refmotion.shape[2] // 4, 4, mask_refmotion.shape[3], mask_refmotion.shape[4]).transpose(1, 2)
mask = torch.cat((mask, mask_refmotion), dim=2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device())
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length)
class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -1173,6 +1305,7 @@ class WanExtension(ComfyExtension):
WanSoundImageToVideo,
WanSoundImageToVideoExtend,
WanHuMoImageToVideo,
WanAnimateToVideo,
Wan22ImageToVideoLatent,
]