Update supported_controlnet.py

This commit is contained in:
layerdiffusion
2024-08-03 12:52:32 -07:00
parent 907d883e49
commit 0871870342

View File

@@ -1,7 +1,9 @@
import os
import torch
import ldm_patched.modules.utils
from huggingface_guess.detection import unet_config_from_diffusers_unet, model_config_from_unet
from huggingface_guess.utils import unet_to_diffusers
from backend import memory_management
from backend.operations import using_forge_operations
from backend.nn.cnets import cldm
from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter
@@ -43,9 +45,9 @@ class ControlNetPatcher(ControlModelPatcher):
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config)
unet_dtype = memory_management.unet_dtype()
controlnet_config = unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@@ -104,11 +106,12 @@ class ControlNetPatcher(ControlModelPatcher):
return ControlNetPatcher(net)
if controlnet_config is None:
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
unet_dtype = memory_management.unet_dtype()
controlnet_config = model_config_from_unet(controlnet_data, prefix, True).unet_config
controlnet_config['dtype'] = unet_dtype
load_device = ldm_patched.modules.model_management.get_torch_device()
manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device)
load_device = memory_management.get_torch_device()
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]