diff --git a/adetailer/args.py b/adetailer/args.py index 3de6200..9e23ba0 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import UserList from functools import cached_property, partial -from typing import Any, Literal, NamedTuple, Union +from typing import Any, Literal, NamedTuple, Optional, Union import pydantic from pydantic import ( @@ -13,6 +13,7 @@ from pydantic import ( PositiveInt, confloat, constr, + root_validator, ) cn_model_regex = r".*(inpaint|tile|scribble|lineart|openpose).*|^None$" @@ -57,10 +58,19 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_cfg_scale: NonNegativeFloat = 7.0 ad_restore_face: bool = False ad_controlnet_model: constr(regex=cn_model_regex) = "None" + ad_controlnet_module: Optional[constr(regex=r".*inpaint.*|^None$")] = None ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0 ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0 + @root_validator + def ad_controlnt_module_validator(cls, values): # noqa: N805 + cn_model = values.get("ad_controlnet_model", "None") + cn_module = values.get("ad_controlnet_module", None) + if "inpaint" not in cn_model or cn_module == "None": + values["ad_controlnet_module"] = None + return values + @staticmethod def ppop( p: dict[str, Any], @@ -115,6 +125,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): "ADetailer ControlNet model", [ "ADetailer ControlNet model", + "ADetailer ControlNet module", "ADetailer ControlNet weight", "ADetailer ControlNet guidance start", "ADetailer ControlNet guidance end", @@ -169,6 +180,7 @@ _all_args = [ ("ad_cfg_scale", "ADetailer CFG scale"), ("ad_restore_face", "ADetailer restore face"), ("ad_controlnet_model", "ADetailer ControlNet model"), + ("ad_controlnet_module", "ADetailer ControlNet module"), ("ad_controlnet_weight", "ADetailer ControlNet weight"), ("ad_controlnet_guidance_start", "ADetailer ControlNet guidance start"), ("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"), diff --git a/adetailer/ui.py b/adetailer/ui.py index 2cd423d..8d53f98 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -10,6 +10,12 @@ from adetailer import AFTER_DETAILER, __version__ from adetailer.args import AD_ENABLE, ALL_ARGS, MASK_MERGE_INVERT from controlnet_ext import controlnet_exists, get_cn_models +cn_module_choices = [ + "inpaint_global_harmonious", + "inpaint_only", + "inpaint_only+lama", +] + class Widgets(SimpleNamespace): def tolist(self): @@ -40,6 +46,12 @@ def on_generate_click(state: dict, *values: Any): return state +def on_cn_model_update(cn_model: str): + if "inpaint" in cn_model: + return gr.update(visible=True, choices=cn_module_choices) + return gr.update(visible=False, choices=["None"]) + + def elem_id(item_id: str, n: int, is_img2img: bool) -> str: tap = "img2img" if is_img2img else "txt2img" suf = suffix(n, "_") @@ -416,6 +428,16 @@ def controlnet(w: Widgets, n: int, is_img2img: bool): elem_id=eid("ad_controlnet_model"), ) + w.ad_controlnet_module = gr.Dropdown( + label="ControlNet module" + suffix(n), + choices=cn_module_choices, + value="inpaint_global_harmonious", + visible=False, + type="value", + interactive=controlnet_exists, + elem_id=eid("ad_controlnet_module"), + ) + w.ad_controlnet_weight = gr.Slider( label="ControlNet weight" + suffix(n), minimum=0.0, @@ -427,6 +449,13 @@ def controlnet(w: Widgets, n: int, is_img2img: bool): elem_id=eid("ad_controlnet_weight"), ) + w.ad_controlnet_model.change( + on_cn_model_update, + inputs=w.ad_controlnet_model, + outputs=w.ad_controlnet_module, + queue=False, + ) + with gr.Column(variant="compact"): w.ad_controlnet_guidance_start = gr.Slider( label="ControlNet guidance start" + suffix(n), diff --git a/controlnet_ext/controlnet_ext.py b/controlnet_ext/controlnet_ext.py index f6ffb79..f0c8918 100644 --- a/controlnet_ext/controlnet_ext.py +++ b/controlnet_ext/controlnet_ext.py @@ -49,16 +49,22 @@ class ControlNetExt: self.cn_models.extend(m for m in models if cn_model_regex.search(m)) def update_scripts_args( - self, p, model: str, weight: float, guidance_start: float, guidance_end: float + self, + p, + model: str, + module: str | None, + weight: float, + guidance_start: float, + guidance_end: float, ): if (not self.cn_available) or model == "None": return - module = None - for m, v in cn_model_module.items(): - if m in model: - module = v - break + if module is None: + for m, v in cn_model_module.items(): + if m in model: + module = v + break cn_units = [ self.external_cn.ControlNetUnit( diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index a3c746a..e40d6df 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -141,6 +141,7 @@ class AfterDetailerScript(scripts.Script): self.controlnet_ext.update_scripts_args( p, model=args.ad_controlnet_model, + module=args.ad_controlnet_module, weight=args.ad_controlnet_weight, guidance_start=args.ad_controlnet_guidance_start, guidance_end=args.ad_controlnet_guidance_end,