diff --git a/controlnet_ext/restore.py b/controlnet_ext/restore.py index c218e07..5b9bfa6 100644 --- a/controlnet_ext/restore.py +++ b/controlnet_ext/restore.py @@ -1,6 +1,8 @@ from __future__ import annotations -from modules import img2img, processing +from contextlib import contextmanager + +from modules import img2img, processing, shared def cn_restore_unet_hook(p, cn_latest_network): @@ -31,3 +33,17 @@ class CNHijackRestore: processing.process_images_inner = self.orig_process if self.img2img: img2img.process_batch = self.orig_img2img + + +@contextmanager +def cn_allow_script_control(): + orig = False + if "control_net_allow_script_control" in shared.opts.data: + try: + orig = shared.opts.data["control_net_allow_script_control"] + shared.opts.data["control_net_allow_script_control"] = True + yield + finally: + shared.opts.data["control_net_allow_script_control"] = orig + else: + yield diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index d1a5a6d..3d31171 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -27,7 +27,11 @@ from adetailer.common import PredictOutput from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes from adetailer.ui import adui, ordinal, suffix from controlnet_ext import ControlNetExt, controlnet_exists -from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook +from controlnet_ext.restore import ( + CNHijackRestore, + cn_allow_script_control, + cn_restore_unet_hook, +) from sd_webui import images, safe, script_callbacks, scripts, shared from sd_webui.paths import data_path, models_path from sd_webui.processing import ( @@ -373,6 +377,9 @@ class AfterDetailerScript(scripts.Script): if args.ad_controlnet_model != "None": self.update_controlnet_args(i2i, args) + else: + i2i.control_net_enabled = False + return i2i def save_image(self, p, image, *, condition: str, suffix: str) -> None: @@ -519,7 +526,7 @@ class AfterDetailerScript(scripts.Script): arg_list = self.get_args(*args_) is_processed = False - with CNHijackRestore(), pause_total_tqdm(): + with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control(): for n, args in enumerate(arg_list): if args.ad_model == "None": continue