diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index d72ece2ce..8f07a7b1c 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -20,7 +20,7 @@ jobs: git_tag: ${{ inputs.git_tag }} cache_tag: "cu130" python_minor: "13" - python_patch: "9" + python_patch: "11" rel_name: "nvidia" rel_extra_name: "" test_release: true @@ -65,11 +65,11 @@ jobs: contents: "write" packages: "write" pull-requests: "read" - name: "Release AMD ROCm 7.1.1" + name: "Release AMD ROCm 7.2" uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "rocm711" + cache_tag: "rocm72" python_minor: "12" python_patch: "10" rel_name: "amd" diff --git a/README.md b/README.md index c56e05d07..96dc2904b 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ comfy install ## Manual Install (Windows, Linux) -Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies. +Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported. Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 071b98332..0194b7d70 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -236,6 +236,8 @@ class ComfyNodeABC(ABC): """Flags a node as experimental, informing users that it may change or not work as expected.""" DEPRECATED: bool """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + DEV_ONLY: bool + """Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled.""" API_NODE: Optional[bool] """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index cb4f52ce1..38f18a83f 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -8,6 +8,7 @@ class LatentFormat: latent_rgb_factors_bias = None latent_rgb_factors_reshape = None taesd_decoder_name = None + spacial_downscale_ratio = 8 def process_in(self, latent): return latent * self.scale_factor @@ -181,6 +182,7 @@ class Flux(SD3): class Flux2(LatentFormat): latent_channels = 128 + spacial_downscale_ratio = 16 def __init__(self): self.latent_rgb_factors =[ @@ -592,6 +594,7 @@ class Wan22(Wan21): class HunyuanImage21(LatentFormat): latent_channels = 64 latent_dimensions = 2 + spacial_downscale_ratio = 32 scale_factor = 0.75289 latent_rgb_factors = [ @@ -725,6 +728,7 @@ class HunyuanVideo15(LatentFormat): latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] latent_channels = 32 latent_dimensions = 3 + spacial_downscale_ratio = 16 scale_factor = 1.03682 taesd_decoder_name = "lighttaehy1_5" @@ -749,6 +753,7 @@ class ACEAudio(LatentFormat): class ChromaRadiance(LatentFormat): latent_channels = 3 + spacial_downscale_ratio = 1 def __init__(self): self.latent_rgb_factors = [ diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index c12ace241..2c6954ecd 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -18,12 +18,12 @@ class CompressedTimestep: def __init__(self, tensor: torch.Tensor, patches_per_frame: int): """ tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame - patches_per_frame: Number of spatial patches per frame (height * width in latent space) + patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression """ self.batch_size, num_tokens, self.feature_dim = tensor.shape # Check if compression is valid (num_tokens must be divisible by patches_per_frame) - if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: self.patches_per_frame = patches_per_frame self.num_frames = num_tokens // patches_per_frame @@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) def forward( - self, - x: Tuple[torch.Tensor, torch.Tensor], - v_context=None, - a_context=None, - attention_mask=None, - v_timestep=None, - a_timestep=None, - v_pe=None, - a_pe=None, - v_cross_pe=None, - a_cross_pe=None, - v_cross_scale_shift_timestep=None, - a_cross_scale_shift_timestep=None, - v_cross_gate_timestep=None, - a_cross_gate_timestep=None, - transformer_options=None, + self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, + v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module): run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + # video if run_vx: - vshift_msa, vscale_msa, vgate_msa = ( - self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) - ) - + # video self-attention + vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa - vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa - vx += self.attn2( - comfy.ldm.common_dit.rms_norm(vx), - context=v_context, - mask=attention_mask, - transformer_options=transformer_options, - ) - - del vshift_msa, vscale_msa, vgate_msa + del vshift_msa, vscale_msa + attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) + del norm_vx + # video cross-attention + vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] + vx.addcmul_(attn1_out, vgate_msa) + del vgate_msa, attn1_out + vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + # audio if run_ax: - ashift_msa, ascale_msa, agate_msa = ( - self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) - ) - + # audio self-attention + ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2))) norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa - ax += ( - self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) - * agate_msa - ) - ax += self.audio_attn2( - comfy.ldm.common_dit.rms_norm(ax), - context=a_context, - mask=attention_mask, - transformer_options=transformer_options, - ) + del ashift_msa, ascale_msa + attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + del norm_ax + # audio cross-attention + agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] + ax.addcmul_(attn1_out, agate_msa) + del agate_msa, attn1_out + ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) - del ashift_msa, ascale_msa, agate_msa - - # Audio - Video cross attention. + # video - audio cross attention. if run_a2v or run_v2a: - # norm3 vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) - ( - scale_ca_audio_hidden_states_a2v, - shift_ca_audio_hidden_states_a2v, - scale_ca_audio_hidden_states_v2a, - shift_ca_audio_hidden_states_v2a, - gate_out_v2a, - ) = self.get_av_ca_ada_values( - self.scale_shift_table_a2v_ca_audio, - ax.shape[0], - a_cross_scale_shift_timestep, - a_cross_gate_timestep, - ) - - ( - scale_ca_video_hidden_states_a2v, - shift_ca_video_hidden_states_a2v, - scale_ca_video_hidden_states_v2a, - shift_ca_video_hidden_states_v2a, - gate_out_a2v, - ) = self.get_av_ca_ada_values( - self.scale_shift_table_a2v_ca_video, - vx.shape[0], - v_cross_scale_shift_timestep, - v_cross_gate_timestep, - ) - + # audio to video cross attention if run_a2v: - vx_scaled = ( - vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) - + shift_ca_video_hidden_states_a2v - ) - ax_scaled = ( - ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) - + shift_ca_audio_hidden_states_a2v - ) - vx += ( - self.audio_to_video_attn( - vx_scaled, - context=ax_scaled, - pe=v_cross_pe, - k_pe=a_cross_pe, - transformer_options=transformer_options, - ) - * gate_out_a2v - ) + scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2] + scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2] - del gate_out_a2v - del scale_ca_video_hidden_states_a2v,\ - shift_ca_video_hidden_states_a2v,\ - scale_ca_audio_hidden_states_a2v,\ - shift_ca_audio_hidden_states_a2v,\ + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v + a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options) + del vx_scaled, ax_scaled + + gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0] + vx.addcmul_(a2v_out, gate_out_a2v) + del gate_out_a2v, a2v_out + + # video to audio cross attention if run_v2a: - ax_scaled = ( - ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) - + shift_ca_audio_hidden_states_v2a - ) - vx_scaled = ( - vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) - + shift_ca_video_hidden_states_v2a - ) - ax += ( - self.video_to_audio_attn( - ax_scaled, - context=vx_scaled, - pe=a_cross_pe, - k_pe=v_cross_pe, - transformer_options=transformer_options, - ) - * gate_out_v2a - ) + scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4] + scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4] - del gate_out_v2a - del scale_ca_video_hidden_states_v2a,\ - shift_ca_video_hidden_states_v2a,\ - scale_ca_audio_hidden_states_v2a,\ - shift_ca_audio_hidden_states_v2a + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a + v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options) + del ax_scaled, vx_scaled + + gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0] + ax.addcmul_(v2a_out, gate_out_v2a) + del gate_out_v2a, v2a_out + + del vx_norm3, ax_norm3 + + # video feedforward if run_vx: - vshift_mlp, vscale_mlp, vgate_mlp = ( - self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None)) - ) - + vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5)) vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp - vx += self.ff(vx_scaled) * vgate_mlp - del vshift_mlp, vscale_mlp, vgate_mlp + del vshift_mlp, vscale_mlp + ff_out = self.ff(vx_scaled) + del vx_scaled + + vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0] + vx.addcmul_(ff_out, vgate_mlp) + del vgate_mlp, ff_out + + # audio feedforward if run_ax: - ashift_mlp, ascale_mlp, agate_mlp = ( - self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None)) - ) - + ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5)) ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp - ax += self.audio_ff(ax_scaled) * agate_mlp + del ashift_mlp, ascale_mlp - del ashift_mlp, ascale_mlp, agate_mlp + ff_out = self.audio_ff(ax_scaled) + del ax_scaled + agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0] + ax.addcmul_(ff_out, agate_mlp) + del agate_mlp, ff_out return vx, ax @@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel): audio_length = kwargs.get("audio_length", 0) # Separate audio and video latents vx, ax = self.separate_audio_and_video_latents(x, audio_length) + + has_spatial_mask = False + if denoise_mask is not None: + # check if any frame has spatial variation (inpainting) + for frame_idx in range(denoise_mask.shape[2]): + frame_mask = denoise_mask[0, 0, frame_idx] + if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max(): + has_spatial_mask = True + break + [vx, v_pixel_coords, additional_args] = super()._process_input( vx, keyframe_idxs, denoise_mask, **kwargs ) + additional_args["has_spatial_mask"] = has_spatial_mask ax, a_latent_coords = self.a_patchifier.patchify(ax) ax = self.audio_patchify_proj(ax) @@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel): # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width orig_shape = kwargs.get("orig_shape") + has_spatial_mask = kwargs.get("has_spatial_mask", None) v_patches_per_frame = None - if orig_shape is not None and len(orig_shape) == 5: + if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: # orig_shape[3] = height, orig_shape[4] = width (in latent space) v_patches_per_frame = orig_shape[3] * orig_shape[4] @@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel): ) # Compress cross-attention timesteps (only video side, audio is too small to benefit) + # v_patches_per_frame is None for spatial masks, set for temporal masks or no mask cross_av_timestep_ss = [ av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), - CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed - CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), ] diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b114d9e31..77d1abc97 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -451,6 +451,7 @@ class NextDiT(nn.Module): device=None, dtype=None, operations=None, + **kwargs, ) -> None: super().__init__() self.dtype = dtype diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 08315f1a8..fd125ceed 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from comfy.ldm.modules.diffusionmodules.model import vae_attention +from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed import comfy.ops ops = comfy.ops.disable_weight_init @@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) + self._padding = 2 * self.padding[0] + self.padding = (0, self.padding[1], self.padding[2]) def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): if cache_list is not None: cache_x = cache_list[cache_idx] cache_list[cache_idx] = None - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] + if cache_x is None and x.shape[2] == 1: + #Fast path - the op will pad for use by truncating the weight + #and save math on a pile of zeros. + return super().forward(x, autopad="causal_zero") + + if self._padding > 0: + padding_needed = self._padding + if cache_x is not None: + cache_x = cache_x.to(x.device) + padding_needed = max(0, padding_needed - cache_x.shape[2]) + padding_shape = list(x.shape) + padding_shape[2] = padding_needed + padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype) + x = torch_cat_if_needed([padding, cache_x, x], dim=2) del cache_x - x = F.pad(x, padding) return super().forward(x) @@ -472,10 +479,12 @@ class WanVAE(nn.Module): def encode(self, x): conv_idx = [0] - feat_map = [None] * count_conv3d(self.decoder) ## cache t = x.shape[2] iter_ = 1 + (t - 1) // 4 + feat_map = None + if iter_ > 1: + feat_map = [None] * count_conv3d(self.decoder) ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): conv_idx = [0] @@ -495,10 +504,11 @@ class WanVAE(nn.Module): def decode(self, z): conv_idx = [0] - feat_map = [None] * count_conv3d(self.decoder) # z: [b,c,t,h,w] - iter_ = z.shape[2] + feat_map = None + if iter_ > 1: + feat_map = [None] * count_conv3d(self.decoder) x = self.conv2(z) for i in range(iter_): conv_idx = [0] diff --git a/comfy/lora.py b/comfy/lora.py index e8246bd66..7b31d055c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + key_map[k[:-len(".weight")]] = to #DiffSynth lora format for k in sdk: hidden_size = model.model_config.unet_config.get("hidden_size", 0) if k.endswith(".weight") and ".linear1." in k: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b29a033cc..8cea16e50 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) dit_config["z_image_modulation"] = True dit_config["time_scale"] = 1000.0 + try: + dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42 + except Exception: + pass if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) diff --git a/comfy/ops.py b/comfy/ops.py index 415c39e92..e406ba7ed 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -203,7 +203,9 @@ class disable_weight_init: def reset_parameters(self): return None - def _conv_forward(self, input, weight, bias, *args, **kwargs): + def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs): + if autopad == "causal_zero": + weight = weight[:, :, -input.shape[2]:, :, :] if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16): out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True) if bias is not None: @@ -212,15 +214,15 @@ class disable_weight_init: else: return super()._conv_forward(input, weight, bias, *args, **kwargs) - def forward_comfy_cast_weights(self, input): + def forward_comfy_cast_weights(self, input, autopad=None): weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._conv_forward(input, weight, bias) + x = self._conv_forward(input, weight, bias, autopad=autopad) uncast_bias_weight(self, weight, bias, offload_stream) return x def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) diff --git a/comfy/sample.py b/comfy/sample.py index 2f8f3a51c..a2a39b527 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None): return noises -def fix_empty_latent_channels(model, latent_image): +def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels - if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if torch.count_nonzero(latent_image) == 0: + if latent_format.latent_channels != latent_image.shape[1]: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if downscale_ratio_spacial is not None: + if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: + ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio + latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled") + if latent_format.latent_dimensions == 3 and latent_image.ndim == 4: latent_image = latent_image.unsqueeze(2) return latent_image diff --git a/comfy/sd.py b/comfy/sd.py index ce7e6bcff..f627f7d55 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert +import comfy.weight_adapter import yaml import math import os @@ -101,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): return (new_modelpatcher, new_clip) +def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip): + """ + Load LoRA in bypass mode without modifying base model weights. + + Instead of patching weights, this injects the LoRA computation into the + forward pass: output = base_forward(x) + lora_path(x) + + Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches. + + This is useful for training and when model weights are offloaded. + """ + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries") + + lora = comfy.lora_convert.convert_lora(lora) + loaded = comfy.lora.load_lora(lora, key_map) + + logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries") + + # Separate adapters (for bypass) from other patches (for regular patching) + bypass_patches = {} # WeightAdapterBase instances -> bypass mode + regular_patches = {} # diff, set, bias patches -> regular weight patching + + for key, patch_data in loaded.items(): + if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase): + bypass_patches[key] = patch_data + else: + regular_patches[key] = patch_data + + logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches") + + k = set() + k1 = set() + + if model is not None: + new_modelpatcher = model.clone() + + # Apply regular patches (bias diff, weight diff, etc.) via normal patching + if regular_patches: + patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model) + k.update(patched_keys) + + # Apply adapter patches via bypass injection + manager = comfy.weight_adapter.BypassInjectionManager() + model_sd_keys = set(new_modelpatcher.model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in model_sd_keys: + manager.add_adapter(key, adapter, strength=strength_model) + k.add(key) + else: + logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}") + + injections = manager.create_injections(new_modelpatcher.model) + + if manager.get_hook_count() > 0: + new_modelpatcher.set_injections("bypass_lora", injections) + else: + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + + # Apply regular patches to clip + if regular_patches: + patched_keys = new_clip.add_patches(regular_patches, strength_clip) + k1.update(patched_keys) + + # Apply adapter patches via bypass injection + clip_manager = comfy.weight_adapter.BypassInjectionManager() + clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in clip_sd_keys: + clip_manager.add_adapter(key, adapter, strength=strength_clip) + k1.add(key) + + clip_injections = clip_manager.create_injections(new_clip.cond_stage_model) + if clip_manager.get_hook_count() > 0: + new_clip.patcher.set_injections("bypass_lora", clip_injections) + else: + new_clip = None + + for x in loaded: + if (x not in k) and (x not in k1): + patch_data = loaded[x] + patch_type = type(patch_data).__name__ + if isinstance(patch_data, tuple): + patch_type = f"tuple({patch_data[0]})" + logging.warning(f"NOT LOADED: {x} (type={patch_type})") + + return (new_modelpatcher, new_clip) + + class CLIP: def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c512ca5d0..d4f22120b 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, start_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -479,8 +479,15 @@ class SDTokenizer: empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] + if len(empty) > 0: + self.tokens_start = 1 + self.start_token = empty[0] + else: + self.tokens_start = 0 + self.start_token = start_token + if start_token is None: + logging.warning("WARNING: There's something wrong with your tokenizers.'") + if end_token is not None: self.end_token = end_token else: @@ -488,7 +495,7 @@ class SDTokenizer: self.end_token = empty[1] else: self.tokens_start = 0 - self.start_token = None + self.start_token = start_token if end_token is not None: self.end_token = end_token else: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 45d913fa6..d25271d6e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1093,7 +1093,7 @@ class ZImage(Lumina2): def __init__(self, unet_config): super().__init__(unet_config) - if comfy.model_management.extended_fp16_support(): + if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False): self.supported_inference_dtypes = self.supported_inference_dtypes.copy() self.supported_inference_dtypes.insert(1, torch.float16) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 4075afca4..f67a5f805 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -118,7 +118,7 @@ class MistralTokenizerClass: class Mistral3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index b40f920e4..b9fa8d5cf 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -5,6 +5,11 @@ from .lokr import LoKrAdapter from .glora import GLoRAAdapter from .oft import OFTAdapter from .boft import BOFTAdapter +from .bypass import ( + BypassInjectionManager, + BypassForwardHook, + create_bypass_injections_from_patches, +) adapters: list[type[WeightAdapterBase]] = [ @@ -31,4 +36,7 @@ __all__ = [ "WeightAdapterTrainBase", "adapters", "adapter_maps", + "BypassInjectionManager", + "BypassForwardHook", + "create_bypass_injections_from_patches", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 43644b106..bce89a0e2 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch import torch.nn as nn @@ -7,12 +7,35 @@ import comfy.model_management class WeightAdapterBase: + """ + Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT/BOFT: g = transform, h = 0 + """ + name: str loaded_keys: set[str] weights: list[torch.Tensor] + # Attributes set by bypass system + multiplier: float = 1.0 + shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel) + @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + ) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": @@ -39,18 +62,202 @@ class WeightAdapterBase: ): raise NotImplementedError + # ===== Bypass Mode Methods ===== + # + # IMPORTANT: Bypass mode is designed for quantized models where original weights + # may not be accessible in a usable format. Therefore, h() and bypass_forward() + # do NOT take org_weight as a parameter. All necessary information (out_channels, + # in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook. + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT/BOFT), returns zeros. + + Note: + This method does NOT access original model weights. Bypass mode is + designed for quantized models where weights may not be in a usable format. + All shape info comes from module attributes set by BypassForwardHook. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Reference: LyCORIS LoConModule.bypass_forward_diff + """ + # Default: no additive component (for OFT/BOFT) + # Simply return zeros matching base_out shape + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT/BOFT override this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + + Reference: LyCORIS OFTModule applies orthogonal transform here + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Note: + This method does NOT take org_weight/org_bias parameters. Bypass mode + is designed for quantized models where weights may not be accessible. + The original forward function handles weight access internally. + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + + Reference: LyCORIS LoConModule.bypass_forward + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + class WeightAdapterTrainBase(nn.Module): - # We follow the scheme of PR #7032 + """ + Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT: g = transform, h = 0 + + Note: + Unlike WeightAdapterBase, TrainBase classes have simplified weight formats + with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition). + + We follow the scheme of PR #7032 + """ + + # Attributes set by bypass system (BypassForwardHook) + # These are set before h()/g()/bypass_forward() are called + multiplier: float = 1.0 + is_conv: bool = False + conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d + kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups + kernel_size: tuple = () + in_channels: int = None + out_channels: int = None + def __init__(self): super().__init__() def __call__(self, w): """ - w: The original weight tensor to be modified. + Weight modification mode: returns modified weight. + + Args: + w: The original weight tensor to be modified. + + Returns: + Modified weight tensor. """ raise NotImplementedError + # ===== Bypass Mode Methods ===== + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT), returns zeros. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Subclasses should override this method. + """ + raise NotImplementedError( + f"{self.__class__.__name__}.h() not implemented. " + "Subclasses must implement h() for bypass mode." + ) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT overrides this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + def passive_memory_usage(self): raise NotImplementedError("passive_memory_usage is not implemented") @@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module): return self.passive_memory_usage() -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) +def weight_decompose( + dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function +): + dora_scale = comfy.model_management.cast_to_device( + dora_scale, weight.device, intermediate_dtype + ) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) @@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten the original tensor will be truncated in that dimension. """ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): - raise ValueError("The new shape must be larger than the original tensor in all dimensions") + raise ValueError( + "The new shape must be larger than the original tensor in all dimensions" + ) if len(new_shape) != len(tensor.shape): - raise ValueError("The new shape must have the same number of dimensions as the original tensor") + raise ValueError( + "The new shape must have the same number of dimensions as the original tensor" + ) # Create a new tensor filled with zeros padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index b2a2f1bd4..02a8dc130 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase): alpha = v[2] dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) boft_m, block_num, boft_b, *_ = blocks.shape @@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(-1, -2) normed_q = q - if alpha > 0: # alpha in boft/bboft is for constraint + if alpha > 0: # alpha in boft/bboft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm @@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase): r = r.to(weight) inp = org = weight - r_b = boft_b//2 + r_b = boft_b // 2 for i in range(boft_m): bi = r[i] g = 2 k = 2**i * r_b if strength != 1: - bi = bi * strength + (1-strength) * I + bi = bi * strength + (1 - strength) * I inp = ( inp.unflatten(0, (-1, g, k)) .transpose(1, 2) @@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase): ) inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp) inp = ( - inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2) + inp.flatten(0, 1) + .unflatten(0, (-1, k, g)) + .transpose(1, 2) + .flatten(0, 2) ) if rescale is not None: inp = inp * rescale lora_diff = inp - org - lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) + lora_diff = comfy.model_management.cast_to_device( + lora_diff, weight.device, intermediate_dtype + ) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrices(self, device, dtype): + """Compute the orthogonal rotation matrices R from BOFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + boft_m, block_num, boft_b, _ = blocks.shape + I = torch.eye(boft_b, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(-1, -2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, boft_m, boft_b + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for BOFT: applies butterfly orthogonal transform. + + BOFT uses multiple stages of butterfly-structured orthogonal transforms. + + Reference: LyCORIS ButterflyOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype) + r_b = boft_b // 2 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(boft_b, device=y.device, dtype=y.dtype) + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # Apply butterfly transform stages + inp = y + for i in range(boft_m): + bi = r[i] # (block_num, boft_b, boft_b) + g = 2 + k = 2**i * r_b + + # Interpolate with identity based on multiplier + if multiplier != 1: + bi = bi * multiplier + (1 - multiplier) * I + + # Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, boft_b)) + ) + # Apply block-diagonal orthogonal transform + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + # Reshape back + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + inp = inp * rescale.transpose(0, -1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + inp = inp.transpose(1, -1) + + return inp diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py new file mode 100644 index 000000000..d4aaf98ca --- /dev/null +++ b/comfy/weight_adapter/bypass.py @@ -0,0 +1,437 @@ +""" +Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.) + +Bypass mode applies adapters during forward pass without modifying base weights: + bypass(f)(x) = g(f(x) + h(x)) + +Where: + - f(x): Original layer forward + - h(x): Additive component from adapter (LoRA path) + - g(y): Output transformation (identity for most adapters) + +This is useful for: + - Training with gradient checkpointing + - Avoiding weight modifications when weights are offloaded + - Supporting multiple adapters with different strengths dynamically +""" + +import logging +from typing import Optional, Union + +import torch +import torch.nn as nn + +from .base import WeightAdapterBase, WeightAdapterTrainBase +from comfy.patcher_extension import PatcherInjection + +# Type alias for adapters that support bypass mode +BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] + + +def get_module_type_info(module: nn.Module) -> dict: + """ + Determine module type and extract conv parameters from module class. + + This is more reliable than checking weight.ndim, especially for quantized layers + where weight shape might be different. + + Returns: + dict with keys: is_conv, conv_dim, stride, padding, dilation, groups + """ + info = { + "is_conv": False, + "conv_dim": 0, + "stride": (1,), + "padding": (0,), + "dilation": (1,), + "groups": 1, + "kernel_size": (1,), + "in_channels": None, + "out_channels": None, + } + + # Determine conv type + if isinstance(module, nn.Conv1d): + info["is_conv"] = True + info["conv_dim"] = 1 + elif isinstance(module, nn.Conv2d): + info["is_conv"] = True + info["conv_dim"] = 2 + elif isinstance(module, nn.Conv3d): + info["is_conv"] = True + info["conv_dim"] = 3 + elif isinstance(module, nn.Linear): + info["is_conv"] = False + info["conv_dim"] = 0 + else: + # Try to infer from class name for custom/quantized layers + class_name = type(module).__name__.lower() + if "conv3d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 3 + elif "conv2d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + elif "conv1d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 1 + elif "conv" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + + # Extract conv parameters if it's a conv layer + if info["is_conv"]: + # Try to get stride, padding, dilation, groups, kernel_size from module + info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"]) + info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"]) + info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"]) + info["groups"] = getattr(module, "groups", 1) + info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"]) + info["in_channels"] = getattr(module, "in_channels", None) + info["out_channels"] = getattr(module, "out_channels", None) + + # Ensure they're tuples + if isinstance(info["stride"], int): + info["stride"] = (info["stride"],) * info["conv_dim"] + if isinstance(info["padding"], int): + info["padding"] = (info["padding"],) * info["conv_dim"] + if isinstance(info["dilation"], int): + info["dilation"] = (info["dilation"],) * info["conv_dim"] + if isinstance(info["kernel_size"], int): + info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"] + + return info + + +class BypassForwardHook: + """ + Hook that wraps a layer's forward to apply adapter in bypass mode. + + Stores the original forward and replaces it with bypass version. + + Supports both: + - WeightAdapterBase: Inference adapters (uses self.weights tuple) + - WeightAdapterTrainBase: Training adapters (nn.Module with parameters) + """ + + def __init__( + self, + module: nn.Module, + adapter: BypassAdapter, + multiplier: float = 1.0, + ): + self.module = module + self.adapter = adapter + self.multiplier = multiplier + self.original_forward = None + + # Determine layer type and conv params from module class (works for quantized layers) + module_info = get_module_type_info(module) + + # Set multiplier and layer type info on adapter for use in h() + adapter.multiplier = multiplier + adapter.is_conv = module_info["is_conv"] + adapter.conv_dim = module_info["conv_dim"] + adapter.kernel_size = module_info["kernel_size"] + adapter.in_channels = module_info["in_channels"] + adapter.out_channels = module_info["out_channels"] + # Store kw_dict for conv operations (like LyCORIS extra_args) + if module_info["is_conv"]: + adapter.kw_dict = { + "stride": module_info["stride"], + "padding": module_info["padding"], + "dilation": module_info["dilation"], + "groups": module_info["groups"], + } + else: + adapter.kw_dict = {} + + def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x)) + + Note: + Bypass mode does NOT access original model weights (org_weight). + This is intentional - bypass mode is designed for quantized models + where weights may not be in a usable format. All necessary shape + information is provided via adapter attributes set during inject(). + """ + # Check if adapter has custom bypass_forward (e.g., GLoRA) + adapter_bypass = getattr(self.adapter, "bypass_forward", None) + if adapter_bypass is not None: + # Check if it's overridden (not the base class default) + # Need to check both base classes since adapter could be either type + adapter_type = type(self.adapter) + is_default_bypass = ( + adapter_type.bypass_forward is WeightAdapterBase.bypass_forward + or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward + ) + if not is_default_bypass: + return adapter_bypass(self.original_forward, x, *args, **kwargs) + + # Default bypass: g(f(x) + h(x, f(x))) + base_out = self.original_forward(x, *args, **kwargs) + h_out = self.adapter.h(x, base_out) + return self.adapter.g(base_out + h_out) + + def inject(self): + """Replace module forward with bypass version.""" + if self.original_forward is not None: + logging.debug( + f"[BypassHook] Already injected for {type(self.module).__name__}" + ) + return # Already injected + + # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward + device = None + dtype = None + if hasattr(self.module, "weight") and self.module.weight is not None: + device = self.module.weight.device + dtype = self.module.weight.dtype + elif hasattr(self.module, "W_q"): # Quantized layers might use different attr + device = self.module.W_q.device + dtype = self.module.W_q.dtype + + if device is not None: + self._move_adapter_weights_to_device(device, dtype) + + self.original_forward = self.module.forward + self.module.forward = self._bypass_forward + logging.debug( + f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" + ) + + def _move_adapter_weights_to_device(self, device, dtype=None): + """Move adapter weights to specified device to avoid per-forward transfers. + + Handles both: + - WeightAdapterBase: has self.weights tuple of tensors + - WeightAdapterTrainBase: nn.Module with parameters, uses .to() method + """ + adapter = self.adapter + + # Check if adapter is an nn.Module (WeightAdapterTrainBase) + if isinstance(adapter, nn.Module): + # In training mode we don't touch dtype as trainer will handle it + adapter.to(device=device) + logging.debug( + f"[BypassHook] Moved training adapter (nn.Module) to {device}" + ) + return + + # WeightAdapterBase: handle self.weights tuple + if not hasattr(adapter, "weights") or adapter.weights is None: + return + + weights = adapter.weights + if isinstance(weights, (list, tuple)): + new_weights = [] + for w in weights: + if isinstance(w, torch.Tensor): + if dtype is not None: + new_weights.append(w.to(device=device, dtype=dtype)) + else: + new_weights.append(w.to(device=device)) + else: + new_weights.append(w) + adapter.weights = ( + tuple(new_weights) if isinstance(weights, tuple) else new_weights + ) + elif isinstance(weights, torch.Tensor): + if dtype is not None: + adapter.weights = weights.to(device=device, dtype=dtype) + else: + adapter.weights = weights.to(device=device) + + logging.debug(f"[BypassHook] Moved adapter weights to {device}") + + def eject(self): + """Restore original module forward.""" + if self.original_forward is None: + logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") + return # Not injected + + self.module.forward = self.original_forward + self.original_forward = None + logging.debug( + f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" + ) + + +class BypassInjectionManager: + """ + Manages bypass mode injection for a collection of adapters. + + Creates PatcherInjection objects that can be used with ModelPatcher. + + Supports both inference adapters (WeightAdapterBase) and training adapters + (WeightAdapterTrainBase). + + Usage: + manager = BypassInjectionManager() + manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8) + manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8) + + injections = manager.create_injections(model) + model_patcher.set_injections("bypass_lora", injections) + """ + + def __init__(self): + self.adapters: dict[str, tuple[BypassAdapter, float]] = {} + self.hooks: list[BypassForwardHook] = [] + + def add_adapter( + self, + key: str, + adapter: BypassAdapter, + strength: float = 1.0, + ): + """ + Add an adapter for a specific weight key. + + Args: + key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight") + adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.) + strength: Multiplier for adapter effect + """ + # Remove .weight suffix if present for module lookup + module_key = key + if module_key.endswith(".weight"): + module_key = module_key[:-7] + logging.debug( + f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" + ) + + self.adapters[module_key] = (adapter, strength) + logging.debug( + f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" + ) + + def clear_adapters(self): + """Remove all adapters.""" + self.adapters.clear() + + def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]: + """Get a submodule by dot-separated key.""" + parts = key.split(".") + module = model + try: + for i, part in enumerate(parts): + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + logging.debug( + f"[BypassManager] Found module for key {key}: {type(module).__name__}" + ) + return module + except (AttributeError, IndexError, KeyError) as e: + logging.error(f"[BypassManager] Failed to find module for key {key}: {e}") + logging.error( + f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}" + ) + return None + + def create_injections(self, model: nn.Module) -> list[PatcherInjection]: + """ + Create PatcherInjection objects for all registered adapters. + + Args: + model: The model to inject into (e.g., model_patcher.model) + + Returns: + List of PatcherInjection objects to use with model_patcher.set_injections() + """ + self.hooks.clear() + + logging.debug( + f"[BypassManager] create_injections called with {len(self.adapters)} adapters" + ) + logging.debug(f"[BypassManager] Model type: {type(model).__name__}") + + for key, (adapter, strength) in self.adapters.items(): + logging.debug(f"[BypassManager] Looking for module: {key}") + module = self._get_module_by_key(model, key) + + if module is None: + logging.warning(f"[BypassManager] Module not found for key {key}") + continue + + if not hasattr(module, "weight"): + logging.warning( + f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})" + ) + continue + + logging.debug( + f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" + ) + hook = BypassForwardHook(module, adapter, multiplier=strength) + self.hooks.append(hook) + + logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") + + # Create single injection that manages all hooks + def inject_all(model_patcher): + logging.debug( + f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.inject() + logging.debug( + f"[BypassManager] Injected hook for {type(hook.module).__name__}" + ) + + def eject_all(model_patcher): + logging.debug( + f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.eject() + + return [PatcherInjection(inject=inject_all, eject=eject_all)] + + def get_hook_count(self) -> int: + """Return number of hooks that will be/are injected.""" + return len(self.hooks) + + +def create_bypass_injections_from_patches( + model: nn.Module, + patches: dict, + strength: float = 1.0, +) -> list[PatcherInjection]: + """ + Convenience function to create bypass injections from a patches dict. + + This is useful when you have patches in the format used by model_patcher.add_patches() + and want to apply them in bypass mode instead. + + Args: + model: The model to inject into + patches: Dict mapping weight keys to adapter data + strength: Global strength multiplier + + Returns: + List of PatcherInjection objects + """ + manager = BypassInjectionManager() + + for key, patch_list in patches.items(): + if not patch_list: + continue + + # patches format: list of (strength_patch, patch_data, strength_model, offset, function) + for patch in patch_list: + patch_strength, patch_data, strength_model, offset, function = patch + + # patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple + if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)): + adapter = patch_data + else: + # Skip non-adapter patches + continue + + combined_strength = strength * patch_strength + manager.add_adapter(key, adapter, strength=combined_strength) + + return manager.create_injections(model) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 939abbba5..d6b97a23b 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,7 +1,8 @@ import logging -from typing import Optional +from typing import Callable, Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, weight_decompose @@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + weights = ( + lora[a1_name], + lora[a2_name], + lora[b1_name], + lora[b2_name], + alpha, + dora_scale, + ) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase): old_glora = True if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + if ( + old_glora + and v[1].shape[0] == weight.shape[0] + and weight.shape[0] == weight.shape[1] + ): pass else: old_glora = False rank = v[1].shape[0] - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + a1 = comfy.model_management.cast_to_device( + v[0].flatten(start_dim=1), weight.device, intermediate_dtype + ) + a2 = comfy.model_management.cast_to_device( + v[1].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b1 = comfy.model_management.cast_to_device( + v[2].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b2 = comfy.model_management.cast_to_device( + v[3].flatten(start_dim=1), weight.device, intermediate_dtype + ) if v[4] is not None: alpha = v[4] / rank @@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase): try: if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + lora_diff = ( + torch.mm(b2, b1) + + torch.mm( + torch.mm( + weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2 + ), + a1, + ) + ).reshape( + weight.shape + ) # old lycoris glora else: if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.einsum( + "o i ..., i j -> o j ...", + torch.einsum( + "o i ..., i j -> o j ...", + weight.to(dtype=intermediate_dtype), + a1, + ), + a2, + ).reshape(weight.shape) else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.mm( + torch.mm(weight.to(dtype=intermediate_dtype), a1), a2 + ).reshape(weight.shape) lora_diff += torch.mm(b1, b2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _compute_paths(self, x: torch.Tensor): + """ + Compute A path and B path outputs for GLoRA bypass. + + GLoRA: f(x) = Wx + WAx + Bx + - A path: a1(a2(x)) - modifies input to base forward + - B path: b1(b2(x)) - additive component + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Returns: (a_out, b_out) + """ + v = self.weights + # v = (a1, a2, b1, b2, alpha, dora_scale) + a1 = v[0] + a2 = v[1] + b1 = v[2] + b2 = v[3] + alpha = v[4] + + dtype = x.dtype + + # Cast dtype (weights should already be on correct device from inject()) + a1 = a1.to(dtype=dtype) + a2 = a2.to(dtype=dtype) + b1 = b1.to(dtype=dtype) + b2 = b2.to(dtype=dtype) + + # Determine rank and scale + # Check for old vs new glora format + old_glora = False + if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]: + rank = a1.shape[0] + old_glora = True + + if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]: + if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]: + pass + else: + old_glora = False + rank = a2.shape[0] + + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + scale = scale * multiplier + + # Use module info from bypass injection, not input tensor shape + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + # Conv case - conv_dim is 1/2/3 for conv1d/2d/3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Get module's stride/padding for spatial dimension handling + module_stride = kw_dict.get("stride", (1,) * conv_dim) + module_padding = kw_dict.get("padding", (0,) * conv_dim) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Ensure weights are in conv shape + # a1, a2, b1 are always 1x1 kernels + if a1.ndim == 2: + a1 = a1.view(*a1.shape, *([1] * conv_dim)) + if a2.ndim == 2: + a2 = a2.view(*a2.shape, *([1] * conv_dim)) + if b1.ndim == 2: + b1 = b1.view(*b1.shape, *([1] * conv_dim)) + # b2 has actual kernel_size (like LoRA down) + if b2.ndim == 2: + if in_channels is not None: + b2 = b2.view(b2.shape[0], in_channels, *kernel_size) + else: + b2 = b2.view(*b2.shape, *([1] * conv_dim)) + + # A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x + a2_out = conv_fn(x, a2) + a_out = conv_fn(a2_out, a1) * scale + + # B path: b2(x) with kernel/stride/padding -> b1(...) 1x1 + b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding) + b_out = conv_fn(b2_out, b1) * scale + else: + # Linear case + if old_glora: + # Old format: a1 @ a2 @ x, b2 @ b1 + a_out = F.linear(F.linear(x, a2), a1) * scale + b_out = F.linear(F.linear(x, b1), b2) * scale + else: + # New format: x @ a1 @ a2, b1 @ b2 + a_out = F.linear(F.linear(x, a1), a2) * scale + b_out = F.linear(F.linear(x, b2), b1) * scale + + return a_out, b_out + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + GLoRA bypass forward: f(x + a(x)) + b(x) + + Unlike standard adapters, GLoRA modifies the input to the base forward + AND adds the B path output. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Reference: LyCORIS GLoRAModule._bypass_forward + """ + a_out, b_out = self._compute_paths(x) + + # Call base forward with modified input + base_out = org_forward(x + a_out, *args, **kwargs) + + # Add B path + return base_out + b_out + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + For GLoRA, h() returns the B path output. + + Note: + GLoRA's full bypass requires overriding bypass_forward() since + it also modifies the input to org_forward. This h() is provided for + compatibility but bypass_forward() should be used for correct behavior. + + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + _, b_out = self._compute_paths(x) + return b_out diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 0abb2d403..8007b7b44 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,11 +1,22 @@ import logging +from functools import cache from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose +@cache +def _warn_loha_bypass_inefficient(): + """One-time warning about LoHa bypass inefficiency.""" + logging.warning( + "LoHa bypass mode is inefficient: full weight diff is computed each forward pass. " + "Consider using LoRA or LoKr for training with bypass mode." + ) + + class HadaWeight(torch.autograd.Function): @staticmethod def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): @@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase): scale = self.alpha / self.rank if self.use_tucker: - diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeightTucker.apply( + self.hada_t1, + self.hada_w1_a, + self.hada_w1_b, + self.hada_t2, + self.hada_w2_a, + self.hada_w2_b, + scale, + ) else: - diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeight.apply( + self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale + ) # Add the scaled difference to the original weight weight = w.to(diff_weight) + diff_weight.reshape(w.shape) @@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase): mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) - return LohaDiff( - (mat1, mat2, alpha, mat3, mat4, None, None, None) - ) + return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None)) def to_train(self): return LohaDiff(self.weights) @@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + weights = ( + lora[hada_w1_a_name], + lora[hada_w1_b_name], + alpha, + lora[hada_w2_a_name], + lora[hada_w2_b_name], + hada_t1, + hada_t2, + dora_scale, + ) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase): w2a = v[3] w2b = v[4] dora_scale = v[7] - if v[5] is not None: #cp decomposition + if v[5] is not None: # cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + m1 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t1, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + ) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + m2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + ) else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + m1 = torch.mm( + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + ) + m2 = torch.mm( + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + ) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoHa: h(x) = diff_weight @ x + + WARNING: Inefficient - computes full Hadamard product each forward. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/loha.py bypass_forward_diff + """ + _warn_loha_bypass_inefficient() + + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora + w1a = v[0] + w1b = v[1] + alpha = v[2] + w2a = v[3] + w2b = v[4] + t1 = v[5] + t2 = v[6] + + # Compute scale + rank = w1b.shape[0] + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Cast dtype + w1a = w1a.to(dtype=x.dtype) + w1b = w1b.to(dtype=x.dtype) + w2a = w2a.to(dtype=x.dtype) + w2b = w2b.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Compute diff weight using Hadamard product + if t1 is not None and t2 is not None: + t1 = t1.to(dtype=x.dtype) + t2 = t2.to(dtype=x.dtype) + m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a) + m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a) + diff_weight = (m1 * m2) * scale + else: + m1 = w1a @ w1b + m2 = w2a @ w2b + diff_weight = (m1 * m2) * scale + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D diff_weight to conv format using kernel_size + # diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size] + if diff_weight.dim() == 2: + if in_channels is not None: + diff_weight = diff_weight.view( + diff_weight.shape[0], in_channels, *kernel_size + ) + else: + diff_weight = diff_weight.view( + *diff_weight.shape, *([1] * conv_dim) + ) + else: + op = F.linear + kw_dict = {} + + return op(x, diff_weight, **kw_dict) diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 9b2aff2d7..b83750012 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -14,7 +15,17 @@ from .base import ( class LokrDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) = weights self.use_tucker = False if lokr_w1_a is not None: _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] @@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase): if self.w2_rebuild: if self.use_tucker: w2 = torch.einsum( - 'i j k l, j r, i p -> p r k l', + "i j k l, j r, i p -> p r k l", self.lokr_t2, self.lokr_w2_b, - self.lokr_w2_a + self.lokr_w2_a, ) else: w2 = self.lokr_w2_a @ self.lokr_w2_b @@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase): return self.lokr_w2 def __call__(self, w): - diff = torch.kron(self.w1, self.w2) + w1 = self.w1 + w2 = self.w2 + # Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron) + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + diff = torch.kron(w1, w2) return w + diff.reshape(w.shape).to(w) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr training: efficient Kronecker product. + + Uses w1/w2 properties which handle both direct and decomposed cases. + For create_train (direct w1/w2), no alpha scaling in properties. + For to_train (decomposed), alpha/rank scaling is in properties. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Get w1, w2 from properties (handles rebuild vs direct) + w1 = self.w1 + w2 = self.w2 + + # Multiplier from bypass injection + multiplier = getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Efficient Kronecker application without materializing full weight + # kron(w1, w2) @ x can be computed as nested operations + # w1: [out_l, in_m], w2: [out_k, in_n, *k_size] + # Full weight would be [out_l*out_k, in_m*in_n, *k_size] + + uq = w1.size(1) # in_m - inner grouping dimension + + if is_conv: + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + B, C_in, *spatial = x.shape + # Reshape input for grouped application: [B * uq, C_in // uq, *spatial] + h_in_group = x.reshape(B * uq, -1, *spatial) + + # Ensure w2 has conv dims + if w2.dim() == 2: + w2 = w2.view(*w2.shape, *([1] * conv_dim)) + + # Apply w2 path with stride/padding + hb = conv_fn(h_in_group, w2, **kw_dict) + + # Reshape for cross-group operation + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross = hb.transpose(1, -1) + + # Apply w1 (always 2D, applied as linear on channel dim) + hc = F.linear(h_cross, w1) + hc = hc.transpose(1, -1) + + # Reshape to output + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + # Linear case + # Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n] + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k] + hb = F.linear(h_in_group, w2) + + # Transpose for w1: [..., uq, out_k] -> [..., out_k, uq] + h_cross = hb.transpose(-1, -2) + + # Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l] + hc = F.linear(h_cross, w1) + + # Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k] + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * multiplier + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase): @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] - in_dim = weight.shape[1:].numel() - out1, out2 = factorization(out_dim, rank) - in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) + in_dim = weight.shape[1] # Just in_channels, not flattened with kernel + k_size = weight.shape[2:] if weight.dim() > 2 else () + + out_l, out_k = factorization(out_dim, rank) + in_m, in_n = factorization(in_dim, rank) + + # w1: [out_l, in_m] + mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32) + # w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear + mat2 = torch.empty( + out_k, in_n, *k_size, device=weight.device, dtype=torch.float32 + ) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) - return LokrDiff( - (mat1, mat2, alpha, None, None, None, None, None, None) - ) + return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None)) def to_train(self): return LokrDiff(self.weights) @@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase): lokr_t2 = lora[lokr_t2_name] loaded_keys.add(lokr_t2_name) - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + if ( + (lokr_w1 is not None) + or (lokr_w2 is not None) + or (lokr_w1_a is not None) + or (lokr_w2_a is not None) + ): + weights = ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) return cls(loaded_keys, weights) else: return None @@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase): if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + w1 = torch.mm( + comfy.model_management.cast_to_device( + w1_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1_b, weight.device, intermediate_dtype + ), + ) else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + w1 = comfy.model_management.cast_to_device( + w1, weight.device, intermediate_dtype + ) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + w2 = torch.mm( + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + ) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + w2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + ) else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + w2 = comfy.model_management.cast_to_device( + w2, weight.device, intermediate_dtype + ) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr: efficient Kronecker product application. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/lokr.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora + w1 = v[0] + w2 = v[1] + alpha = v[2] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + + use_w1 = w1 is not None + use_w2 = w2 is not None + tucker = t2 is not None + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) if is_conv else {} + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + else: + op = F.linear + + # Determine rank and scale + rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Build c (w1) + if use_w1: + c = w1.to(dtype=x.dtype) + else: + c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype) + uq = c.size(1) + + # Build w2 components + if use_w2: + ba = w2.to(dtype=x.dtype) + else: + a = w2_b.to(dtype=x.dtype) + b = w2_a.to(dtype=x.dtype) + if is_conv: + if tucker: + # Tucker: a, b get 1s appended (kernel is in t2) + if a.dim() == 2: + a = a.view(*a.shape, *([1] * conv_dim)) + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + else: + # Non-tucker conv: b may need 1s appended + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + + # Reshape input by uq groups + if is_conv: + B, _, *rest = x.shape + h_in_group = x.reshape(B * uq, -1, *rest) + else: + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2 path + if use_w2: + hb = op(h_in_group, ba, **kw_dict) + else: + if is_conv: + if tucker: + t = t2.to(dtype=x.dtype) + if t.dim() == 2: + t = t.view(*t.shape, *([1] * conv_dim)) + ha = op(h_in_group, a) + ht = op(ha, t, **kw_dict) + hb = op(ht, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + else: + ha = op(h_in_group, a) + hb = op(ha, b) + + # Reshape and apply c (w1) + if is_conv: + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + + if is_conv: + hc = hc.transpose(1, -1) + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * scale diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 3cc60bb1b..bc4260a8f 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase): rank, in_dim = mat2.shape[0], mat2.shape[1] if mid is not None: convdim = mid.ndim - 2 - layer = ( - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.Conv3d - )[convdim] + layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim] else: layer = torch.nn.Linear self.lora_up = layer(rank, out_dim, bias=False) @@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase): weight = w + scale * diff.reshape(w.shape) return weight.to(org_dtype) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA training: h(x) = up(down(x)) * scale + + Simple implementation using the nn.Module weights directly. + No mid/dora/reshape branches (create_train doesn't create them). + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Compute scale = alpha / rank * multiplier + scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Get weights (keep in original dtype for numerical stability) + down_weight = self.lora_down.weight + up_weight = self.lora_up.weight + + if is_conv: + # Conv path: use functional conv + # conv_dim: 1=conv1d, 2=conv2d, 3=conv3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Reshape 2D weights to conv format if needed + # down: [rank, in_features] -> [rank, in_channels, *kernel_size] + # up: [out_features, rank] -> [out_features, rank, 1, 1, ...] + if down_weight.dim() == 2: + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + if in_channels is not None: + down_weight = down_weight.view( + down_weight.shape[0], in_channels, *kernel_size + ) + else: + # Fallback: assume 1x1 kernel + down_weight = down_weight.view( + *down_weight.shape, *([1] * conv_dim) + ) + if up_weight.dim() == 2: + # up always uses 1x1 kernel + up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim)) + + # down conv uses stride/padding from module, up is 1x1 + hidden = conv_fn(x, down_weight, **kw_dict) + + # mid layer if exists (tucker decomposition) + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + if mid_weight.dim() == 2: + mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim)) + hidden = conv_fn(hidden, mid_weight) + + # up conv is always 1x1 (no stride/padding) + out = conv_fn(hidden, up_weight) + else: + # Linear path: simple matmul chain + hidden = F.linear(x, down_weight) + + # mid layer if exists + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + hidden = F.linear(hidden, mid_weight) + + out = F.linear(hidden, up_weight) + + return out * scale + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase): mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) - return LoraDiff( - (mat1, mat2, alpha, None, None, None) - ) + return LoraDiff((mat1, mat2, alpha, None, None, None)) def to_train(self): return LoraDiff(self.weights) @@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase): except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA: h(x) = up(down(x)) * scale + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/locon.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape + up = v[0] + down = v[1] + alpha = v[2] + mid = v[3] + + # Compute scale = alpha / rank + rank = down.shape[0] + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + scale = scale * getattr(self, "multiplier", 1.0) + + # Cast dtype + up = up.to(dtype=x.dtype) + down = down.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + op = FUNC_LIST[ + conv_dim + 2 + ] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D weights to conv format using kernel_size + # down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size] + # up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel) + if down.dim() == 2: + # down.shape[1] = in_channels * prod(kernel_size) + if in_channels is not None: + down = down.view(down.shape[0], in_channels, *kernel_size) + else: + # Fallback: assume 1x1 kernel if in_channels unknown + down = down.view(*down.shape, *([1] * conv_dim)) + if up.dim() == 2: + # up always uses 1x1 kernel + up = up.view(*up.shape, *([1] * conv_dim)) + if mid is not None: + mid = mid.to(dtype=x.dtype) + if mid.dim() == 2: + mid = mid.view(*mid.shape, *([1] * conv_dim)) + else: + op = F.linear + kw_dict = {} # linear doesn't take stride/padding + + # Simple chain: down -> mid (if tucker) -> up + if mid is not None: + if not is_conv: + mid = mid.to(dtype=x.dtype) + hidden = op(x, down) + hidden = op(hidden, mid, **kw_dict) + out = op(hidden, up) + else: + hidden = op(x, down, **kw_dict) + out = op(hidden, up) + + return out * scale diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index c0aab9635..bc83cf8e8 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,13 +3,18 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) class OFTDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - # Unpack weights tuple from LoHaAdapter + # Unpack weights tuple from OFTAdapter blocks, rescale, alpha, _ = weights # Create trainable parameters @@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase): weight = self.rescale * weight return weight.to(org_dtype) + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + blocks = self.oft_blocks.to(device=device, dtype=dtype) + I = torch.eye(self.block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if set + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r.to(dtype) + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + OFT has no additive component - returns zeros matching base_out shape. + + OFT only transforms the output via g(), it doesn't add to it. + """ + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation. + + OFT transforms output channels using block-diagonal orthogonal matrices. + """ + r = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(self.block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.reshape(*batch_shape, out_features) + + # Apply rescale if present + if self.rescaled: + rescale = self.rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out + def passive_memory_usage(self): """Calculates memory usage of the trainable parameters.""" return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) - return OFTDiff( - (block, None, alpha, None) + block = torch.zeros( + block_num, block_size, block_size, device=weight.device, dtype=torch.float32 ) + return OFTDiff((block, None, alpha, None)) def to_train(self): return OFTDiff(self.weights) @@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase): alpha = 0 dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) block_num, block_size, *_ = blocks.shape @@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(1, 2) normed_q = q - if alpha > 0: # alpha in oft/boft is for constraint + if alpha > 0: # alpha in oft/boft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() r = r.to(weight) + # Create I in weight's dtype for the einsum + I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype) _, *shape = weight.shape lora_diff = torch.einsum( "k n m, k n ... -> k m ...", - (r * strength) - strength * I, + (r * strength) - strength * I_w, weight.view(block_num, block_size, *shape), ).view(-1, *shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + block_num, block_size, _ = blocks.shape + I = torch.eye(block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, block_num, block_size + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation to output. + + OFT transforms the output channels using block-diagonal orthogonal matrices. + + Reference: LyCORIS DiagOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.view(*batch_shape, block_num, block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.view(*batch_shape, out_features) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 2ec8d6e4b..be759952e 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1247,6 +1247,7 @@ class NodeInfoV1: output_node: bool=None deprecated: bool=None experimental: bool=None + dev_only: bool=None api_node: bool=None price_badge: dict | None = None search_aliases: list[str]=None @@ -1264,6 +1265,7 @@ class NodeInfoV3: output_node: bool=None deprecated: bool=None experimental: bool=None + dev_only: bool=None api_node: bool=None price_badge: dict | None = None @@ -1375,6 +1377,8 @@ class Schema: """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" is_experimental: bool=False """Flags a node as experimental, informing users that it may change or not work as expected.""" + is_dev_only: bool=False + """Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled.""" is_api_node: bool=False """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" price_badge: PriceBadge | None = None @@ -1383,6 +1387,8 @@ class Schema: """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" enable_expand: bool=False """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" + accept_all_inputs: bool=False + """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" def validate(self): '''Validate the schema: @@ -1483,6 +1489,7 @@ class Schema: output_node=self.is_output_node, deprecated=self.is_deprecated, experimental=self.is_experimental, + dev_only=self.is_dev_only, api_node=self.is_api_node, python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, @@ -1517,6 +1524,7 @@ class Schema: output_node=self.is_output_node, deprecated=self.is_deprecated, experimental=self.is_experimental, + dev_only=self.is_dev_only, api_node=self.is_api_node, python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, @@ -1789,6 +1797,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._DEPRECATED + _DEV_ONLY = None + @final + @classproperty + def DEV_ONLY(cls): # noqa + if cls._DEV_ONLY is None: + cls.GET_SCHEMA() + return cls._DEV_ONLY + _API_NODE = None @final @classproperty @@ -1853,6 +1869,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._NOT_IDEMPOTENT + _ACCEPT_ALL_INPUTS = None + @final + @classproperty + def ACCEPT_ALL_INPUTS(cls): # noqa + if cls._ACCEPT_ALL_INPUTS is None: + cls.GET_SCHEMA() + return cls._ACCEPT_ALL_INPUTS + @final @classmethod def INPUT_TYPES(cls) -> dict[str, dict]: @@ -1883,6 +1907,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._EXPERIMENTAL = schema.is_experimental if cls._DEPRECATED is None: cls._DEPRECATED = schema.is_deprecated + if cls._DEV_ONLY is None: + cls._DEV_ONLY = schema.is_dev_only if cls._API_NODE is None: cls._API_NODE = schema.is_api_node if cls._OUTPUT_NODE is None: @@ -1891,6 +1917,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._INPUT_IS_LIST = schema.is_input_list if cls._NOT_IDEMPOTENT is None: cls._NOT_IDEMPOTENT = schema.not_idempotent + if cls._ACCEPT_ALL_INPUTS is None: + cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs if cls._RETURN_TYPES is None: output = [] diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 400648cca..23cbe2372 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -13,17 +13,6 @@ class Text2ImageTaskCreationRequest(BaseModel): watermark: bool | None = Field(False) -class Image2ImageTaskCreationRequest(BaseModel): - model: str = Field(...) - prompt: str = Field(...) - response_format: str | None = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: str | None = Field("adaptive") - seed: int | None = Field(..., ge=0, le=2147483647) - guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(False) - - class Seedream4Options(BaseModel): max_images: int = Field(15) diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py new file mode 100644 index 000000000..8e3c79ab9 --- /dev/null +++ b/comfy_api_nodes/apis/grok.py @@ -0,0 +1,67 @@ +from pydantic import BaseModel, Field + + +class ImageGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + aspect_ratio: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_for: str = Field("url") + + +class InputUrlObject(BaseModel): + url: str = Field(...) + + +class ImageEditRequest(BaseModel): + model: str = Field(...) + image: InputUrlObject = Field(...) + prompt: str = Field(...) + resolution: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_for: str = Field("url") + + +class VideoGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + image: InputUrlObject | None = Field(...) + duration: int = Field(...) + aspect_ratio: str | None = Field(...) + resolution: str = Field(...) + seed: int = Field(...) + + +class VideoEditRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + video: InputUrlObject = Field(...) + seed: int = Field(...) + + +class ImageResponseObject(BaseModel): + url: str | None = Field(None) + b64_json: str | None = Field(None) + revised_prompt: str | None = Field(None) + + +class ImageGenerationResponse(BaseModel): + data: list[ImageResponseObject] = Field(...) + + +class VideoGenerationResponse(BaseModel): + request_id: str = Field(...) + + +class VideoResponseObject(BaseModel): + url: str = Field(...) + upsampled_prompt: str | None = Field(None) + duration: int = Field(...) + + +class VideoStatusResponse(BaseModel): + status: str | None = Field(None) + video: VideoResponseObject | None = Field(None) + model: str | None = Field(None) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py new file mode 100644 index 000000000..6421c9bd5 --- /dev/null +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -0,0 +1,66 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field, model_validator + + +class InputGenerateType(TypedDict): + generate_type: str + polygon_type: str + pbr: bool + + +class Hunyuan3DViewImage(BaseModel): + ViewType: str = Field(..., description="Valid values: back, left, right.") + ViewImageUrl: str = Field(...) + + +class To3DProTaskRequest(BaseModel): + Model: str = Field(...) + Prompt: str | None = Field(None) + ImageUrl: str | None = Field(None) + MultiViewImages: list[Hunyuan3DViewImage] | None = Field(None) + EnablePBR: bool | None = Field(...) + FaceCount: int | None = Field(...) + GenerateType: str | None = Field(...) + PolygonType: str | None = Field(...) + + +class RequestError(BaseModel): + Code: str = Field("") + Message: str = Field("") + + +class To3DProTaskCreateResponse(BaseModel): + JobId: str | None = Field(None) + Error: RequestError | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class ResultFile3D(BaseModel): + Type: str = Field(...) + Url: str = Field(...) + PreviewImageUrl: str = Field("") + + +class To3DProTaskResultResponse(BaseModel): + ErrorCode: str = Field("") + ErrorMessage: str = Field("") + ResultFile3Ds: list[ResultFile3D] = Field([]) + Status: str = Field(...) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class To3DProTaskQueryRequest(BaseModel): + JobId: str = Field(...) diff --git a/comfy_api_nodes/apis/magnific.py b/comfy_api_nodes/apis/magnific.py new file mode 100644 index 000000000..b9f148def --- /dev/null +++ b/comfy_api_nodes/apis/magnific.py @@ -0,0 +1,122 @@ +from typing import TypedDict + +from pydantic import AliasChoices, BaseModel, Field, model_validator + + +class InputPortraitMode(TypedDict): + portrait_mode: str + portrait_style: str + portrait_beautifier: str + + +class InputAdvancedSettings(TypedDict): + advanced_settings: str + whites: int + blacks: int + brightness: int + contrast: int + saturation: int + engine: str + transfer_light_a: str + transfer_light_b: str + fixed_generation: bool + + +class InputSkinEnhancerMode(TypedDict): + mode: str + skin_detail: int + optimized_for: str + + +class ImageUpscalerCreativeRequest(BaseModel): + image: str = Field(...) + scale_factor: str = Field(...) + optimized_for: str = Field(...) + prompt: str | None = Field(None) + creativity: int = Field(...) + hdr: int = Field(...) + resemblance: int = Field(...) + fractality: int = Field(...) + engine: str = Field(...) + + +class ImageUpscalerPrecisionV2Request(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + ultra_detail: int = Field(...) + flavor: str = Field(...) + scale_factor: int = Field(...) + + +class ImageRelightAdvancedSettingsRequest(BaseModel): + whites: int = Field(...) + blacks: int = Field(...) + brightness: int = Field(...) + contrast: int = Field(...) + saturation: int = Field(...) + engine: str = Field(...) + transfer_light_a: str = Field(...) + transfer_light_b: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageRelightRequest(BaseModel): + image: str = Field(...) + prompt: str | None = Field(None) + transfer_light_from_reference_image: str | None = Field(None) + light_transfer_strength: int = Field(...) + interpolate_from_original: bool = Field(...) + change_background: bool = Field(...) + style: str = Field(...) + preserve_details: bool = Field(...) + advanced_settings: ImageRelightAdvancedSettingsRequest | None = Field(...) + + +class ImageStyleTransferRequest(BaseModel): + image: str = Field(...) + reference_image: str = Field(...) + prompt: str | None = Field(None) + style_strength: int = Field(...) + structure_strength: int = Field(...) + is_portrait: bool = Field(...) + portrait_style: str | None = Field(...) + portrait_beautifier: str | None = Field(...) + flavor: str = Field(...) + engine: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageSkinEnhancerCreativeRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + + +class ImageSkinEnhancerFaithfulRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + skin_detail: int = Field(...) + + +class ImageSkinEnhancerFlexibleRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + optimized_for: str = Field(...) + + +class TaskResponse(BaseModel): + """Unified response model that handles both wrapped and unwrapped API responses.""" + + task_id: str = Field(...) + status: str = Field(validation_alias=AliasChoices("status", "task_status")) + generated: list[str] | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "data" in values and isinstance(values["data"], dict): + return values["data"] + return values diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 486801150..0cb5e3be8 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -9,7 +9,6 @@ from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, VIDEO_TASKS_EXECUTION_TIME, - Image2ImageTaskCreationRequest, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, Seedream4Options, @@ -174,99 +173,6 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) -class ByteDanceImageEditNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ByteDanceImageEditNode", - display_name="ByteDance Image Edit", - category="api node/image/ByteDance", - description="Edit images using ByteDance models via api based on prompt", - inputs=[ - IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), - IO.Image.Input( - "image", - tooltip="The base image to edit", - ), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Instruction to edit image", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to use for generation", - optional=True, - ), - IO.Float.Input( - "guidance_scale", - default=5.5, - min=1.0, - max=10.0, - step=0.01, - display_mode=IO.NumberDisplay.number, - tooltip="Higher value makes the image follow the prompt more closely", - optional=True, - ), - IO.Boolean.Input( - "watermark", - default=False, - tooltip='Whether to add an "AI generated" watermark to the image', - optional=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - prompt: str, - seed: int, - guidance_scale: float, - watermark: bool, - ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1)) - source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] - payload = Image2ImageTaskCreationRequest( - model=model, - prompt=prompt, - image=source_url, - seed=seed, - guidance_scale=guidance_scale, - watermark=watermark, - ) - response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), - data=payload, - response_model=ImageTaskCreationResponse, - ) - return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) - - class ByteDanceSeedreamNode(IO.ComfyNode): @classmethod @@ -1101,7 +1007,6 @@ class ByteDanceExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ ByteDanceImageNode, - ByteDanceImageEditNode, ByteDanceSeedreamNode, ByteDanceTextToVideoNode, ByteDanceImageToVideoNode, diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py new file mode 100644 index 000000000..da15e97ea --- /dev/null +++ b/comfy_api_nodes/nodes_grok.py @@ -0,0 +1,417 @@ +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.grok import ( + ImageEditRequest, + ImageGenerationRequest, + ImageGenerationResponse, + InputUrlObject, + VideoEditRequest, + VideoGenerationRequest, + VideoGenerationResponse, + VideoStatusResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_fs_object_size, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_video_to_comfyapi, + validate_string, + validate_video_duration, +) + + +class GrokImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageNode", + display_name="Grok Image", + category="api node/image/Grok", + description="Generate images using Grok based on a text prompt", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + ), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), + expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + aspect_ratio: str, + number_of_images: int, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"), + data=ImageGenerationRequest( + model=model, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=number_of_images, + seed=seed, + ), + response_model=ImageGenerationResponse, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageEditNode", + display_name="Grok Image Edit", + category="api node/image/Grok", + description="Modify an existing image based on a text prompt", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input("resolution", options=["1K"]), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of edited images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), + expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + resolution: str, + number_of_images: int, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), + data=ImageEditRequest( + model=model, + image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + prompt=prompt, + resolution=resolution.lower(), + n=number_of_images, + seed=seed, + ), + response_model=ImageGenerationResponse, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoNode", + display_name="Grok Video", + category="api node/video/Grok", + description="Generate video from a prompt or an image", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=6, + min=1, + max=15, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Image.Input("image", optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + expr=""" + ( + $base := 0.181 * widgets.duration; + {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + image: Input.Image | None = None, + ) -> IO.NodeOutput: + image_url = None + if image is not None: + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}") + validate_string(prompt, strip_whitespace=True, min_length=1) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), + data=VideoGenerationRequest( + model=model, + image=image_url, + prompt=prompt, + resolution=resolution, + duration=duration, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoEditNode", + display_name="Grok Video Edit", + category="api node/video/Grok", + description="Edit an existing video based on a text prompt.", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + video: Input.Video, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + validate_video_duration(video, min_duration=1, max_duration=8.7) + video_stream = video.get_stream_source() + video_size = get_fs_object_size(video_stream) + if video_size > 50 * 1024 * 1024: + raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"), + data=VideoEditRequest( + model=model, + video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), + prompt=prompt, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GrokImageNode, + GrokImageEditNode, + GrokVideoNode, + GrokVideoEditNode, + ] + + +async def comfy_entrypoint() -> GrokExtension: + return GrokExtension() diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py new file mode 100644 index 000000000..b3a736643 --- /dev/null +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -0,0 +1,297 @@ +import os + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.hunyuan3d import ( + Hunyuan3DViewImage, + InputGenerateType, + ResultFile3D, + To3DProTaskCreateResponse, + To3DProTaskQueryRequest, + To3DProTaskRequest, + To3DProTaskResultResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_bytesio, + downscale_image_tensor_by_max_side, + poll_op, + sync_op, + upload_image_to_comfyapi, + validate_image_dimensions, + validate_string, +) +from folder_paths import get_output_directory + + +def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D: + for i in response_objs: + if i.Type.lower() == "glb": + return i + raise ValueError("No GLB file found in response. Please report this to the developers.") + + +class TencentTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentTextToModelNode", + display_name="Hunyuan3D: Text to Model (Pro)", + category="api node/3d/Tencent", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["generate_type", "generate_type.pbr", "face_count"]), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + face_count: int, + generate_type: InputGenerateType, + seed: int, + ) -> IO.NodeOutput: + _ = seed + validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + Prompt=prompt, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + model_file = f"hunyuan_model_{response.JobId}.glb" + await download_url_to_bytesio( + get_glb_obj_from_response(result.ResultFile3Ds).Url, + os.path.join(get_output_directory(), model_file), + ) + return IO.NodeOutput(model_file) + + +class TencentImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentImageToModelNode", + display_name="Hunyuan3D: Image(s) to Model (Pro)", + category="api node/3d/Tencent", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["generate_type", "generate_type.pbr", "face_count"], + inputs=["image_left", "image_right", "image_back"], + ), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $multiview := ( + inputs.image_left.connected or inputs.image_right.connected or inputs.image_back.connected + ) ? 10 : 0; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $multiview + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + face_count: int, + generate_type: InputGenerateType, + seed: int, + image_left: Input.Image | None = None, + image_right: Input.Image | None = None, + image_back: Input.Image | None = None, + ) -> IO.NodeOutput: + _ = seed + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + validate_image_dimensions(image, min_width=128, min_height=128) + multiview_images = [] + for k, v in { + "left": image_left, + "right": image_right, + "back": image_back, + }.items(): + if v is None: + continue + validate_image_dimensions(v, min_width=128, min_height=128) + multiview_images.append( + Hunyuan3DViewImage( + ViewType=k, + ViewImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(v, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + ) + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + ImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(image, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + MultiViewImages=multiview_images if multiview_images else None, + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + model_file = f"hunyuan_model_{response.JobId}.glb" + await download_url_to_bytesio( + get_glb_obj_from_response(result.ResultFile3Ds).Url, + os.path.join(get_output_directory(), model_file), + ) + return IO.NodeOutput(model_file) + + +class TencentHunyuan3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TencentTextToModelNode, + TencentImageToModelNode, + ] + + +async def comfy_entrypoint() -> TencentHunyuan3DExtension: + return TencentHunyuan3DExtension() diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 3ec71530b..739fe1855 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -249,7 +249,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), - max_poll_attempts=160, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py new file mode 100644 index 000000000..013e71cc8 --- /dev/null +++ b/comfy_api_nodes/nodes_magnific.py @@ -0,0 +1,889 @@ +import math + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.magnific import ( + ImageRelightAdvancedSettingsRequest, + ImageRelightRequest, + ImageSkinEnhancerCreativeRequest, + ImageSkinEnhancerFaithfulRequest, + ImageSkinEnhancerFlexibleRequest, + ImageStyleTransferRequest, + ImageUpscalerCreativeRequest, + ImageUpscalerPrecisionV2Request, + InputAdvancedSettings, + InputPortraitMode, + InputSkinEnhancerMode, + TaskResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + downscale_image_tensor, + get_image_dimensions, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, +) + + +class MagnificImageUpscalerCreativeNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerCreativeNode", + display_name="Magnific Image Upscale (Creative)", + category="api node/image/Magnific", + description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " + "Maximum output: 25.3 megapixels.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "optimized_for", + options=[ + "standard", + "soft_portraits", + "hard_portraits", + "art_n_illustration", + "videogame_assets", + "nature_n_landscapes", + "films_n_photography", + "3d_renders", + "science_fiction_n_horror", + ], + ), + IO.Int.Input("creativity", min=-10, max=10, default=0, display_mode=IO.NumberDisplay.slider), + IO.Int.Input( + "hdr", + min=-10, + max=10, + default=0, + tooltip="The level of definition and detail.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "resemblance", + min=-10, + max=10, + default=0, + tooltip="The level of resemblance to the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "fractality", + min=-10, + max=10, + default=0, + tooltip="The strength of the prompt and intricacy per square pixel.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=["automatic", "magnific_illusio", "magnific_sharpy", "magnific_sparkle"], + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum pixel limit.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + expr=""" + ( + $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; + {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + scale_factor: str, + optimized_for: str, + creativity: int, + hdr: int, + resemblance: int, + fractality: int, + engine: str, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_pixels = 25_300_000 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.rstrip("x")) + output_pixels = height * width * requested_scale * requested_scale + + if output_pixels > max_output_pixels: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + input_pixels = width * height + scale = 2 + max_input_pixels = max_output_pixels // 4 + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + scale_output_pixels = input_pixels * candidate * candidate + if scale_output_pixels <= max_output_pixels: + scale = candidate + max_input_pixels = None + break + downscale_ratio = math.sqrt(scale_output_pixels / max_output_pixels) + if downscale_ratio <= 2.0: + scale = candidate + max_input_pixels = max_output_pixels // (candidate * candidate) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + scale_factor = f"{scale}x" + else: + raise ValueError( + f"Output size ({width * requested_scale}x{height * requested_scale} = {output_pixels:,} pixels) " + f"exceeds maximum allowed size of {max_output_pixels:,} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerCreativeRequest( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=scale_factor, + optimized_for=optimized_for, + creativity=creativity, + hdr=hdr, + resemblance=resemblance, + fractality=fractality, + engine=engine, + prompt=prompt if prompt else None, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerPreciseV2Node", + display_name="Magnific Image Upscale (Precise V2)", + category="api node/image/Magnific", + description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " + "Maximum output: 10060×10060 pixels.", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "flavor", + options=["sublime", "photo", "photo_denoiser"], + tooltip="Processing style: " + "sublime for general use, photo for photographs, photo_denoiser for noisy photos.", + ), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=7, + tooltip="Image sharpness intensity. Higher values increase edge definition and clarity.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=7, + tooltip="Intelligent grain/texture enhancement to prevent the image from " + "looking too smooth or artificial.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "ultra_detail", + min=0, + max=100, + default=30, + tooltip="Controls fine detail, textures, and micro-details added during upscaling.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum resolution.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + expr=""" + ( + $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; + {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + scale_factor: str, + flavor: str, + sharpen: int, + smart_grain: int, + ultra_detail: int, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_dimension = 10060 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.strip("x")) + output_width = width * requested_scale + output_height = height * requested_scale + + if output_width > max_output_dimension or output_height > max_output_dimension: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + max_dim = max(width, height) + scale = 2 + max_input_dim = max_output_dimension // 2 + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + output_dim = max_dim * candidate + if output_dim <= max_output_dimension: + scale = candidate + max_input_pixels = None + break + downscale_ratio = output_dim / max_output_dimension + if downscale_ratio <= 2.0: + scale = candidate + max_input_dim = max_output_dimension // candidate + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + requested_scale = scale + else: + raise ValueError( + f"Output dimensions ({output_width}x{output_height}) exceed maximum allowed " + f"resolution of {max_output_dimension}x{max_output_dimension} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerPrecisionV2Request( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=requested_scale, + flavor=flavor, + sharpen=sharpen, + smart_grain=smart_grain, + ultra_detail=ultra_detail, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageStyleTransferNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageStyleTransferNode", + display_name="Magnific Image Style Transfer", + category="api node/image/Magnific", + description="Transfer the style from a reference image to your input image.", + inputs=[ + IO.Image.Input("image", tooltip="The image to apply style transfer to."), + IO.Image.Input("reference_image", tooltip="The reference image to extract style from."), + IO.String.Input("prompt", multiline=True, default=""), + IO.Int.Input( + "style_strength", + min=0, + max=100, + default=100, + tooltip="Percentage of style strength.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "structure_strength", + min=0, + max=100, + default=50, + tooltip="Maintains the structure of the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "flavor", + options=["faithful", "gen_z", "psychedelia", "detaily", "clear", "donotstyle", "donotstyle_sharp"], + tooltip="Style transfer flavor.", + ), + IO.Combo.Input( + "engine", + options=[ + "balanced", + "definio", + "illusio", + "3d_cartoon", + "colorful_anime", + "caricature", + "real", + "super_real", + "softy", + ], + tooltip="Processing engine selection.", + ), + IO.DynamicCombo.Input( + "portrait_mode", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Combo.Input( + "portrait_style", + options=["standard", "pop", "super_pop"], + tooltip="Visual style applied to portrait images.", + ), + IO.Combo.Input( + "portrait_beautifier", + options=["none", "beautify_face", "beautify_face_max"], + tooltip="Facial beautification intensity on portraits.", + ), + ], + ), + ], + tooltip="Enable portrait mode for facial enhancements.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="When disabled, expect each generation to introduce a degree of randomness, " + "leading to more diverse outcomes.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + reference_image: Input.Image, + prompt: str, + style_strength: int, + structure_strength: int, + flavor: str, + engine: str, + portrait_mode: InputPortraitMode, + fixed_generation: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + is_portrait = portrait_mode["portrait_mode"] == "enabled" + portrait_style = portrait_mode.get("portrait_style", "standard") + portrait_beautifier = portrait_mode.get("portrait_beautifier", "none") + + uploaded_urls = await upload_images_to_comfyapi(cls, [image, reference_image], max_images=2) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-style-transfer", method="POST"), + response_model=TaskResponse, + data=ImageStyleTransferRequest( + image=uploaded_urls[0], + reference_image=uploaded_urls[1], + prompt=prompt if prompt else None, + style_strength=style_strength, + structure_strength=structure_strength, + is_portrait=is_portrait, + portrait_style=portrait_style if is_portrait else None, + portrait_beautifier=portrait_beautifier if is_portrait and portrait_beautifier != "none" else None, + flavor=flavor, + engine=engine, + fixed_generation=fixed_generation, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-style-transfer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageRelightNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageRelightNode", + display_name="Magnific Image Relight", + category="api node/image/Magnific", + description="Relight an image with lighting adjustments and optional reference-based light transfer.", + inputs=[ + IO.Image.Input("image", tooltip="The image to relight."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Descriptive guidance for lighting. Supports emphasis notation (1-1.4).", + ), + IO.Int.Input( + "light_transfer_strength", + min=0, + max=100, + default=100, + tooltip="Intensity of light transfer application.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "style", + options=[ + "standard", + "darker_but_realistic", + "clean", + "smooth", + "brighter", + "contrasted_n_hdr", + "just_composition", + ], + tooltip="Stylistic output preference.", + ), + IO.Boolean.Input( + "interpolate_from_original", + default=False, + tooltip="Restricts generation freedom to match original more closely.", + ), + IO.Boolean.Input( + "change_background", + default=True, + tooltip="Modifies background based on prompt/reference.", + ), + IO.Boolean.Input( + "preserve_details", + default=True, + tooltip="Maintains texture and fine details from original.", + ), + IO.DynamicCombo.Input( + "advanced_settings", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "whites", + min=0, + max=100, + default=50, + tooltip="Adjusts the brightest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "blacks", + min=0, + max=100, + default=50, + tooltip="Adjusts the darkest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "brightness", + min=0, + max=100, + default=50, + tooltip="Overall brightness adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "contrast", + min=0, + max=100, + default=50, + tooltip="Contrast adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "saturation", + min=0, + max=100, + default=50, + tooltip="Color saturation adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=[ + "automatic", + "balanced", + "cool", + "real", + "illusio", + "fairy", + "colorful_anime", + "hard_transform", + "softy", + ], + tooltip="Processing engine selection.", + ), + IO.Combo.Input( + "transfer_light_a", + options=["automatic", "low", "medium", "normal", "high", "high_on_faces"], + tooltip="The intensity of light transfer.", + ), + IO.Combo.Input( + "transfer_light_b", + options=[ + "automatic", + "composition", + "straight", + "smooth_in", + "smooth_out", + "smooth_both", + "reverse_both", + "soft_in", + "soft_out", + "soft_mid", + # "strong_mid", # Commented out because requests fail when this is set. + "style_shift", + "strong_shift", + ], + tooltip="Also modifies light transfer intensity. " + "Can be combined with the previous control for varied effects.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="Ensures consistent output with the same settings.", + ), + ], + ), + ], + tooltip="Fine-tuning options for advanced lighting control.", + ), + IO.Image.Input( + "reference_image", + optional=True, + tooltip="Optional reference image to transfer lighting from.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + light_transfer_strength: int, + style: str, + interpolate_from_original: bool, + change_background: bool, + preserve_details: bool, + advanced_settings: InputAdvancedSettings, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if reference_image is not None and get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + if reference_image is not None: + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + reference_url = None + if reference_image is not None: + reference_url = (await upload_images_to_comfyapi(cls, reference_image, max_images=1))[0] + + adv_settings = None + if advanced_settings["advanced_settings"] == "enabled": + adv_settings = ImageRelightAdvancedSettingsRequest( + whites=advanced_settings["whites"], + blacks=advanced_settings["blacks"], + brightness=advanced_settings["brightness"], + contrast=advanced_settings["contrast"], + saturation=advanced_settings["saturation"], + engine=advanced_settings["engine"], + transfer_light_a=advanced_settings["transfer_light_a"], + transfer_light_b=advanced_settings["transfer_light_b"], + fixed_generation=advanced_settings["fixed_generation"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-relight", method="POST"), + response_model=TaskResponse, + data=ImageRelightRequest( + image=image_url, + prompt=prompt if prompt else None, + transfer_light_from_reference_image=reference_url, + light_transfer_strength=light_transfer_strength, + interpolate_from_original=interpolate_from_original, + change_background=change_background, + style=style, + preserve_details=preserve_details, + advanced_settings=adv_settings, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-relight/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageSkinEnhancerNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageSkinEnhancerNode", + display_name="Magnific Image Skin Enhancer", + category="api node/image/Magnific", + description="Skin enhancement for portraits with multiple processing modes.", + inputs=[ + IO.Image.Input("image", tooltip="The portrait image to enhance."), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=0, + tooltip="Sharpening intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=2, + tooltip="Smart grain intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.DynamicCombo.Input( + "mode", + options=[ + IO.DynamicCombo.Option("creative", []), + IO.DynamicCombo.Option( + "faithful", + [ + IO.Int.Input( + "skin_detail", + min=0, + max=100, + default=80, + tooltip="Skin detail enhancement level.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + IO.DynamicCombo.Option( + "flexible", + [ + IO.Combo.Input( + "optimized_for", + options=[ + "enhance_skin", + "improve_lighting", + "enhance_everything", + "transform_to_real", + "no_make_up", + ], + tooltip="Enhancement optimization target.", + ), + ], + ), + ], + tooltip="Processing mode: creative for artistic enhancement, " + "faithful for preserving original appearance, " + "flexible for targeted optimization.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $rates := {"creative": 0.29, "faithful": 0.37, "flexible": 0.45}; + {"type":"usd","usd": $lookup($rates, widgets.mode)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + sharpen: int, + smart_grain: int, + mode: InputSkinEnhancerMode, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=4096 * 4096))[0] + selected_mode = mode["mode"] + + if selected_mode == "creative": + endpoint = "creative" + data = ImageSkinEnhancerCreativeRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + ) + elif selected_mode == "faithful": + endpoint = "faithful" + data = ImageSkinEnhancerFaithfulRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + skin_detail=mode["skin_detail"], + ) + else: # flexible + endpoint = "flexible" + data = ImageSkinEnhancerFlexibleRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + optimized_for=mode["optimized_for"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{endpoint}", method="POST"), + response_model=TaskResponse, + data=data, + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + # MagnificImageUpscalerCreativeNode, + # MagnificImageUpscalerPreciseV2Node, + MagnificImageStyleTransferNode, + MagnificImageRelightNode, + MagnificImageSkinEnhancerNode, + ] + + +async def comfy_entrypoint() -> MagnificExtension: + return MagnificExtension() diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 87e663845..afc18bb25 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -149,7 +149,6 @@ class OpenAIVideoSora2(IO.ComfyNode): response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) return IO.NodeOutput( diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index c052e7656..8fccde25a 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -203,7 +203,6 @@ class TopazImageEnhance(IO.ComfyNode): progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: x.credits * 0.08, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=60, ) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 364976000..c3c9ff4bf 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -13,6 +13,7 @@ from .conversions import ( bytesio_to_image_tensor, convert_mask_to_image, downscale_image_tensor, + downscale_image_tensor_by_max_side, image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, @@ -33,6 +34,7 @@ from .download_helpers import ( from .upload_helpers import ( upload_audio_to_comfyapi, upload_file_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, upload_video_to_comfyapi, ) @@ -61,6 +63,7 @@ __all__ = [ # Upload helpers "upload_audio_to_comfyapi", "upload_file_to_comfyapi", + "upload_image_to_comfyapi", "upload_images_to_comfyapi", "upload_video_to_comfyapi", # Download helpers @@ -75,6 +78,7 @@ __all__ = [ "bytesio_to_image_tensor", "convert_mask_to_image", "downscale_image_tensor", + "downscale_image_tensor_by_max_side", "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index f372ec7b5..8a1259506 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -141,7 +141,7 @@ async def poll_op( queued_statuses: list[str | int] | None = None, data: BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, @@ -238,7 +238,7 @@ async def poll_op_raw( queued_statuses: list[str | int] | None = None, data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 546741b7b..3e37e8a8c 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -56,15 +56,14 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, *, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: """Converts a torch.Tensor image to a named BytesIO object. Args: image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -79,13 +78,14 @@ def tensor_to_bytesio( return img_binary -def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: +def tensor_to_pil(image: torch.Tensor, total_pixels: int | None = 2048 * 2048) -> Image.Image: """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" if len(image.shape) > 3: image = image[0] # TODO: remove alpha if not allowed and present input_tensor = image.cpu() - input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + if total_pixels is not None: + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() image_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) return img @@ -93,14 +93,14 @@ def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image def tensor_to_base64_string( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -144,16 +144,31 @@ def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) return s +def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: + """Downscale input image tensor so the largest dimension is at most max_side pixels.""" + samples = image.movedim(-1, 1) + height, width = samples.shape[2], samples.shape[3] + max_dim = max(width, height) + if max_dim <= max_side: + return image + scale_by = max_side / max_dim + new_width = round(width * scale_by) + new_height = round(height * scale_by) + s = common_upscale(samples, new_width, new_height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + def tensor_to_data_uri( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Converts a tensor image to a Data URI string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). Returns: diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 2794be35c..3153f2b98 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -49,7 +49,7 @@ async def upload_images_to_comfyapi( mime_type: str | None = None, wait_label: str | None = "Uploading", show_batch_index: bool = True, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. @@ -88,6 +88,28 @@ async def upload_images_to_comfyapi( return download_urls +async def upload_image_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + mime_type: str | None = None, + wait_label: str | None = "Uploading", + total_pixels: int = 2048 * 2048, +) -> str: + """Uploads a single image to ComfyUI API and returns its download URL.""" + return ( + await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type=mime_type, + wait_label=wait_label, + show_batch_index=False, + total_pixels=total_pixels, + ) + )[0] + + async def upload_audio_to_comfyapi( cls: type[IO.ComfyNode], audio: Input.Audio, diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 3eb40e937..8afd13acf 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -701,7 +701,14 @@ class Noise_EmptyNoise: def generate_noise(self, input_latent): latent_image = input_latent["samples"] - return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + if latent_image.is_nested: + tensors = latent_image.unbind() + zeros = [] + for t in tensors: + zeros.append(torch.zeros(t.shape, dtype=t.dtype, layout=t.layout, device="cpu")) + return comfy.nested_tensor.NestedTensor(zeros) + else: + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") class Noise_RandomNoise: @@ -741,7 +748,7 @@ class SamplerCustom(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image if not add_noise: @@ -760,6 +767,7 @@ class SamplerCustom(io.ComfyNode): samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) @@ -939,7 +947,7 @@ class SamplerCustomAdvanced(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image noise_mask = None @@ -954,6 +962,7 @@ class SamplerCustomAdvanced(io.ComfyNode): samples = samples.to(comfy.model_management.intermediate_device()) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 1ed060205..c066064ac 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -104,19 +104,23 @@ class CustomComboNode(io.ComfyNode): category="utils", is_experimental=True, inputs=[io.Combo.Input("choice", options=[])], - outputs=[io.String.Output()] + outputs=[ + io.String.Output(display_name="STRING"), + io.Int.Output(display_name="INDEX"), + ], + accept_all_inputs=True, ) @classmethod - def validate_inputs(cls, choice: io.Combo.Type) -> bool: + def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool: # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. return True @classmethod - def execute(cls, choice: io.Combo.Type) -> io.NodeOutput: - return io.NodeOutput(choice) + def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput: + return io.NodeOutput(choice, index) class DCTestNode(io.ComfyNode): diff --git a/comfy_extras/nodes_lora_debug.py b/comfy_extras/nodes_lora_debug.py new file mode 100644 index 000000000..937a0fbfb --- /dev/null +++ b/comfy_extras/nodes_lora_debug.py @@ -0,0 +1,79 @@ +import folder_paths +import comfy.utils +import comfy.sd + + +class LoraLoaderBypass: + """ + Apply LoRA in bypass mode without modifying base model weights. + + Bypass mode computes: output = base_forward(x) + lora_path(x) + This is useful for training and when model weights are offloaded. + """ + + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." + EXPERIMENTAL = True + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + self.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + + +class LoraLoaderBypassModelOnly(LoraLoaderBypass): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + + +NODE_CLASS_MAPPINGS = { + "LoraLoaderBypass": LoraLoaderBypass, + "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)", + "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)", +} diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index b91a22309..2aec62f61 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -223,11 +223,24 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords[:, 0] += frame_idx + + # The following adjusts keyframe end positions for small grid IC-LoRA. + # After dilation, the small grid has the same size and position as the large grid, + # but each token encodes a larger image patch. We adjust the end position (not start) + # so that RoPE represents the correct middle point of each token. + # keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end]) + # We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3. + spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor( + scale_factors[1:], + device=pixel_coords.device, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype) + if keyframe_idxs is None: keyframe_idxs = pixel_coords else: @@ -235,12 +248,12 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128): + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) if guide_mask is not None: target_h = max(noise_mask.shape[3], guide_mask.shape[3]) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 02e5e7dd8..736213a47 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode): @classmethod def execute(cls, width, height, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples":latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) generate = execute # TODO: remove diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 68a73cf13..024a89391 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.weight_adapter import adapters, adapter_maps +from comfy.weight_adapter.bypass import BypassInjectionManager from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar @@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + for param_groups in self.optimizer.param_groups: + for param in param_groups["params"]: + if param.grad is None: + continue + param.grad.data = param.grad.data.to(param.data.dtype) self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) @@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): num_images = sum(t.shape[0] for t in latents) multi_res = False # Not using multi_res path in bucket mode - logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") for i, lat in enumerate(latents): - logging.info(f" Bucket {i}: shape {lat.shape}") + logging.debug(f" Bucket {i}: shape {lat.shape}") return latents, num_images, multi_res # Non-bucket mode @@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): latents = [t.to(dtype) for t in latents] for latent in latents: all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") + logging.debug(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: multi_res = True else: @@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode): if bucket_mode: return positive # Skip validation in bucket mode - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: return positive * num_images elif len(positive) != num_images: @@ -596,6 +602,8 @@ def _create_weight_adapter( shape = module.weight.shape lora_params = {} + logging.debug(f"Creating weight adapter for {key} with shape {shape}") + if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) dora_scale = existing_weights.get(f"{key}.dora_scale", None) @@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): return lora_sd, all_weight_adapters +def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup LoRA adapters in bypass mode. + + In bypass mode: + - Weight adapters (lora/lokr/oft) use bypass injection (forward hook) + - Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification) + + This is useful when the base model weights are quantized and cannot be + directly modified. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list, bypass_manager) + """ + lora_sd = {} + all_weight_adapters = [] + bypass_manager = BypassInjectionManager() + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + all_weight_adapters.append(adapter) + + key = f"{n}.weight" + # BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass + # Only use bypass for adapters that have h() method (lora/lokr/oft) + if isinstance(adapter, BiasDiff): + mp.add_weight_wrapper(key, adapter) + logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}") + else: + bypass_manager.add_adapter(key, adapter, strength=1.0) + logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}") + + if hasattr(m, "bias") and m.bias is not None: + # Bias adapters still use weight wrapper (bias is usually not quantized) + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}") + + return lora_sd, all_weight_adapters, bypass_manager + + def _create_optimizer(optimizer_name, parameters, learning_rate): """Create optimizer based on name. @@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", ), + io.Boolean.Input( + "bypass_mode", + default=False, + tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", + ), ], outputs=[ - io.Model.Output( - display_name="model", tooltip="Model with LoRA applied" - ), io.Custom("LORA_MODEL").Output( display_name="lora", tooltip="LoRA weights" ), @@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing, existing_lora, bucket_mode, + bypass_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] + bypass_mode = bypass_mode[0] # Process latents based on mode if bucket_mode: @@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode): existing_weights, existing_steps = _load_existing_lora(existing_lora) # Setup LoRA adapters - lora_sd, all_weight_adapters = _setup_lora_adapters( - mp, existing_weights, algorithm, lora_dtype, rank - ) + bypass_manager = None + if bypass_mode: + logging.debug("Using bypass mode for training") + lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass( + mp, existing_weights, algorithm, lora_dtype, rank + ) + else: + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) # Create optimizer and loss function optimizer = _create_optimizer( @@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode): guider = TrainGuider(mp) guider.set_conds(positive) + # Inject bypass hooks if bypass mode is enabled + bypass_injections = None + if bypass_manager is not None: + bypass_injections = bypass_manager.create_injections(mp.model) + for injection in bypass_injections: + injection.inject(mp) + logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks") + # Run training loop try: _run_training_loop( @@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + # Eject bypass hooks if they were injected + if bypass_injections is not None: + for injection in bypass_injections: + injection.eject(mp) + logging.debug("[BypassMode] Ejected bypass hooks") for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer @@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode): for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) + # mp in train node is highly specialized for training + # use it in inference will result in bad behavior so we don't return it + return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) class LoraModelLoader(io.ComfyNode):# diff --git a/comfyui_version.py b/comfyui_version.py index 952d413db..d56466db2 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.10.0" +__version__ = "0.11.0" diff --git a/execution.py b/execution.py index 648f204ec..4b4f63c80 100644 --- a/execution.py +++ b/execution.py @@ -175,7 +175,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= continue obj = cached.outputs[output_index] input_data_all[x] = obj - elif input_category is not None: + elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS): input_data_all[x] = [input_data] if is_v3: diff --git a/manager_requirements.txt b/manager_requirements.txt index bea6d4927..c420cc48e 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.5 +comfyui_manager==4.1b1 diff --git a/nodes.py b/nodes.py index 29e7776fc..794969753 100644 --- a/nodes.py +++ b/nodes.py @@ -1230,7 +1230,7 @@ class EmptyLatentImage: def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) + return ({"samples": latent, "downscale_ratio_spacial": 8}, ) class LatentFromBatch: @@ -1538,7 +1538,7 @@ class SetLatentNoiseMask: def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] - latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") @@ -1556,6 +1556,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples return (out, ) @@ -2104,7 +2105,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", - "LoraLoader": "Load LoRA", + "LoraLoader": "Load LoRA (Model and CLIP)", + "LoraLoaderModelOnly": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", "DiffControlNetLoader": "Load ControlNet Model (diff)", @@ -2431,6 +2433,7 @@ async def init_builtin_extra_nodes(): "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py", + "nodes_lora_debug.py" ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 120b6c751..c0e787abd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.10.0" +version = "0.11.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index fa2393e19..b29d429ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.37.11 -comfyui-workflow-templates==0.8.15 +comfyui-workflow-templates==0.8.27 comfyui-embedded-docs==0.4.0 torch torchsde @@ -22,6 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 +requests #non essential dependencies: kornia>=0.7.1 diff --git a/server.py b/server.py index 1888745b7..2aee5cc06 100644 --- a/server.py +++ b/server.py @@ -679,6 +679,8 @@ class PromptServer(): info['deprecated'] = True if getattr(obj_class, "EXPERIMENTAL", False): info['experimental'] = True + if getattr(obj_class, "DEV_ONLY", False): + info['dev_only'] = True if hasattr(obj_class, 'API_NODE'): info['api_node'] = obj_class.API_NODE