mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-11 16:20:00 +00:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node."""
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@@ -8,11 +8,7 @@ from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
# from diffusers.models.modeling_utils import ModelMixin
|
||||
# from diffusers.loaders import FromOriginalModelMixin
|
||||
# from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
@@ -259,7 +255,7 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -269,7 +265,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -279,7 +275,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -294,7 +290,7 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -304,7 +300,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -314,7 +310,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -366,7 +362,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
@@ -386,7 +382,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
@@ -421,7 +417,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = weight_norm(
|
||||
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
|
||||
@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
|
||||
@@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
img = torch.cat([ref_latent, img], dim=-2)
|
||||
ref_latent_ids[..., 0] = -1
|
||||
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
|
||||
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
@@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||
def img_ids(self, x):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
@@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
return c_skip, c
|
||||
|
||||
|
||||
class WanCamAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
|
||||
super(WanCamAdapter, self).__init__()
|
||||
|
||||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||
|
||||
# Convolution: reduce spatial dimensions by a factor
|
||||
# of 2 (without overlap)
|
||||
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
# Residual blocks for feature extraction
|
||||
self.residual_blocks = nn.Sequential(
|
||||
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Reshape to merge the frame dimension into batch
|
||||
bs, c, f, h, w = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||
|
||||
# Pixel Unshuffle operation
|
||||
x_unshuffled = self.pixel_unshuffle(x)
|
||||
|
||||
# Convolution operation
|
||||
x_conv = self.conv(x_unshuffled)
|
||||
|
||||
# Feature extraction with residual blocks
|
||||
out = self.residual_blocks(x_conv)
|
||||
|
||||
# Reshape to restore original bf dimension
|
||||
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||
|
||||
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WanCamResidualBlock(nn.Module):
|
||||
def __init__(self, dim, operation_settings={}):
|
||||
super(WanCamResidualBlock, self).__init__()
|
||||
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||
@@ -637,3 +691,92 @@ class VaceWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class CameraWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='camera',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
image_model=None,
|
||||
in_dim_control_adapter=24,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
camera_conditions = None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@@ -286,6 +286,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -924,6 +924,10 @@ class HunyuanVideo(BaseModel):
|
||||
if guiding_frame_index is not None:
|
||||
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||
|
||||
ref_latent = kwargs.get("ref_latent", None)
|
||||
if ref_latent is not None:
|
||||
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
|
||||
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
@@ -1075,6 +1079,17 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
camera_conditions = kwargs.get("camera_conditions", None)
|
||||
if camera_conditions is not None:
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
|
||||
@@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "vace"
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
|
||||
@@ -308,10 +308,10 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
@@ -30,7 +30,7 @@ if RMSNorm is None:
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=None,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
|
||||
@@ -451,7 +451,7 @@ class VAE:
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
|
||||
@@ -992,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
|
||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "camera",
|
||||
"in_dim": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1129,6 +1139,6 @@ class ACEStep(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -78,8 +78,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user