feat: restore cn hijack

This commit is contained in:
Bingsu
2023-05-24 23:27:24 +09:00
parent 1bb89cf3f4
commit f24137a570
4 changed files with 50 additions and 10 deletions

View File

@@ -1 +1 @@
__version__ = "23.5.18.dev0"
__version__ = "23.5.18.dev1"

View File

@@ -1,3 +1,7 @@
from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models
__all__ = ["ControlNetExt", "controlnet_exists", "get_cn_inpaint_models"]
__all__ = [
"ControlNetExt",
"controlnet_exists",
"get_cn_inpaint_models",
]

33
controlnet_ext/restore.py Normal file
View File

@@ -0,0 +1,33 @@
from __future__ import annotations
from modules import img2img, processing
def cn_restore_unet_hook(p, cn_latest_network):
if cn_latest_network is not None:
unet = p.sd_model.model.diffusion_model
cn_latest_network.restore(unet)
class CNHijackRestore:
def __init__(self):
self.process = hasattr(processing, "__controlnet_original_process_images_inner")
self.img2img = hasattr(img2img, "__controlnet_original_process_batch")
def __enter__(self):
if self.process:
self.orig_process = processing.process_images_inner
processing.process_images_inner = getattr(
processing, "__controlnet_original_process_images_inner"
)
if self.img2img:
self.orig_img2img = img2img.process_batch
img2img.process_batch = getattr(
img2img, "__controlnet_original_process_batch"
)
def __exit__(self, *args, **kwargs):
if self.process:
processing.process_images_inner = self.orig_process
if self.img2img:
img2img.process_batch = self.orig_img2img

View File

@@ -26,6 +26,7 @@ from adetailer.common import PredictOutput
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
from adetailer.ui import adui, ordinal, suffix
from controlnet_ext import ControlNetExt, controlnet_exists
from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook
from sd_webui import images, safe, script_callbacks, scripts, shared
from sd_webui.paths import data_path, models_path
from sd_webui.processing import (
@@ -75,6 +76,7 @@ class AfterDetailerScript(scripts.Script):
super().__init__()
self.controlnet_ext = None
self.ultralytics_device = self.get_ultralytics_device()
self.cn_latest_network = None
def title(self):
return AFTER_DETAILER
@@ -270,7 +272,7 @@ class AfterDetailerScript(scripts.Script):
def script_filter(self, p, args: ADetailerArgs):
script_runner = copy(p.scripts)
script_args = deepcopy(p.script_args)
cn_used = self.disable_controlnet_units(script_args)
self.disable_controlnet_units(script_args)
ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True)
if not ad_only_seleted_scripts:
@@ -283,7 +285,7 @@ class AfterDetailerScript(scripts.Script):
for name in (script_name, script_name.strip())
}
if cn_used or args.ad_controlnet_model != "None":
if args.ad_controlnet_model != "None":
script_names_set.add("controlnet")
filtered_alwayson = []
@@ -292,22 +294,20 @@ class AfterDetailerScript(scripts.Script):
filename = Path(filepath).stem
if filename in script_names_set:
filtered_alwayson.append(script_object)
if filename == "controlnet":
self.cn_latest_network = script_object.latest_network
script_runner.alwayson_scripts = filtered_alwayson
return script_runner, script_args
def disable_controlnet_units(self, script_args: list[Any]) -> bool:
cn_used = False
def disable_controlnet_units(self, script_args: list[Any]) -> None:
for obj in script_args:
if "controlnet" in obj.__class__.__name__.lower():
cn_used = True
if hasattr(obj, "enabled"):
obj.enabled = False
if hasattr(obj, "input_mode"):
obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
return cn_used
def get_i2i_p(self, p, args: ADetailerArgs, image):
seed, subseed = self.get_seed(p)
width, height = self.get_width_height(p, args)
@@ -476,6 +476,8 @@ class AfterDetailerScript(scripts.Script):
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
if not re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
if args.ad_controlnet_model == "None":
cn_restore_unet_hook(p2, self.cn_latest_network)
processed = process_images(p2)
p2 = copy(i2i)
@@ -505,7 +507,8 @@ class AfterDetailerScript(scripts.Script):
for n, args in enumerate(arg_list):
if args.ad_model == "None":
continue
is_processed |= self._postprocess_image(p, pp, args, n=n)
with CNHijackRestore():
is_processed |= self._postprocess_image(p, pp, args, n=n)
if is_processed:
self.save_image(