From dc9822b7df4785e77690e93b0e09feaff01e2e12 Mon Sep 17 00:00:00 2001 From: krigeta <75309361+krigeta@users.noreply.github.com> Date: Sat, 14 Feb 2026 08:53:52 +0530 Subject: [PATCH] Add working Qwen 2512 ControlNet (Fun ControlNet) support (#12359) --- comfy/controlnet.py | 73 +++++++++++ comfy/ldm/qwen_image/controlnet.py | 190 +++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 8336412f2..ba670b16d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -297,6 +297,30 @@ class ControlNet(ControlBase): self.model_sampling_current = None super().cleanup() + +class QwenFunControlNet(ControlNet): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): + # Fun checkpoints are more sensitive to high strengths in the generic + # ControlNet merge path. Use a soft response curve so strength=1.0 stays + # unchanged while >1 grows more gently. + original_strength = self.strength + self.strength = math.sqrt(max(self.strength, 0.0)) + try: + return super().get_control(x_noisy, t, cond, batched_number, transformer_options) + finally: + self.strength = original_strength + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + self.set_extra_arg("base_model", model.diffusion_model) + + def copy(self): + c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped + self.copy_to(c) + return c + class ControlLoraOps: class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, @@ -606,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control + +def load_controlnet_qwen_fun(sd, model_options={}): + load_device = comfy.model_management.get_torch_device() + weight_dtype = comfy.utils.weight_dtype(sd) + unet_dtype = model_options.get("dtype", weight_dtype) + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) + + in_features = sd["control_img_in.weight"].shape[1] + inner_dim = sd["control_img_in.weight"].shape[0] + + block_weight = sd["control_blocks.0.attn.to_q.weight"] + attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0] + num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim)) + + model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel( + control_in_features=in_features, + inner_dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + operations=operations, + device=comfy.model_management.unet_offload_device(), + dtype=unet_dtype, + ) + model = controlnet_load_state_dict(model, sd) + + latent_format = comfy.latent_formats.Wan21() + control = QwenFunControlNet( + model, + compression_ratio=1, + latent_format=latent_format, + # Fun checkpoints already expect their own 33-channel context handling. + # Enabling generic concat_mask injects an extra mask channel at apply-time + # and breaks the intended fallback packing path. + concat_mask=False, + load_device=load_device, + manual_cast_dtype=manual_cast_dtype, + extra_conds=[], + ) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -683,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data: + return load_controlnet_qwen_fun(controlnet_data, model_options=model_options) elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index a6d408104..c0aae9240 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -2,6 +2,196 @@ import torch import math from .model import QwenImageTransformer2DModel +from .model import QwenImageTransformerBlock + + +class QwenImageFunControlBlock(QwenImageTransformerBlock): + def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None): + super().__init__( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + device=device, + operations=operations, + ) + self.has_before_proj = has_before_proj + if has_before_proj: + self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + +class QwenImageFunControlNetModel(torch.nn.Module): + def __init__( + self, + control_in_features=132, + inner_dim=3072, + num_attention_heads=24, + attention_head_dim=128, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.main_model_double = main_model_double + self.injection_layers = tuple(injection_layers) + # Keep base hint scaling at 1.0 so user-facing strength behaves similarly + # to the reference Gen2/VideoX implementation around strength=1. + self.hint_scale = 1.0 + self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype) + + self.control_blocks = torch.nn.ModuleList([]) + for i in range(num_control_blocks): + self.control_blocks.append( + QwenImageFunControlBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + has_before_proj=(i == 0), + dtype=dtype, + device=device, + operations=operations, + ) + ) + + def _process_hint_tokens(self, hint): + if hint is None: + return None + if hint.ndim == 4: + hint = hint.unsqueeze(2) + + # Fun checkpoints are trained with 33 latent channels before 2x2 packing: + # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features. + # Default behavior (no inpaint input in stock Apply ControlNet) should use + # zeros for mask/inpaint branches, matching VideoX fallback semantics. + expected_c = self.control_img_in.weight.shape[1] // 4 + if hint.shape[1] == 16 and expected_c == 33: + zeros_mask = torch.zeros_like(hint[:, :1]) + zeros_inpaint = torch.zeros_like(hint) + hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1) + + bs, c, t, h, w = hint.shape + hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view( + orig_shape[0], + orig_shape[1], + orig_shape[-3], + orig_shape[-2] // 2, + 2, + orig_shape[-1] // 2, + 2, + ) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape( + bs, + t * ((h + 1) // 2) * ((w + 1) // 2), + c * 4, + ) + + expected_in = self.control_img_in.weight.shape[1] + cur_in = hidden_states.shape[-1] + if cur_in < expected_in: + pad = torch.zeros( + (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + hidden_states = torch.cat([hidden_states, pad], dim=-1) + elif cur_in > expected_in: + hidden_states = hidden_states[:, :, :expected_in] + + return hidden_states + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + hint=None, + transformer_options={}, + base_model=None, + **kwargs, + ): + if base_model is None: + raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.") + + encoder_hidden_states_mask = attention_mask + # Keep attention mask disabled inside Fun control blocks to mirror + # VideoX behavior (they rely on seq lengths for RoPE, not masked attention). + encoder_hidden_states_mask = None + + hidden_states, img_ids, _ = base_model.process_img(x) + hint_tokens = self._process_hint_tokens(hint) + if hint_tokens is None: + raise RuntimeError("Qwen Fun ControlNet requires a control hint image.") + + if hint_tokens.shape[1] != hidden_states.shape[1]: + max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1]) + hint_tokens = hint_tokens[:, :max_tokens] + hidden_states = hidden_states[:, :max_tokens] + img_ids = img_ids[:, :max_tokens] + + txt_start = round( + max( + ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ) + ) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous() + + hidden_states = base_model.img_in(hidden_states) + encoder_hidden_states = base_model.txt_norm(context) + encoder_hidden_states = base_model.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + base_model.time_text_embed(timesteps, hidden_states) + if guidance is None + else base_model.time_text_embed(timesteps, guidance, hidden_states) + ) + + c = self.control_img_in(hint_tokens) + + for i, block in enumerate(self.control_blocks): + if i == 0: + c_in = block.before_proj(c) + hidden_states + all_c = [] + else: + all_c = list(torch.unbind(c, dim=0)) + c_in = all_c.pop(-1) + + encoder_hidden_states, c_out = block( + hidden_states=c_in, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + c_skip = block.after_proj(c_out) * self.hint_scale + all_c += [c_skip, c_out] + c = torch.stack(all_c, dim=0) + + hints = torch.unbind(c, dim=0)[:-1] + + controlnet_block_samples = [None] * self.main_model_double + for local_idx, base_idx in enumerate(self.injection_layers): + if local_idx < len(hints) and base_idx < len(controlnet_block_samples): + controlnet_block_samples[base_idx] = hints[local_idx] + + return {"input": controlnet_block_samples} class QwenImageControlNetModel(QwenImageTransformer2DModel):