From 3263bd5ed0dbae18efda6f0104d48d3b8a887646 Mon Sep 17 00:00:00 2001 From: Bingsu Date: Thu, 1 Jun 2023 09:34:09 +0900 Subject: [PATCH] fix: add cn guidance start, var names, style, ui --- adetailer/args.py | 9 ++-- adetailer/ui.py | 75 +++++++++++++++++++------------- controlnet_ext/__init__.py | 4 +- controlnet_ext/controlnet_ext.py | 43 +++++++++--------- scripts/!adetailer.py | 7 +-- 5 files changed, 78 insertions(+), 60 deletions(-) diff --git a/adetailer/args.py b/adetailer/args.py index b0fec68..923913f 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -15,6 +15,8 @@ from pydantic import ( constr, ) +cn_model_regex = r".*(inpaint|tile|scribble|lineart|openpose).*|^None$" + class Arg(NamedTuple): attr: str @@ -54,10 +56,9 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_use_cfg_scale: bool = False ad_cfg_scale: NonNegativeFloat = 7.0 ad_restore_face: bool = False - ad_controlnet_model: constr( - regex=r".*(inpaint|tile|scribble|lineart|openpose).*|^None$" - ) = "None" + ad_controlnet_model: constr(regex=cn_model_regex) = "None" ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 + ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0 ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0 @staticmethod @@ -113,6 +114,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): [ "ADetailer ControlNet model", "ADetailer ControlNet weight", + "ADetailer ControlNet guidance start", "ADetailer ControlNet guidance end", ], cond="None", @@ -163,6 +165,7 @@ _all_args = [ ("ad_restore_face", "ADetailer restore face"), ("ad_controlnet_model", "ADetailer ControlNet model"), ("ad_controlnet_weight", "ADetailer ControlNet weight"), + ("ad_controlnet_guidance_start", "ADetailer ControlNet guidance start"), ("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"), ] diff --git a/adetailer/ui.py b/adetailer/ui.py index 72900a4..1c6ceda 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -8,7 +8,7 @@ import gradio as gr from adetailer import AFTER_DETAILER, __version__ from adetailer.args import AD_ENABLE, ALL_ARGS, MASK_MERGE_INVERT -from controlnet_ext import controlnet_exists, get_cn_inpaint_models +from controlnet_ext import controlnet_exists, get_cn_models class Widgets(SimpleNamespace): @@ -155,39 +155,52 @@ def one_ui_group( inpainting(w, n, is_img2img) with gr.Group(), gr.Row(variant="panel"): - cn_inpaint_models = ["None"] + get_cn_inpaint_models() + cn_models = ["None"] + get_cn_models() - w.ad_controlnet_model = gr.Dropdown( - label="ControlNet model" + suffix(n), - choices=cn_inpaint_models, - value="None", - visible=True, - type="value", - interactive=controlnet_exists, - elem_id=eid("ad_controlnet_model"), - ) + with gr.Column(variant="compact"): + w.ad_controlnet_model = gr.Dropdown( + label="ControlNet model" + suffix(n), + choices=cn_models, + value="None", + visible=True, + type="value", + interactive=controlnet_exists, + elem_id=eid("ad_controlnet_model"), + ) - w.ad_controlnet_weight = gr.Slider( - label="ControlNet weight" + suffix(n), - minimum=0.0, - maximum=1.0, - step=0.05, - value=1.0, - visible=True, - interactive=controlnet_exists, - elem_id=eid("ad_controlnet_weight"), - ) + w.ad_controlnet_weight = gr.Slider( + label="ControlNet weight" + suffix(n), + minimum=0.0, + maximum=1.0, + step=0.01, + value=1.0, + visible=True, + interactive=controlnet_exists, + elem_id=eid("ad_controlnet_weight"), + ) - w.ad_controlnet_guidance_end = gr.Slider( - label="ControlNet guidance end" + suffix(n), - minimum=0.0, - maximum=1.0, - step=0.05, - value=1.0, - visible=True, - interactive=controlnet_exists, - elem_id=eid("ad_controlnet_guidance_end"), - ) + with gr.Column(variant="compact"): + w.ad_controlnet_guidance_start = gr.Slider( + label="ControlNet guidance start" + suffix(n), + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.0, + visible=True, + interactive=controlnet_exists, + elem_id=eid("ad_controlnet_guidance_start"), + ) + + w.ad_controlnet_guidance_end = gr.Slider( + label="ControlNet guidance end" + suffix(n), + minimum=0.0, + maximum=1.0, + step=0.01, + value=1.0, + visible=True, + interactive=controlnet_exists, + elem_id=eid("ad_controlnet_guidance_end"), + ) for attr in ALL_ARGS.attrs: widget = getattr(w, attr) diff --git a/controlnet_ext/__init__.py b/controlnet_ext/__init__.py index e032d2d..0ab6668 100644 --- a/controlnet_ext/__init__.py +++ b/controlnet_ext/__init__.py @@ -1,7 +1,7 @@ -from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models +from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models __all__ = [ "ControlNetExt", "controlnet_exists", - "get_cn_inpaint_models", + "get_cn_models", ] diff --git a/controlnet_ext/controlnet_ext.py b/controlnet_ext/controlnet_ext.py index d3e294e..68fd186 100644 --- a/controlnet_ext/controlnet_ext.py +++ b/controlnet_ext/controlnet_ext.py @@ -12,16 +12,6 @@ ext_path = Path(data_path, "extensions") ext_builtin_path = Path(script_path, "extensions-builtin") is_in_builtin = False # compatibility for vladmandic/automatic controlnet_exists = False -controlnet_enabled_models = { - "inpaint": "inpaint_global_harmonious", - "scribble": "t2ia_sketch_pidi", - "lineart": "lineart_coarse", - "openpose": "openpose_full", - "tile": None, -} -controlnet_model_regex = re.compile( - r".*(" + ("|".join(controlnet_enabled_models.keys())) + ").*" -) if ext_path.exists(): controlnet_exists = any( @@ -38,6 +28,15 @@ if not controlnet_exists and ext_builtin_path.exists(): if controlnet_exists: is_in_builtin = True +cn_model_module = { + "inpaint": "inpaint_global_harmonious", + "scribble": "t2ia_sketch_pidi", + "lineart": "lineart_coarse", + "openpose": "openpose_full", + "tile": None, +} +cn_model_regex = re.compile("|".join(cn_model_module.keys())) + class ControlNetExt: def __init__(self): @@ -54,11 +53,16 @@ class ControlNetExt: self.external_cn = importlib.import_module(import_path, "external_code") self.cn_available = True models = self.external_cn.get_models() - self.cn_models.extend(m for m in models if controlnet_model_regex.match(m)) + self.cn_models.extend(m for m in models if cn_model_regex.search(m)) + + def update_scripts_args( + self, p, model: str, weight: float, guidance_start: float, guidance_end: float + ): + if (not self.cn_available) or model == "None": + return - def _update_scripts_args(self, p, model: str, weight: float, guidance_end: float): module = None - for m, v in controlnet_enabled_models.items(): + for m, v in cn_model_module.items(): if m in model: module = v break @@ -69,6 +73,7 @@ class ControlNetExt: weight=weight, control_mode=self.external_cn.ControlMode.BALANCED, module=module, + guidance_start=guidance_start, guidance_end=guidance_end, pixel_perfect=True, ) @@ -76,10 +81,6 @@ class ControlNetExt: self.external_cn.update_cn_script_in_processing(p, cn_units) - def update_scripts_args(self, p, model: str, weight: float, guidance_end: float): - if self.cn_available and model != "None": - self._update_scripts_args(p, model, weight, guidance_end) - def get_cn_model_dirs() -> list[Path]: cn_model_dir = Path(models_path, "ControlNet") @@ -99,7 +100,7 @@ def get_cn_model_dirs() -> list[Path]: @lru_cache -def _get_cn_inpaint_models() -> list[str]: +def _get_cn_models() -> list[str]: """ Since we can't import ControlNet, we use a function that does something like controlnet's `list(global_state.cn_models_names.values())`. @@ -119,7 +120,7 @@ def _get_cn_inpaint_models() -> list[str]: if ( p.is_file() and p.suffix in cn_model_exts - and controlnet_model_regex.match(p.name) + and cn_model_regex.search(p.name) ): if name_filter and name_filter not in p.name.lower(): continue @@ -134,7 +135,7 @@ def _get_cn_inpaint_models() -> list[str]: return models -def get_cn_inpaint_models() -> list[str]: +def get_cn_models() -> list[str]: if controlnet_exists: - return _get_cn_inpaint_models() + return _get_cn_models() return [] diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 4717e7a..5b62fb3 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -139,9 +139,10 @@ class AfterDetailerScript(scripts.Script): ): self.controlnet_ext.update_scripts_args( p, - args.ad_controlnet_model, - args.ad_controlnet_weight, - args.ad_controlnet_guidance_end, + model=args.ad_controlnet_model, + weight=args.ad_controlnet_weight, + guidance_start=args.ad_controlnet_guidance_start, + guidance_end=args.ad_controlnet_guidance_end, ) def is_ad_enabled(self, *args_) -> bool: