mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 19:29:54 +00:00
feat: restore cn hijack
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "23.5.18.dev0"
|
||||
__version__ = "23.5.18.dev1"
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models
|
||||
|
||||
__all__ = ["ControlNetExt", "controlnet_exists", "get_cn_inpaint_models"]
|
||||
__all__ = [
|
||||
"ControlNetExt",
|
||||
"controlnet_exists",
|
||||
"get_cn_inpaint_models",
|
||||
]
|
||||
|
||||
33
controlnet_ext/restore.py
Normal file
33
controlnet_ext/restore.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from modules import img2img, processing
|
||||
|
||||
|
||||
def cn_restore_unet_hook(p, cn_latest_network):
|
||||
if cn_latest_network is not None:
|
||||
unet = p.sd_model.model.diffusion_model
|
||||
cn_latest_network.restore(unet)
|
||||
|
||||
|
||||
class CNHijackRestore:
|
||||
def __init__(self):
|
||||
self.process = hasattr(processing, "__controlnet_original_process_images_inner")
|
||||
self.img2img = hasattr(img2img, "__controlnet_original_process_batch")
|
||||
|
||||
def __enter__(self):
|
||||
if self.process:
|
||||
self.orig_process = processing.process_images_inner
|
||||
processing.process_images_inner = getattr(
|
||||
processing, "__controlnet_original_process_images_inner"
|
||||
)
|
||||
if self.img2img:
|
||||
self.orig_img2img = img2img.process_batch
|
||||
img2img.process_batch = getattr(
|
||||
img2img, "__controlnet_original_process_batch"
|
||||
)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
if self.process:
|
||||
processing.process_images_inner = self.orig_process
|
||||
if self.img2img:
|
||||
img2img.process_batch = self.orig_img2img
|
||||
@@ -26,6 +26,7 @@ from adetailer.common import PredictOutput
|
||||
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
|
||||
from adetailer.ui import adui, ordinal, suffix
|
||||
from controlnet_ext import ControlNetExt, controlnet_exists
|
||||
from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook
|
||||
from sd_webui import images, safe, script_callbacks, scripts, shared
|
||||
from sd_webui.paths import data_path, models_path
|
||||
from sd_webui.processing import (
|
||||
@@ -75,6 +76,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
super().__init__()
|
||||
self.controlnet_ext = None
|
||||
self.ultralytics_device = self.get_ultralytics_device()
|
||||
self.cn_latest_network = None
|
||||
|
||||
def title(self):
|
||||
return AFTER_DETAILER
|
||||
@@ -270,7 +272,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
def script_filter(self, p, args: ADetailerArgs):
|
||||
script_runner = copy(p.scripts)
|
||||
script_args = deepcopy(p.script_args)
|
||||
cn_used = self.disable_controlnet_units(script_args)
|
||||
self.disable_controlnet_units(script_args)
|
||||
|
||||
ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True)
|
||||
if not ad_only_seleted_scripts:
|
||||
@@ -283,7 +285,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
for name in (script_name, script_name.strip())
|
||||
}
|
||||
|
||||
if cn_used or args.ad_controlnet_model != "None":
|
||||
if args.ad_controlnet_model != "None":
|
||||
script_names_set.add("controlnet")
|
||||
|
||||
filtered_alwayson = []
|
||||
@@ -292,22 +294,20 @@ class AfterDetailerScript(scripts.Script):
|
||||
filename = Path(filepath).stem
|
||||
if filename in script_names_set:
|
||||
filtered_alwayson.append(script_object)
|
||||
if filename == "controlnet":
|
||||
self.cn_latest_network = script_object.latest_network
|
||||
|
||||
script_runner.alwayson_scripts = filtered_alwayson
|
||||
return script_runner, script_args
|
||||
|
||||
def disable_controlnet_units(self, script_args: list[Any]) -> bool:
|
||||
cn_used = False
|
||||
def disable_controlnet_units(self, script_args: list[Any]) -> None:
|
||||
for obj in script_args:
|
||||
if "controlnet" in obj.__class__.__name__.lower():
|
||||
cn_used = True
|
||||
if hasattr(obj, "enabled"):
|
||||
obj.enabled = False
|
||||
if hasattr(obj, "input_mode"):
|
||||
obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
|
||||
|
||||
return cn_used
|
||||
|
||||
def get_i2i_p(self, p, args: ADetailerArgs, image):
|
||||
seed, subseed = self.get_seed(p)
|
||||
width, height = self.get_width_height(p, args)
|
||||
@@ -476,6 +476,8 @@ class AfterDetailerScript(scripts.Script):
|
||||
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
|
||||
|
||||
if not re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
||||
if args.ad_controlnet_model == "None":
|
||||
cn_restore_unet_hook(p2, self.cn_latest_network)
|
||||
processed = process_images(p2)
|
||||
|
||||
p2 = copy(i2i)
|
||||
@@ -505,7 +507,8 @@ class AfterDetailerScript(scripts.Script):
|
||||
for n, args in enumerate(arg_list):
|
||||
if args.ad_model == "None":
|
||||
continue
|
||||
is_processed |= self._postprocess_image(p, pp, args, n=n)
|
||||
with CNHijackRestore():
|
||||
is_processed |= self._postprocess_image(p, pp, args, n=n)
|
||||
|
||||
if is_processed:
|
||||
self.save_image(
|
||||
|
||||
Reference in New Issue
Block a user