diff --git a/adetailer/ui.py b/adetailer/ui.py index b6318ff..8e044cc 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -9,25 +9,36 @@ import gradio as gr from adetailer import AFTER_DETAILER, __version__ from adetailer.args import ALL_ARGS, MASK_MERGE_INVERT -from controlnet_ext import controlnet_exists, get_cn_models +from controlnet_ext import controlnet_exists, controlnet_forge, get_cn_models -cn_module_choices = { - "inpaint": [ - "inpaint_global_harmonious", - "inpaint_only", - "inpaint_only+lama", - ], - "lineart": [ - "lineart_coarse", - "lineart_realistic", - "lineart_anime", - "lineart_anime_denoise", - ], - "openpose": ["openpose_full", "dw_openpose_full"], - "tile": ["tile_resample", "tile_colorfix", "tile_colorfix+sharp"], - "scribble": ["t2ia_sketch_pidi"], - "depth": ["depth_midas", "depth_hand_refiner"], -} +if controlnet_forge: + from lib_controlnet import global_state + cn_module_choices = { + "inpaint": list(m for m in global_state.get_filtered_preprocessors("Inpaint")), + "lineart": list(m for m in global_state.get_filtered_preprocessors("Lineart")), + "openpose": list(m for m in global_state.get_filtered_preprocessors("OpenPose")), + "tile": list(m for m in global_state.get_filtered_preprocessors("Tile")), + "scribble": list(m for m in global_state.get_filtered_preprocessors("Scribble")), + "depth": list(m for m in global_state.get_filtered_preprocessors("Depth")), + } +else: + cn_module_choices = { + "inpaint": [ + "inpaint_global_harmonious", + "inpaint_only", + "inpaint_only+lama", + ], + "lineart": [ + "lineart_coarse", + "lineart_realistic", + "lineart_anime", + "lineart_anime_denoise", + ], + "openpose": ["openpose_full", "dw_openpose_full"], + "tile": ["tile_resample", "tile_colorfix", "tile_colorfix+sharp"], + "scribble": ["t2ia_sketch_pidi"], + "depth": ["depth_midas", "depth_hand_refiner"], + } class Widgets(SimpleNamespace): diff --git a/controlnet_ext/__init__.py b/controlnet_ext/__init__.py index 0ab6668..32efe17 100644 --- a/controlnet_ext/__init__.py +++ b/controlnet_ext/__init__.py @@ -1,7 +1,21 @@ -from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models +try: + from .controlnet_ext_forge import ( + ControlNetExt, + controlnet_exists, + controlnet_forge, + get_cn_models, + ) +except ImportError: + from .controlnet_ext import ( + ControlNetExt, + controlnet_exists, + controlnet_forge, + get_cn_models, + ) __all__ = [ "ControlNetExt", "controlnet_exists", + "controlnet_forge", "get_cn_models", ] diff --git a/controlnet_ext/common.py b/controlnet_ext/common.py new file mode 100644 index 0000000..beeb60e --- /dev/null +++ b/controlnet_ext/common.py @@ -0,0 +1,11 @@ +import re + +cn_model_module = { + "inpaint": "inpaint_global_harmonious", + "scribble": "t2ia_sketch_pidi", + "lineart": "lineart_coarse", + "openpose": "openpose_full", + "tile": "tile_resample", + "depth": "depth_midas", +} +cn_model_regex = re.compile("|".join(cn_model_module.keys())) diff --git a/controlnet_ext/controlnet_ext.py b/controlnet_ext/controlnet_ext.py index 9f54f12..b6785cb 100644 --- a/controlnet_ext/controlnet_ext.py +++ b/controlnet_ext/controlnet_ext.py @@ -1,7 +1,6 @@ from __future__ import annotations import importlib -import re import sys from functools import lru_cache from pathlib import Path @@ -9,6 +8,8 @@ from textwrap import dedent from modules import extensions, sd_models, shared +from .common import cn_model_regex + try: from modules.paths import extensions_builtin_dir, extensions_dir, models_path except ImportError as e: @@ -22,6 +23,7 @@ except ImportError as e: ext_path = Path(extensions_dir) ext_builtin_path = Path(extensions_builtin_dir) controlnet_exists = False +controlnet_forge = False controlnet_path = None cn_base_path = "" @@ -42,16 +44,6 @@ if controlnet_path is not None: if target_path not in sys.path: sys.path.append(target_path) -cn_model_module = { - "inpaint": "inpaint_global_harmonious", - "scribble": "t2ia_sketch_pidi", - "lineart": "lineart_coarse", - "openpose": "openpose_full", - "tile": "tile_resample", - "depth": "depth_midas", -} -cn_model_regex = re.compile("|".join(cn_model_module.keys())) - class ControlNetExt: def __init__(self): diff --git a/controlnet_ext/controlnet_ext_forge.py b/controlnet_ext/controlnet_ext_forge.py new file mode 100644 index 0000000..0c0331b --- /dev/null +++ b/controlnet_ext/controlnet_ext_forge.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import copy +import sys + +import numpy as np +from lib_controlnet import external_code, global_state +from lib_controlnet.external_code import ControlNetUnit + +from modules import scripts +from modules.processing import StableDiffusionProcessing + +from .common import cn_model_regex + +controlnet_exists = True +controlnet_forge = True + +def find_script(p : StableDiffusionProcessing, script_title : str) -> scripts.Script: + script = next((s for s in p.scripts.scripts if s.title() == script_title ), None) + if not script: + raise Exception("Script not found: " + script_title) + return script + +def add_forge_script_to_adetailer_run(p: StableDiffusionProcessing, script_title : str, script_args : list): + p.scripts = copy.copy(scripts.scripts_img2img) + p.scripts.alwayson_scripts = [] + p.script_args_value = [] + + script = copy.copy(find_script(p, script_title)) + script.args_from = len(p.script_args_value) + script.args_to = len(p.script_args_value) + len(script_args) + p.scripts.alwayson_scripts.append(script) + p.script_args_value.extend(script_args) + +class ControlNetExt: + def __init__(self): + self.cn_available = False + self.external_cn = external_code + + def init_controlnet(self): + self.cn_available = True + + def update_scripts_args( + self, + p, + model: str, + module: str | None, + weight: float, + guidance_start: float, + guidance_end: float, + ): + if (not self.cn_available) or model == "None": + return + + if controlnet_forge: + image = np.asarray(p.init_images[0]) + mask = np.zeros_like(image) + mask[:] = 255 + + cnet_image = { + "image": image, + "mask": mask + } + + pres = external_code.pixel_perfect_resolution( + image, + target_H=p.height, + target_W=p.width, + resize_mode=external_code.resize_mode_from_value(p.resize_mode) + ) + + add_forge_script_to_adetailer_run( + p, + "ControlNet", + [ + ControlNetUnit( + enabled=True, + image=cnet_image, + model=model, + module=module, + weight=weight, + guidance_start=guidance_start, + guidance_end=guidance_end, + processor_res=pres + ) + ] + ) + + return + + +def get_cn_models() -> list[str]: + models = global_state.get_all_controlnet_names() + return [m for m in models if cn_model_regex.search(m)] diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 42eb374..ece36e8 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -36,7 +36,12 @@ from adetailer.mask import ( ) from adetailer.traceback import rich_traceback from adetailer.ui import WebuiInfo, adui, ordinal, suffix -from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models +from controlnet_ext import ( + ControlNetExt, + controlnet_exists, + controlnet_forge, + get_cn_models, +) from controlnet_ext.restore import ( CNHijackRestore, cn_allow_script_control, @@ -517,13 +522,14 @@ class AfterDetailerScript(scripts.Script): i2i._ad_disabled = True i2i._ad_inner = True - if args.ad_controlnet_model != "Passthrough": - self.disable_controlnet_units(i2i.script_args) + if not controlnet_forge: + if args.ad_controlnet_model != "Passthrough": + self.disable_controlnet_units(i2i.script_args) - if args.ad_controlnet_model not in ["None", "Passthrough"]: - self.update_controlnet_args(i2i, args) - elif args.ad_controlnet_model == "None": - i2i.control_net_enabled = False + if args.ad_controlnet_model not in ["None", "Passthrough"]: + self.update_controlnet_args(i2i, args) + elif args.ad_controlnet_model == "None": + i2i.control_net_enabled = False return i2i @@ -729,6 +735,12 @@ class AfterDetailerScript(scripts.Script): p2.seed = self.get_each_tap_seed(seed, j) p2.subseed = self.get_each_tap_seed(subseed, j) + if controlnet_forge: + if args.ad_controlnet_model not in "None": + self.update_controlnet_args(p2, args) + else: + p2.control_net_enabled = False + try: processed = process_images(p2) except NansException as e: