diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 49b04c19..2cc90843 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -1,7 +1,9 @@ import os import torch -import ldm_patched.modules.utils +from huggingface_guess.detection import unet_config_from_diffusers_unet, model_config_from_unet +from huggingface_guess.utils import unet_to_diffusers +from backend import memory_management from backend.operations import using_forge_operations from backend.nn.cnets import cldm from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter @@ -43,9 +45,9 @@ class ControlNetPatcher(ControlModelPatcher): controlnet_config = None if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format - unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) - diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config) + unet_dtype = memory_management.unet_dtype() + controlnet_config = unet_config_from_diffusers_unet(controlnet_data, unet_dtype) + diffusers_keys = unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -104,11 +106,12 @@ class ControlNetPatcher(ControlModelPatcher): return ControlNetPatcher(net) if controlnet_config is None: - unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + unet_dtype = memory_management.unet_dtype() + controlnet_config = model_config_from_unet(controlnet_data, prefix, True).unet_config + controlnet_config['dtype'] = unet_dtype - load_device = ldm_patched.modules.model_management.get_torch_device() - manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device) + load_device = memory_management.get_torch_device() + manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device) controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]