feat: pydantic args

This commit is contained in:
Bingsu
2023-05-04 13:31:48 +09:00
parent 8c525a3c21
commit 778b5ebe29
4 changed files with 89 additions and 80 deletions

View File

@@ -10,7 +10,15 @@ import gradio as gr
import torch
import modules # noqa: F401
from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict
from adetailer import (
ALL_ARGS,
ADetailerArgs,
__version__,
get_args,
get_models,
mediapipe_predict,
ultralytics_predict,
)
from adetailer.common import dilate_erode, is_all_black, offset
from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models
from modules import images, safe, script_callbacks, scripts, shared
@@ -38,78 +46,6 @@ print(
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
)
ALL_ARGS = [
("ad_enable", "ADetailer enable", bool),
("ad_model", "ADetailer model", str),
("ad_prompt", "ADetailer prompt", str),
("ad_negative_prompt", "ADetailer negative prompt", str),
("ad_conf", "ADetailer conf", int),
("ad_dilate_erode", "ADetailer dilate/erode", int),
("ad_x_offset", "ADetailer x offset", int),
("ad_y_offset", "ADetailer y offset", int),
("ad_mask_blur", "ADetailer mask blur", int),
("ad_denoising_strength", "ADetailer denoising strength", float),
("ad_inpaint_full_res", "ADetailer inpaint full", bool),
("ad_inpaint_full_res_padding", "ADetailer inpaint padding", int),
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height", bool),
("ad_inpaint_width", "ADetailer inpaint width", int),
("ad_inpaint_height", "ADetailer inpaint height", int),
("ad_cfg_scale", "ADetailer CFG scale", float),
("ad_controlnet_model", "ADetailer ControlNet model", str),
("ad_controlnet_weight", "ADetailer ControlNet weight", float),
]
class ADetailerArgs:
ad_enable: bool
ad_model: str
ad_prompt: str
ad_negative_prompt: str
ad_conf: float
ad_dilate_erode: int
ad_x_offset: int
ad_y_offset: int
ad_mask_blur: int
ad_denoising_strength: float
ad_inpaint_full_res: bool
ad_inpaint_full_res_padding: int
ad_use_inpaint_width_height: bool
ad_inpaint_width: int
ad_inpaint_height: int
ad_cfg_scale: float
ad_controlnet_model: str
ad_controlnet_weight: float
def __init__(self, *args):
args = self.ensure_dtype(args)
for i, (attr, *_) in enumerate(ALL_ARGS):
if attr == "ad_conf":
setattr(self, attr, args[i] / 100.0)
else:
setattr(self, attr, args[i])
def asdict(self):
return self.__dict__
def ensure_dtype(self, args):
args = list(args)
for i, (attr, _, dtype) in enumerate(ALL_ARGS):
if not isinstance(args[i], dtype):
try:
if dtype is bool:
args[i] = self.is_true(args[i])
else:
args[i] = dtype(args[i])
except ValueError as e:
msg = f"Error converting {args[i]!r}({attr}) to {dtype}: {e}"
raise ValueError(msg) from e
return args
def is_true(self, value: Any):
if isinstance(value, bool):
return value
return str(value).lower() == "true"
class Widgets:
def tolist(self):
@@ -325,7 +261,7 @@ class AfterDetailerScript(scripts.Script):
return args.ad_enable is True and args.ad_model != "None"
def extra_params(self, args: ADetailerArgs):
params = {name: getattr(args, attr) for attr, name, *_ in ALL_ARGS[1:]}
params = {name: getattr(args, attr) for attr, name in ALL_ARGS[1:]}
params["ADetailer conf"] = int(params["ADetailer conf"] * 100)
params["ADetailer version"] = __version__
@@ -344,10 +280,6 @@ class AfterDetailerScript(scripts.Script):
return params
@staticmethod
def get_args(*args):
return ADetailerArgs(*args)
@staticmethod
def get_ultralytics_device():
'`device = ""` means autodetect'
@@ -495,7 +427,7 @@ class AfterDetailerScript(scripts.Script):
)
def process(self, p, *args_):
args = self.get_args(*args_)
args = get_args(*args_)
if self.is_ad_enabled(args):
extra_params = self.extra_params(args)
p.extra_generation_params.update(extra_params)
@@ -504,7 +436,7 @@ class AfterDetailerScript(scripts.Script):
if getattr(p, "_disable_adetailer", False):
return
args = self.get_args(*args_)
args = get_args(*args_)
if not self.is_ad_enabled(args):
return