feat: ad enable checker

이제 최소 2개의 인자가 주어지면 동작합니다.
This commit is contained in:
Bingsu
2023-05-06 01:53:45 +09:00
parent 111f04ee20
commit 10482a1da1
3 changed files with 26 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
from .__version__ import __version__
from .args import ALL_ARGS, ADetailerArgs, get_args
from .args import ALL_ARGS, ADetailerArgs, EnableChecker, get_args
from .common import PredictOutput, get_models
from .mediapipe import mediapipe_predict
from .ultralytics import ultralytics_predict
@@ -8,6 +8,7 @@ __all__ = [
"__version__",
"ADetailerArgs",
"ALL_ARGS",
"EnableChecker",
"PredictOutput",
"get_args",
"get_models",

View File

@@ -98,6 +98,14 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
return params
class EnableChecker(BaseModel):
ad_enable: bool = False
ad_model: str = "None"
def is_enabled(self):
return self.ad_enable and self.ad_model != "None"
def get_args(*args: Any) -> ADetailerArgs:
arg_dict = {all_args.attr: args[i] for i, all_args in enumerate(ALL_ARGS)}
arg_dict = {attr: arg for arg, (attr, *_) in zip(args, ALL_ARGS)}
return ADetailerArgs(**arg_dict)

View File

@@ -14,6 +14,7 @@ import modules # noqa: F401
from adetailer import (
ALL_ARGS,
ADetailerArgs,
EnableChecker,
__version__,
get_args,
get_models,
@@ -269,19 +270,24 @@ class AfterDetailerScript(scripts.Script):
file=sys.stderr,
)
def is_ad_enabled(self, args: ADetailerArgs) -> bool:
return args.ad_enable is True and args.ad_model != "None"
def is_ad_enabled(self, *args_) -> bool:
if len(args_) < 2:
return False
checker = EnableChecker(ad_enable=args_[0], ad_model=args_[1])
return checker.is_enabled()
def get_args(self, *args_) -> ADetailerArgs:
try:
args = get_args(*args_)
except IndexError as e:
message = [f"[-] ADetailer: IndexError during get_args: {e}"]
except ValueError as e:
message = [
f"[-] ADetailer: ValidationError when validating arguments: {e}\n"
]
for arg, (attr, *_) in zip_longest(args_, ALL_ARGS):
dtype = type(arg)
arg = "MISSING" if arg is None else repr(arg)
message.append(f" {attr}: {arg} ({dtype})")
raise IndexError("\n".join(message)) from e
raise ValueError("\n".join(message)) from e
return args
@@ -461,8 +467,8 @@ class AfterDetailerScript(scripts.Script):
if getattr(p, "_disable_adetailer", False):
return
args = self.get_args(*args_)
if self.is_ad_enabled(args):
if self.is_ad_enabled(*args_):
args = self.get_args(*args_)
extra_params = self.extra_params(args)
p.extra_generation_params.update(extra_params)
@@ -531,11 +537,10 @@ class AfterDetailerScript(scripts.Script):
if getattr(p, "_disable_adetailer", False):
return
args = self.get_args(*args_)
if not self.is_ad_enabled(args):
if not self.is_ad_enabled(*args_):
return
args = self.get_args(*args_)
self._postprocess_image(p, pp, args)
try: