diff --git a/adetailer/args.py b/adetailer/args.py index 7cb2e6a..ba92038 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -54,8 +54,9 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_use_cfg_scale: bool = False ad_cfg_scale: NonNegativeFloat = 7.0 ad_restore_face: bool = False - ad_controlnet_model: constr(regex=r".*inpaint.*|^None$") = "None" + ad_controlnet_model: constr(regex=r".*(inpaint|tile|scribble|lineart|openpose).*|^None$") = "None" ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 + ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0 @staticmethod def ppop( @@ -107,7 +108,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ppop("ADetailer restore face") ppop( "ADetailer ControlNet model", - ["ADetailer ControlNet model", "ADetailer ControlNet weight"], + ["ADetailer ControlNet model", "ADetailer ControlNet weight", "ADetailer ControlNet guidance end"], cond="None", ) @@ -156,6 +157,7 @@ _all_args = [ ("ad_restore_face", "ADetailer restore face"), ("ad_controlnet_model", "ADetailer ControlNet model"), ("ad_controlnet_weight", "ADetailer ControlNet weight"), + ("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"), ] AD_ENABLE = Arg(*_all_args[0]) diff --git a/adetailer/ui.py b/adetailer/ui.py index c7372cf..4b7cd46 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -177,6 +177,17 @@ def one_ui_group( interactive=controlnet_exists, elem_id=eid("ad_controlnet_weight"), ) + + w.ad_controlnet_guidance_end = gr.Slider( + label="ControlNet guidance end" + suffix(n), + minimum=0.0, + maximum=1.0, + step=0.05, + value=1.0, + visible=True, + interactive=controlnet_exists, + elem_id=eid("ad_controlnet_guidance_end"), + ) for attr in ALL_ARGS.attrs: widget = getattr(w, attr) diff --git a/controlnet_ext/controlnet_ext.py b/controlnet_ext/controlnet_ext.py index 22ec198..e45c2c4 100644 --- a/controlnet_ext/controlnet_ext.py +++ b/controlnet_ext/controlnet_ext.py @@ -3,6 +3,7 @@ from __future__ import annotations import importlib from functools import lru_cache from pathlib import Path +import re from modules import sd_models, shared from modules.paths import data_path, models_path, script_path @@ -11,6 +12,14 @@ ext_path = Path(data_path, "extensions") ext_builtin_path = Path(script_path, "extensions-builtin") is_in_builtin = False # compatibility for vladmandic/automatic controlnet_exists = False +controlnet_enabled_models = { + 'inpaint': 'inpaint_global_harmonious', + 'scribble': 't2ia_sketch_pidi', + 'lineart': 'lineart_coarse', + 'openpose': 'openpose_full', + 'tile': None, +} +controlnet_model_regex = re.compile(r'.*('+('|'.join(controlnet_enabled_models.keys()))+').*') if ext_path.exists(): controlnet_exists = any( @@ -43,24 +52,31 @@ class ControlNetExt: self.external_cn = importlib.import_module(import_path, "external_code") self.cn_available = True models = self.external_cn.get_models() - self.cn_models.extend(m for m in models if "inpaint" in m) + self.cn_models.extend(m for m in models if controlnet_model_regex.match(m)) + + def _update_scripts_args(self, p, model: str, weight: float, guidance_end: float): + module = None + for m, v in controlnet_enabled_models.items(): + if m in model: + module = v + break - def _update_scripts_args(self, p, model: str, weight: float): cn_units = [ self.external_cn.ControlNetUnit( model=model, weight=weight, control_mode=self.external_cn.ControlMode.BALANCED, - module="inpaint_global_harmonious", + module=module, + guidance_end=guidance_end, pixel_perfect=True, ) ] self.external_cn.update_cn_script_in_processing(p, cn_units) - def update_scripts_args(self, p, model: str, weight: float): + def update_scripts_args(self, p, model: str, weight: float, guidance_end: float): if self.cn_available and model != "None": - self._update_scripts_args(p, model, weight) + self._update_scripts_args(p, model, weight, guidance_end) def get_cn_model_dirs() -> list[Path]: @@ -98,7 +114,7 @@ def _get_cn_inpaint_models() -> list[str]: continue for p in base.rglob("*"): - if p.is_file() and p.suffix in cn_model_exts and "inpaint" in p.name: + if p.is_file() and p.suffix in cn_model_exts and controlnet_model_regex.match(p.name): if name_filter and name_filter not in p.name.lower(): continue model_paths.append(p) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 8af2cb8..8f2d149 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -138,7 +138,7 @@ class AfterDetailerScript(scripts.Script): and args.ad_controlnet_model != "None" ): self.controlnet_ext.update_scripts_args( - p, args.ad_controlnet_model, args.ad_controlnet_weight + p, args.ad_controlnet_model, args.ad_controlnet_weight, args.ad_controlnet_guidance_end ) def is_ad_enabled(self, *args_) -> bool: