mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-26 09:18:59 +00:00
Cnet (#22)
* ini * remove shit * Create control_model.py * i * i * Update controlnet_supported.py * Update controlnet_supported.py * Update controlnet_supported.py * i * i * Update controlnet_supported.py * i * Update controlnet_supported.py * remove shits * remove shit * Update global_state.py * i * i * Update legacy_preprocessors.py * Update legacy_preprocessors.py * remove shit * Update batch_hijack.py * remove shit * remove shit * i * i * i * Update external_code.py * Update global_state.py * Update infotext.py * Update utils.py * Update external_code.py * i * i * i * Update controlnet_ui_group.py * remove shit * remove shit * i * Update controlnet.py * Update controlnet.py * Update controlnet.py * Update controlnet.py * Update controlnet.py * i * Update global_state.py * Update global_state.py * i * Update global_state.py * Update global_state.py * Update global_state.py * Update global_state.py * Update controlnet_ui_group.py * i * Update global_state.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py * i * Update controlnet_ui_group.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py
This commit is contained in:
158
modules_forge/supported_controlnet.py
Normal file
158
modules_forge/supported_controlnet.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import os
|
||||
import torch
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.controlnet
|
||||
|
||||
from ldm_patched.modules.controlnet import ControlLora, ControlNet, load_t2i_adapter
|
||||
from modules_forge.controlnet import apply_controlnet_advanced
|
||||
from modules_forge.shared import add_supported_control_model
|
||||
|
||||
|
||||
class ControlModelPatcher:
|
||||
@staticmethod
|
||||
def try_build_from_state_dict(state_dict, ckpt_path):
|
||||
return None
|
||||
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher = model_patcher
|
||||
|
||||
def patch_to_process(self, p, control_image):
|
||||
return
|
||||
|
||||
|
||||
class ControlNetPatcher(ControlModelPatcher):
|
||||
@staticmethod
|
||||
def try_build_from_state_dict(controlnet_data, ckpt_path):
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlNetPatcher(ControlLora(controlnet_data))
|
||||
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data,
|
||||
unet_dtype)
|
||||
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
|
||||
count = 0
|
||||
loop = True
|
||||
while loop:
|
||||
suffix = [".weight", ".bias"]
|
||||
for s in suffix:
|
||||
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
||||
k_out = "zero_convs.{}.0{}".format(count, s)
|
||||
if k_in not in controlnet_data:
|
||||
loop = False
|
||||
break
|
||||
diffusers_keys[k_in] = k_out
|
||||
count += 1
|
||||
|
||||
count = 0
|
||||
loop = True
|
||||
while loop:
|
||||
suffix = [".weight", ".bias"]
|
||||
for s in suffix:
|
||||
if count == 0:
|
||||
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
||||
else:
|
||||
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
||||
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
||||
if k_in not in controlnet_data:
|
||||
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
||||
loop = False
|
||||
diffusers_keys[k_in] = k_out
|
||||
count += 1
|
||||
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in controlnet_data:
|
||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||
|
||||
leftover_keys = controlnet_data.keys()
|
||||
if len(leftover_keys) > 0:
|
||||
print("leftover keys:", leftover_keys)
|
||||
controlnet_data = new_sd
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
key = 'zero_convs.0.0.weight'
|
||||
if pth_key in controlnet_data:
|
||||
pth = True
|
||||
key = pth_key
|
||||
prefix = "control_model."
|
||||
elif key in controlnet_data:
|
||||
prefix = ""
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
if net is None:
|
||||
return None
|
||||
return ControlNetPatcher(net)
|
||||
|
||||
if controlnet_config is None:
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix,
|
||||
unet_dtype, True).unet_config
|
||||
load_device = ldm_patched.modules.model_management.get_torch_device()
|
||||
manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
print("WARNING: Your controlnet model is diff version rather than official float16 model. "
|
||||
"Please use an official float16/float32 model for robust performance.")
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
print(missing, unexpected)
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"):
|
||||
# TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype)
|
||||
return ControlNetPatcher(control)
|
||||
|
||||
def __init__(self, model_patcher):
|
||||
super().__init__(model_patcher)
|
||||
self.strength = 1.0
|
||||
self.start_percent = 0.0
|
||||
self.end_percent = 1.0
|
||||
self.positive_advanced_weighting = None
|
||||
self.negative_advanced_weighting = None
|
||||
self.advanced_frame_weighting = None
|
||||
self.advanced_sigma_weighting = None
|
||||
|
||||
def patch_to_process(self, p, control_image):
|
||||
unet = p.sd_model.forge_objects.unet
|
||||
|
||||
unet = apply_controlnet_advanced(
|
||||
unet=unet,
|
||||
controlnet=self.model_patcher,
|
||||
image_bchw=control_image,
|
||||
strength=self.strength,
|
||||
start_percent=self.start_percent,
|
||||
end_percent=self.end_percent,
|
||||
positive_advanced_weighting=self.positive_advanced_weighting,
|
||||
negative_advanced_weighting=self.negative_advanced_weighting,
|
||||
advanced_frame_weighting=self.advanced_frame_weighting,
|
||||
advanced_sigma_weighting=self.advanced_sigma_weighting)
|
||||
|
||||
p.sd_model.forge_objects.unet = unet
|
||||
return
|
||||
|
||||
|
||||
add_supported_control_model(ControlNetPatcher)
|
||||
Reference in New Issue
Block a user