mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-21 23:09:19 +00:00
control rework
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.controlnet
|
||||
|
||||
from ldm_patched.modules.controlnet import ControlLora, ControlNet, load_t2i_adapter
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.cnets import cldm
|
||||
from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter
|
||||
from modules_forge.controlnet import apply_controlnet_advanced
|
||||
from modules_forge.shared import add_supported_control_model
|
||||
|
||||
@@ -43,8 +44,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data,
|
||||
unet_dtype)
|
||||
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@@ -105,15 +105,16 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
|
||||
if controlnet_config is None:
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix,
|
||||
unet_dtype, True).unet_config
|
||||
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
|
||||
load_device = ldm_patched.modules.model_management.get_torch_device()
|
||||
manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
|
||||
with using_forge_operations(parameters_manual_cast=manual_cast_dtype is not None):
|
||||
control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
@@ -136,8 +137,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
# TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return ControlNetPatcher(control)
|
||||
|
||||
def __init__(self, model_patcher):
|
||||
|
||||
Reference in New Issue
Block a user