From 10482a1da11eac464c1ae65cdcea016ba71d9d00 Mon Sep 17 00:00:00 2001 From: Bingsu Date: Sat, 6 May 2023 01:53:45 +0900 Subject: [PATCH] feat: ad enable checker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 이제 최소 2개의 인자가 주어지면 동작합니다. --- adetailer/__init__.py | 3 ++- adetailer/args.py | 10 +++++++++- scripts/!adetailer.py | 25 +++++++++++++++---------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/adetailer/__init__.py b/adetailer/__init__.py index 5d7109c..fab3898 100644 --- a/adetailer/__init__.py +++ b/adetailer/__init__.py @@ -1,5 +1,5 @@ from .__version__ import __version__ -from .args import ALL_ARGS, ADetailerArgs, get_args +from .args import ALL_ARGS, ADetailerArgs, EnableChecker, get_args from .common import PredictOutput, get_models from .mediapipe import mediapipe_predict from .ultralytics import ultralytics_predict @@ -8,6 +8,7 @@ __all__ = [ "__version__", "ADetailerArgs", "ALL_ARGS", + "EnableChecker", "PredictOutput", "get_args", "get_models", diff --git a/adetailer/args.py b/adetailer/args.py index 2f7e11d..2c435d0 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -98,6 +98,14 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): return params +class EnableChecker(BaseModel): + ad_enable: bool = False + ad_model: str = "None" + + def is_enabled(self): + return self.ad_enable and self.ad_model != "None" + + def get_args(*args: Any) -> ADetailerArgs: - arg_dict = {all_args.attr: args[i] for i, all_args in enumerate(ALL_ARGS)} + arg_dict = {attr: arg for arg, (attr, *_) in zip(args, ALL_ARGS)} return ADetailerArgs(**arg_dict) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 1ac4f0b..0498201 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -14,6 +14,7 @@ import modules # noqa: F401 from adetailer import ( ALL_ARGS, ADetailerArgs, + EnableChecker, __version__, get_args, get_models, @@ -269,19 +270,24 @@ class AfterDetailerScript(scripts.Script): file=sys.stderr, ) - def is_ad_enabled(self, args: ADetailerArgs) -> bool: - return args.ad_enable is True and args.ad_model != "None" + def is_ad_enabled(self, *args_) -> bool: + if len(args_) < 2: + return False + checker = EnableChecker(ad_enable=args_[0], ad_model=args_[1]) + return checker.is_enabled() def get_args(self, *args_) -> ADetailerArgs: try: args = get_args(*args_) - except IndexError as e: - message = [f"[-] ADetailer: IndexError during get_args: {e}"] + except ValueError as e: + message = [ + f"[-] ADetailer: ValidationError when validating arguments: {e}\n" + ] for arg, (attr, *_) in zip_longest(args_, ALL_ARGS): dtype = type(arg) arg = "MISSING" if arg is None else repr(arg) message.append(f" {attr}: {arg} ({dtype})") - raise IndexError("\n".join(message)) from e + raise ValueError("\n".join(message)) from e return args @@ -461,8 +467,8 @@ class AfterDetailerScript(scripts.Script): if getattr(p, "_disable_adetailer", False): return - args = self.get_args(*args_) - if self.is_ad_enabled(args): + if self.is_ad_enabled(*args_): + args = self.get_args(*args_) extra_params = self.extra_params(args) p.extra_generation_params.update(extra_params) @@ -531,11 +537,10 @@ class AfterDetailerScript(scripts.Script): if getattr(p, "_disable_adetailer", False): return - args = self.get_args(*args_) - - if not self.is_ad_enabled(args): + if not self.is_ad_enabled(*args_): return + args = self.get_args(*args_) self._postprocess_image(p, pp, args) try: