Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data
2025-05-19 06:04:10 +09:00
50 changed files with 2760 additions and 684 deletions

View File

@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node."""
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
@classmethod
@abstractmethod

View File

@@ -8,11 +8,7 @@ from typing import Callable, Tuple, List
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
# from diffusers.models.modeling_utils import ModelMixin
# from diffusers.loaders import FromOriginalModelMixin
# from diffusers.configuration_utils import ConfigMixin, register_to_config
from .music_log_mel import LogMelSpectrogram
@@ -259,7 +255,7 @@ class ResBlock1(torch.nn.Module):
self.convs1 = nn.ModuleList(
[
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@@ -269,7 +265,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@@ -279,7 +275,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@@ -294,7 +290,7 @@ class ResBlock1(torch.nn.Module):
self.convs2 = nn.ModuleList(
[
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@@ -304,7 +300,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@@ -314,7 +310,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@@ -366,7 +362,7 @@ class HiFiGANGenerator(nn.Module):
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = weight_norm(
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
num_mels,
upsample_initial_channel,
@@ -386,7 +382,7 @@ class HiFiGANGenerator(nn.Module):
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
@@ -421,7 +417,7 @@ class HiFiGANGenerator(nn.Module):
self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation()
self.conv_post = weight_norm(
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
ch,
1,

View File

@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
return x
def WNConv1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":

View File

@@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
control=None,
transformer_options={},
) -> Tensor:
@@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
img = torch.cat([ref_latent, img], dim=-2)
ref_latent_ids[..., 0] = -1
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
@@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
img[:, : img_len] += add
img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
@@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
def img_ids(self, x):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out

View File

@@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
return c_skip, c
class WanCamAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
super(WanCamAdapter, self).__init__()
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
)
def forward(self, x):
# Reshape to merge the frame dimension into batch
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
# Pixel Unshuffle operation
x_unshuffled = self.pixel_unshuffle(x)
# Convolution operation
x_conv = self.conv(x_unshuffled)
# Feature extraction with residual blocks
out = self.residual_blocks(x_conv)
# Reshape to restore original bf dimension
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
out = out.permute(0, 2, 1, 3, 4)
return out
class WanCamResidualBlock(nn.Module):
def __init__(self, dim, operation_settings={}):
super(WanCamResidualBlock, self).__init__()
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.relu = nn.ReLU(inplace=True)
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
@@ -637,3 +691,92 @@ class VaceWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class CameraWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='camera',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
in_dim_control_adapter=24,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
camera_conditions = None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@@ -286,6 +286,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.ACEStep):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map

View File

@@ -924,6 +924,10 @@ class HunyuanVideo(BaseModel):
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
ref_latent = kwargs.get("ref_latent", None)
if ref_latent is not None:
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
return out
def scale_latent_inpaint(self, latent_image, **kwargs):
@@ -1075,6 +1079,17 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
camera_conditions = kwargs.get("camera_conditions", None)
if camera_conditions is not None:
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

View File

@@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@@ -308,10 +308,10 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)

View File

@@ -30,7 +30,7 @@ if RMSNorm is None:
def __init__(
self,
normalized_shape,
eps=None,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,

View File

@@ -451,7 +451,7 @@ class VAE:
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float32]
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else:

View File

@@ -992,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out
class WAN21_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera",
"in_dim": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1129,6 +1139,6 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]

View File

@@ -78,8 +78,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else: