diff --git a/adetailer/__version__.py b/adetailer/__version__.py index 21311d5..fe3a6a0 100644 --- a/adetailer/__version__.py +++ b/adetailer/__version__.py @@ -1 +1 @@ -__version__ = "23.5.18.dev0" +__version__ = "23.5.18.dev1" diff --git a/controlnet_ext/__init__.py b/controlnet_ext/__init__.py index 66eab52..e032d2d 100644 --- a/controlnet_ext/__init__.py +++ b/controlnet_ext/__init__.py @@ -1,3 +1,7 @@ from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models -__all__ = ["ControlNetExt", "controlnet_exists", "get_cn_inpaint_models"] +__all__ = [ + "ControlNetExt", + "controlnet_exists", + "get_cn_inpaint_models", +] diff --git a/controlnet_ext/restore.py b/controlnet_ext/restore.py new file mode 100644 index 0000000..c218e07 --- /dev/null +++ b/controlnet_ext/restore.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from modules import img2img, processing + + +def cn_restore_unet_hook(p, cn_latest_network): + if cn_latest_network is not None: + unet = p.sd_model.model.diffusion_model + cn_latest_network.restore(unet) + + +class CNHijackRestore: + def __init__(self): + self.process = hasattr(processing, "__controlnet_original_process_images_inner") + self.img2img = hasattr(img2img, "__controlnet_original_process_batch") + + def __enter__(self): + if self.process: + self.orig_process = processing.process_images_inner + processing.process_images_inner = getattr( + processing, "__controlnet_original_process_images_inner" + ) + if self.img2img: + self.orig_img2img = img2img.process_batch + img2img.process_batch = getattr( + img2img, "__controlnet_original_process_batch" + ) + + def __exit__(self, *args, **kwargs): + if self.process: + processing.process_images_inner = self.orig_process + if self.img2img: + img2img.process_batch = self.orig_img2img diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 1e9ec59..4d2e574 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -26,6 +26,7 @@ from adetailer.common import PredictOutput from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes from adetailer.ui import adui, ordinal, suffix from controlnet_ext import ControlNetExt, controlnet_exists +from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook from sd_webui import images, safe, script_callbacks, scripts, shared from sd_webui.paths import data_path, models_path from sd_webui.processing import ( @@ -75,6 +76,7 @@ class AfterDetailerScript(scripts.Script): super().__init__() self.controlnet_ext = None self.ultralytics_device = self.get_ultralytics_device() + self.cn_latest_network = None def title(self): return AFTER_DETAILER @@ -270,7 +272,7 @@ class AfterDetailerScript(scripts.Script): def script_filter(self, p, args: ADetailerArgs): script_runner = copy(p.scripts) script_args = deepcopy(p.script_args) - cn_used = self.disable_controlnet_units(script_args) + self.disable_controlnet_units(script_args) ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True) if not ad_only_seleted_scripts: @@ -283,7 +285,7 @@ class AfterDetailerScript(scripts.Script): for name in (script_name, script_name.strip()) } - if cn_used or args.ad_controlnet_model != "None": + if args.ad_controlnet_model != "None": script_names_set.add("controlnet") filtered_alwayson = [] @@ -292,22 +294,20 @@ class AfterDetailerScript(scripts.Script): filename = Path(filepath).stem if filename in script_names_set: filtered_alwayson.append(script_object) + if filename == "controlnet": + self.cn_latest_network = script_object.latest_network script_runner.alwayson_scripts = filtered_alwayson return script_runner, script_args - def disable_controlnet_units(self, script_args: list[Any]) -> bool: - cn_used = False + def disable_controlnet_units(self, script_args: list[Any]) -> None: for obj in script_args: if "controlnet" in obj.__class__.__name__.lower(): - cn_used = True if hasattr(obj, "enabled"): obj.enabled = False if hasattr(obj, "input_mode"): obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple") - return cn_used - def get_i2i_p(self, p, args: ADetailerArgs, image): seed, subseed = self.get_seed(p) width, height = self.get_width_height(p, args) @@ -476,6 +476,8 @@ class AfterDetailerScript(scripts.Script): self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j) if not re.match(r"^\s*\[SKIP\]\s*$", p2.prompt): + if args.ad_controlnet_model == "None": + cn_restore_unet_hook(p2, self.cn_latest_network) processed = process_images(p2) p2 = copy(i2i) @@ -505,7 +507,8 @@ class AfterDetailerScript(scripts.Script): for n, args in enumerate(arg_list): if args.ad_model == "None": continue - is_processed |= self._postprocess_image(p, pp, args, n=n) + with CNHijackRestore(): + is_processed |= self._postprocess_image(p, pp, args, n=n) if is_processed: self.save_image(