mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-04 12:49:55 +00:00
Add working Qwen 2512 ControlNet (Fun ControlNet) support (#12359)
This commit is contained in:
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@@ -606,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -683,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
Reference in New Issue
Block a user