From 2ace2759b8c330dd0cdfa67198fc5dcccad018ed Mon Sep 17 00:00:00 2001 From: Dowon Date: Sun, 19 May 2024 16:25:06 +0900 Subject: [PATCH] feat(scripts): enable each tap --- aaaaaa/traceback.py | 30 +++++++++++++++++++----------- aaaaaa/ui.py | 20 ++++++++++++++------ adetailer/args.py | 7 +++++-- scripts/!adetailer.py | 8 +++++++- tests/test_args.py | 14 ++++++++++++++ 5 files changed, 59 insertions(+), 20 deletions(-) diff --git a/aaaaaa/traceback.py b/aaaaaa/traceback.py index 74d1848..f4f653d 100644 --- a/aaaaaa/traceback.py +++ b/aaaaaa/traceback.py @@ -12,6 +12,7 @@ from rich.table import Table from rich.traceback import Traceback from adetailer.__version__ import __version__ +from adetailer.args import ADetailerArgs def processing(*args: Any) -> dict[str, Any]: @@ -66,23 +67,30 @@ def sd_models() -> dict[str, str]: def ad_args(*args: Any) -> dict[str, Any]: - ad_args = [ - arg - for arg in args - if isinstance(arg, dict) and arg.get("ad_model", "None") != "None" - ] + ad_args = [] + for arg in args: + if not isinstance(arg, dict): + continue + + try: + a = ADetailerArgs(**arg) + except ValueError: + continue + + if not a.need_skip(): + ad_args.append(a) + if not ad_args: return {} arg0 = ad_args[0] - is_api = arg0.get("is_api", True) return { "version": __version__, - "ad_model": arg0["ad_model"], - "ad_prompt": arg0.get("ad_prompt", ""), - "ad_negative_prompt": arg0.get("ad_negative_prompt", ""), - "ad_controlnet_model": arg0.get("ad_controlnet_model", "None"), - "is_api": type(is_api) is not tuple, + "ad_model": arg0.ad_model, + "ad_prompt": arg0.ad_prompt, + "ad_negative_prompt": arg0.ad_negative_prompt, + "ad_controlnet_model": arg0.ad_controlnet_model, + "is_api": arg0.is_api, } diff --git a/aaaaaa/ui.py b/aaaaaa/ui.py index e7e2ec6..b63760e 100644 --- a/aaaaaa/ui.py +++ b/aaaaaa/ui.py @@ -162,7 +162,7 @@ def adui( states.append(state) infotext_fields.extend(infofields) - # components: [bool, dict, dict, ...] + # components: [bool, bool, dict, dict, ...] components = [ad_enable, ad_skip_img2img, *states] return components, infotext_fields @@ -171,14 +171,22 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo): w = Widgets() eid = partial(elem_id, n=n, is_img2img=is_img2img) + model_choices = ( + [*webui_info.ad_model_list, "None"] + if n == 0 + else ["None", *webui_info.ad_model_list] + ) + with gr.Group(): - with gr.Row(): - model_choices = ( - [*webui_info.ad_model_list, "None"] - if n == 0 - else ["None", *webui_info.ad_model_list] + with gr.Row(variant="compact"): + w.ad_tap_enable = gr.Checkbox( + label=f"Enable this tap ({ordinal(n + 1)})", + value=True, + visible=True, + elem_id=eid("ad_tap_enable"), ) + with gr.Row(): w.ad_model = gr.Dropdown( label="ADetailer detector" + suffix(n), choices=model_choices, diff --git a/adetailer/args.py b/adetailer/args.py index 8ae74e0..ba0808c 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -55,6 +55,7 @@ class ArgsList(UserList): class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_model: str = "None" ad_model_classes: str = "" + ad_tap_enable: bool = True ad_prompt: str = "" ad_negative_prompt: str = "" ad_confidence: confloat(ge=0.0, le=1.0) = 0.3 @@ -119,7 +120,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): p.pop(k, None) def extra_params(self, suffix: str = "") -> dict[str, Any]: - if self.ad_model == "None": + if self.need_skip(): return {} p = {name: getattr(self, attr) for attr, name in ALL_ARGS} @@ -128,6 +129,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ppop("ADetailer model classes") ppop("ADetailer prompt") ppop("ADetailer negative prompt") + p.pop("ADetailer tap enable", None) # always pop ppop("ADetailer mask only top k largest", cond=0) ppop("ADetailer mask min ratio", cond=0.0) ppop("ADetailer mask max ratio", cond=1.0) @@ -204,12 +206,13 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): return self.ad_model.lower().startswith("mediapipe") def need_skip(self) -> bool: - return self.ad_model == "None" + return self.ad_model == "None" or self.ad_tap_enable is False _all_args = [ ("ad_model", "ADetailer model"), ("ad_model_classes", "ADetailer model classes"), + ("ad_tap_enable", "ADetailer tap enable"), ("ad_prompt", "ADetailer prompt"), ("ad_negative_prompt", "ADetailer negative prompt"), ("ad_confidence", "ADetailer confidence"), diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index b12637d..c5a5f16 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -187,7 +187,13 @@ class AfterDetailerScript(scripts.Script): return False ad_enabled = args_[0] if isinstance(args_[0], bool) else True - not_none = any(arg.get("ad_model", "None") != "None" for arg in arg_list) + pydantic_args = [] + for arg in arg_list: + try: + pydantic_args.append(ADetailerArgs(**arg)) + except ValueError: # noqa: PERF203 + continue + not_none = not all(arg.need_skip() for arg in pydantic_args) return ad_enabled and not_none def set_skip_img2img(self, p, *args_) -> None: diff --git a/tests/test_args.py b/tests/test_args.py index eb96330..19db115 100644 --- a/tests/test_args.py +++ b/tests/test_args.py @@ -32,3 +32,17 @@ def test_is_mediapipe(ad_model: str, expect: bool) -> None: def test_need_skip(ad_model: str, expect: bool) -> None: args = ADetailerArgs(ad_model=ad_model) assert args.need_skip() is expect + + +@pytest.mark.parametrize( + ("ad_model", "ad_tap_enable", "expect"), + [ + ("face_yolov8n.pt", False, True), + ("mediapipe_face_full", False, True), + ("None", True, True), + ("ace_yolov8s.pt", True, False), + ], +) +def test_need_skip_tap_enable(ad_model: str, ad_tap_enable: bool, expect: bool) -> None: + args = ADetailerArgs(ad_model=ad_model, ad_tap_enable=ad_tap_enable) + assert args.need_skip() is expect