feat: restore cn hijack

This commit is contained in:
Bingsu
2023-05-24 23:27:24 +09:00
parent 1bb89cf3f4
commit f24137a570
4 changed files with 50 additions and 10 deletions

View File

@@ -26,6 +26,7 @@ from adetailer.common import PredictOutput
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
from adetailer.ui import adui, ordinal, suffix
from controlnet_ext import ControlNetExt, controlnet_exists
from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook
from sd_webui import images, safe, script_callbacks, scripts, shared
from sd_webui.paths import data_path, models_path
from sd_webui.processing import (
@@ -75,6 +76,7 @@ class AfterDetailerScript(scripts.Script):
super().__init__()
self.controlnet_ext = None
self.ultralytics_device = self.get_ultralytics_device()
self.cn_latest_network = None
def title(self):
return AFTER_DETAILER
@@ -270,7 +272,7 @@ class AfterDetailerScript(scripts.Script):
def script_filter(self, p, args: ADetailerArgs):
script_runner = copy(p.scripts)
script_args = deepcopy(p.script_args)
cn_used = self.disable_controlnet_units(script_args)
self.disable_controlnet_units(script_args)
ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True)
if not ad_only_seleted_scripts:
@@ -283,7 +285,7 @@ class AfterDetailerScript(scripts.Script):
for name in (script_name, script_name.strip())
}
if cn_used or args.ad_controlnet_model != "None":
if args.ad_controlnet_model != "None":
script_names_set.add("controlnet")
filtered_alwayson = []
@@ -292,22 +294,20 @@ class AfterDetailerScript(scripts.Script):
filename = Path(filepath).stem
if filename in script_names_set:
filtered_alwayson.append(script_object)
if filename == "controlnet":
self.cn_latest_network = script_object.latest_network
script_runner.alwayson_scripts = filtered_alwayson
return script_runner, script_args
def disable_controlnet_units(self, script_args: list[Any]) -> bool:
cn_used = False
def disable_controlnet_units(self, script_args: list[Any]) -> None:
for obj in script_args:
if "controlnet" in obj.__class__.__name__.lower():
cn_used = True
if hasattr(obj, "enabled"):
obj.enabled = False
if hasattr(obj, "input_mode"):
obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
return cn_used
def get_i2i_p(self, p, args: ADetailerArgs, image):
seed, subseed = self.get_seed(p)
width, height = self.get_width_height(p, args)
@@ -476,6 +476,8 @@ class AfterDetailerScript(scripts.Script):
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
if not re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
if args.ad_controlnet_model == "None":
cn_restore_unet_hook(p2, self.cn_latest_network)
processed = process_images(p2)
p2 = copy(i2i)
@@ -505,7 +507,8 @@ class AfterDetailerScript(scripts.Script):
for n, args in enumerate(arg_list):
if args.ad_model == "None":
continue
is_processed |= self._postprocess_image(p, pp, args, n=n)
with CNHijackRestore():
is_processed |= self._postprocess_image(p, pp, args, n=n)
if is_processed:
self.save_image(