From f12f66c298561a93bb2885c3d6c48e61fc5da086 Mon Sep 17 00:00:00 2001 From: Dowon Date: Wed, 15 May 2024 22:13:06 +0900 Subject: [PATCH] refactor: refactor some functions --- aaaaaa/p_method.py | 4 ++++ adetailer/args.py | 6 ++++++ scripts/!adetailer.py | 11 ++++++----- tests/test_args.py | 20 ++++++++++++++++++++ 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/aaaaaa/p_method.py b/aaaaaa/p_method.py index 9a87e7c..a9dfd14 100644 --- a/aaaaaa/p_method.py +++ b/aaaaaa/p_method.py @@ -28,3 +28,7 @@ def get_i(p) -> int: bs = p.batch_size i = p.batch_index return it * bs + i + + +def is_skip_img2img(p) -> bool: + return getattr(p, "_ad_skip_img2img", False) diff --git a/adetailer/args.py b/adetailer/args.py index a54ac6c..e6d0b4e 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -200,6 +200,12 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): return p + def is_mediapipe(self) -> bool: + return self.ad_model.lower().startswith("mediapipe") + + def need_skip(self) -> bool: + return self.ad_model == "None" + _all_args = [ ("ad_model", "ADetailer model"), diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index d2c40a9..696c605 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -26,6 +26,7 @@ from aaaaaa.p_method import ( get_i, is_img2img_inpaint, is_inpaint_only_masked, + is_skip_img2img, need_call_postprocess, need_call_process, ) @@ -625,7 +626,7 @@ class AfterDetailerScript(scripts.Script): @staticmethod def get_i2i_init_image(p, pp): - if getattr(p, "_ad_skip_img2img", False): + if is_skip_img2img(p): return p.init_images[0] return pp.image @@ -649,7 +650,7 @@ class AfterDetailerScript(scripts.Script): mask = ImageChops.invert(mask) mask = create_binary_mask(mask) - if getattr(p, "_ad_skip_img2img", False): + if is_skip_img2img(p): if hasattr(p, "init_images") and p.init_images: width, height = p.init_images[0].size else: @@ -712,7 +713,7 @@ class AfterDetailerScript(scripts.Script): seed, subseed = self.get_seed(p) ad_prompts, ad_negatives = self.get_prompt(p, args) - is_mediapipe = args.ad_model.lower().startswith("mediapipe") + is_mediapipe = args.is_mediapipe() kwargs = {} if is_mediapipe: @@ -800,11 +801,11 @@ class AfterDetailerScript(scripts.Script): is_processed = False with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control(): for n, args in enumerate(arg_list): - if args.ad_model == "None": + if args.need_skip(): continue is_processed |= self._postprocess_image_inner(p, pp, args, n=n) - if is_processed and not getattr(p, "_ad_skip_img2img", False): + if is_processed and not is_skip_img2img(p): self.save_image( p, init_image, condition="ad_save_images_before", suffix="-ad-before" ) diff --git a/tests/test_args.py b/tests/test_args.py index 89b0c2b..eb96330 100644 --- a/tests/test_args.py +++ b/tests/test_args.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from adetailer.args import ALL_ARGS, ADetailerArgs @@ -12,3 +14,21 @@ def test_all_args() -> None: if attr == "is_api": continue assert attr in ALL_ARGS.attrs, attr + + +@pytest.mark.parametrize( + ("ad_model", "expect"), + [("mediapipe_face_full", True), ("face_yolov8n.pt", False)], +) +def test_is_mediapipe(ad_model: str, expect: bool) -> None: + args = ADetailerArgs(ad_model=ad_model) + assert args.is_mediapipe() is expect + + +@pytest.mark.parametrize( + ("ad_model", "expect"), + [("mediapipe_face_full", False), ("face_yolov8n.pt", False), ("None", True)], +) +def test_need_skip(ad_model: str, expect: bool) -> None: + args = ADetailerArgs(ad_model=ad_model) + assert args.need_skip() is expect