mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-02-21 07:34:05 +00:00
feat: ad enable checker
이제 최소 2개의 인자가 주어지면 동작합니다.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .__version__ import __version__
|
||||
from .args import ALL_ARGS, ADetailerArgs, get_args
|
||||
from .args import ALL_ARGS, ADetailerArgs, EnableChecker, get_args
|
||||
from .common import PredictOutput, get_models
|
||||
from .mediapipe import mediapipe_predict
|
||||
from .ultralytics import ultralytics_predict
|
||||
@@ -8,6 +8,7 @@ __all__ = [
|
||||
"__version__",
|
||||
"ADetailerArgs",
|
||||
"ALL_ARGS",
|
||||
"EnableChecker",
|
||||
"PredictOutput",
|
||||
"get_args",
|
||||
"get_models",
|
||||
|
||||
@@ -98,6 +98,14 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
return params
|
||||
|
||||
|
||||
class EnableChecker(BaseModel):
|
||||
ad_enable: bool = False
|
||||
ad_model: str = "None"
|
||||
|
||||
def is_enabled(self):
|
||||
return self.ad_enable and self.ad_model != "None"
|
||||
|
||||
|
||||
def get_args(*args: Any) -> ADetailerArgs:
|
||||
arg_dict = {all_args.attr: args[i] for i, all_args in enumerate(ALL_ARGS)}
|
||||
arg_dict = {attr: arg for arg, (attr, *_) in zip(args, ALL_ARGS)}
|
||||
return ADetailerArgs(**arg_dict)
|
||||
|
||||
@@ -14,6 +14,7 @@ import modules # noqa: F401
|
||||
from adetailer import (
|
||||
ALL_ARGS,
|
||||
ADetailerArgs,
|
||||
EnableChecker,
|
||||
__version__,
|
||||
get_args,
|
||||
get_models,
|
||||
@@ -269,19 +270,24 @@ class AfterDetailerScript(scripts.Script):
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def is_ad_enabled(self, args: ADetailerArgs) -> bool:
|
||||
return args.ad_enable is True and args.ad_model != "None"
|
||||
def is_ad_enabled(self, *args_) -> bool:
|
||||
if len(args_) < 2:
|
||||
return False
|
||||
checker = EnableChecker(ad_enable=args_[0], ad_model=args_[1])
|
||||
return checker.is_enabled()
|
||||
|
||||
def get_args(self, *args_) -> ADetailerArgs:
|
||||
try:
|
||||
args = get_args(*args_)
|
||||
except IndexError as e:
|
||||
message = [f"[-] ADetailer: IndexError during get_args: {e}"]
|
||||
except ValueError as e:
|
||||
message = [
|
||||
f"[-] ADetailer: ValidationError when validating arguments: {e}\n"
|
||||
]
|
||||
for arg, (attr, *_) in zip_longest(args_, ALL_ARGS):
|
||||
dtype = type(arg)
|
||||
arg = "MISSING" if arg is None else repr(arg)
|
||||
message.append(f" {attr}: {arg} ({dtype})")
|
||||
raise IndexError("\n".join(message)) from e
|
||||
raise ValueError("\n".join(message)) from e
|
||||
|
||||
return args
|
||||
|
||||
@@ -461,8 +467,8 @@ class AfterDetailerScript(scripts.Script):
|
||||
if getattr(p, "_disable_adetailer", False):
|
||||
return
|
||||
|
||||
args = self.get_args(*args_)
|
||||
if self.is_ad_enabled(args):
|
||||
if self.is_ad_enabled(*args_):
|
||||
args = self.get_args(*args_)
|
||||
extra_params = self.extra_params(args)
|
||||
p.extra_generation_params.update(extra_params)
|
||||
|
||||
@@ -531,11 +537,10 @@ class AfterDetailerScript(scripts.Script):
|
||||
if getattr(p, "_disable_adetailer", False):
|
||||
return
|
||||
|
||||
args = self.get_args(*args_)
|
||||
|
||||
if not self.is_ad_enabled(args):
|
||||
if not self.is_ad_enabled(*args_):
|
||||
return
|
||||
|
||||
args = self.get_args(*args_)
|
||||
self._postprocess_image(p, pp, args)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user