From 778b5ebe295fd802f9d20eaaaba0aaf70a371acd Mon Sep 17 00:00:00 2001 From: Bingsu Date: Thu, 4 May 2023 13:31:48 +0900 Subject: [PATCH] feat: pydantic args --- adetailer/__init__.py | 4 ++ adetailer/args.py | 72 +++++++++++++++++++++++++++++++++ install.py | 1 + scripts/!adetailer.py | 92 ++++++------------------------------------- 4 files changed, 89 insertions(+), 80 deletions(-) create mode 100644 adetailer/args.py diff --git a/adetailer/__init__.py b/adetailer/__init__.py index cab12ad..5d7109c 100644 --- a/adetailer/__init__.py +++ b/adetailer/__init__.py @@ -1,11 +1,15 @@ from .__version__ import __version__ +from .args import ALL_ARGS, ADetailerArgs, get_args from .common import PredictOutput, get_models from .mediapipe import mediapipe_predict from .ultralytics import ultralytics_predict __all__ = [ "__version__", + "ADetailerArgs", + "ALL_ARGS", "PredictOutput", + "get_args", "get_models", "mediapipe_predict", "ultralytics_predict", diff --git a/adetailer/args.py b/adetailer/args.py new file mode 100644 index 0000000..f5d7337 --- /dev/null +++ b/adetailer/args.py @@ -0,0 +1,72 @@ +from typing import Any, NamedTuple + +import pydantic +from pydantic import ( + BaseModel, + NonNegativeFloat, + NonNegativeInt, + PositiveInt, + confloat, + validator, +) + + +class Arg(NamedTuple): + attr: str + name: str + + +_all_args = [ + ("ad_enable", "ADetailer enable"), + ("ad_model", "ADetailer model"), + ("ad_prompt", "ADetailer prompt"), + ("ad_negative_prompt", "ADetailer negative prompt"), + ("ad_conf", "ADetailer conf"), + ("ad_dilate_erode", "ADetailer dilate/erode"), + ("ad_x_offset", "ADetailer x offset"), + ("ad_y_offset", "ADetailer y offset"), + ("ad_mask_blur", "ADetailer mask blur"), + ("ad_denoising_strength", "ADetailer denoising strength"), + ("ad_inpaint_full_res", "ADetailer inpaint full"), + ("ad_inpaint_full_res_padding", "ADetailer inpaint padding"), + ("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"), + ("ad_inpaint_width", "ADetailer inpaint width"), + ("ad_inpaint_height", "ADetailer inpaint height"), + ("ad_cfg_scale", "ADetailer CFG scale"), + ("ad_controlnet_model", "ADetailer ControlNet model"), + ("ad_controlnet_weight", "ADetailer ControlNet weight"), +] + +ALL_ARGS = [Arg(*args) for args in _all_args] + + +class ADetailerArgs(BaseModel): + ad_enable: bool = True + ad_model: str = "None" + ad_prompt: str = "" + ad_negative_prompt: str = "" + ad_conf: confloat(ge=0.0, le=1.0) = 0.3 + ad_dilate_erode: int = 32 + ad_x_offset: int = 0 + ad_y_offset: int = 0 + ad_mask_blur: NonNegativeInt = 4 + ad_denoising_strength: confloat(ge=0.0, le=1.0) = 0.4 + ad_inpaint_full_res: bool = True + ad_inpaint_full_res_padding: NonNegativeInt = 0 + ad_use_inpaint_width_height: bool = False + ad_inpaint_width: PositiveInt = 512 + ad_inpaint_height: PositiveInt = 512 + ad_cfg_scale: NonNegativeFloat = 7.0 + ad_controlnet_model: str = "None" + ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 + + @validator("ad_conf", pre=True) + def check_ad_conf(cls, v): # noqa: N805 + if isinstance(v, int): + v = v / 100.0 + return v + + +def get_args(*args: Any) -> ADetailerArgs: + arg_dict = {all_args.attr: args[i] for i, all_args in enumerate(ALL_ARGS)} + return ADetailerArgs(**arg_dict) diff --git a/install.py b/install.py index d91fee4..0e8e188 100644 --- a/install.py +++ b/install.py @@ -46,6 +46,7 @@ def install(): ("ultralytics", "8.0.87", None), ("mediapipe", "0.9.3.0", None), ("huggingface_hub", None, None), + ("pydantic", None, None), # mediapipe ("protobuf", "3.20.0", "3.20.9999"), ] diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index d05009f..9fa9b3c 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -10,7 +10,15 @@ import gradio as gr import torch import modules # noqa: F401 -from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict +from adetailer import ( + ALL_ARGS, + ADetailerArgs, + __version__, + get_args, + get_models, + mediapipe_predict, + ultralytics_predict, +) from adetailer.common import dilate_erode, is_all_black, offset from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models from modules import images, safe, script_callbacks, scripts, shared @@ -38,78 +46,6 @@ print( f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}" ) -ALL_ARGS = [ - ("ad_enable", "ADetailer enable", bool), - ("ad_model", "ADetailer model", str), - ("ad_prompt", "ADetailer prompt", str), - ("ad_negative_prompt", "ADetailer negative prompt", str), - ("ad_conf", "ADetailer conf", int), - ("ad_dilate_erode", "ADetailer dilate/erode", int), - ("ad_x_offset", "ADetailer x offset", int), - ("ad_y_offset", "ADetailer y offset", int), - ("ad_mask_blur", "ADetailer mask blur", int), - ("ad_denoising_strength", "ADetailer denoising strength", float), - ("ad_inpaint_full_res", "ADetailer inpaint full", bool), - ("ad_inpaint_full_res_padding", "ADetailer inpaint padding", int), - ("ad_use_inpaint_width_height", "ADetailer use inpaint width/height", bool), - ("ad_inpaint_width", "ADetailer inpaint width", int), - ("ad_inpaint_height", "ADetailer inpaint height", int), - ("ad_cfg_scale", "ADetailer CFG scale", float), - ("ad_controlnet_model", "ADetailer ControlNet model", str), - ("ad_controlnet_weight", "ADetailer ControlNet weight", float), -] - - -class ADetailerArgs: - ad_enable: bool - ad_model: str - ad_prompt: str - ad_negative_prompt: str - ad_conf: float - ad_dilate_erode: int - ad_x_offset: int - ad_y_offset: int - ad_mask_blur: int - ad_denoising_strength: float - ad_inpaint_full_res: bool - ad_inpaint_full_res_padding: int - ad_use_inpaint_width_height: bool - ad_inpaint_width: int - ad_inpaint_height: int - ad_cfg_scale: float - ad_controlnet_model: str - ad_controlnet_weight: float - - def __init__(self, *args): - args = self.ensure_dtype(args) - for i, (attr, *_) in enumerate(ALL_ARGS): - if attr == "ad_conf": - setattr(self, attr, args[i] / 100.0) - else: - setattr(self, attr, args[i]) - - def asdict(self): - return self.__dict__ - - def ensure_dtype(self, args): - args = list(args) - for i, (attr, _, dtype) in enumerate(ALL_ARGS): - if not isinstance(args[i], dtype): - try: - if dtype is bool: - args[i] = self.is_true(args[i]) - else: - args[i] = dtype(args[i]) - except ValueError as e: - msg = f"Error converting {args[i]!r}({attr}) to {dtype}: {e}" - raise ValueError(msg) from e - return args - - def is_true(self, value: Any): - if isinstance(value, bool): - return value - return str(value).lower() == "true" - class Widgets: def tolist(self): @@ -325,7 +261,7 @@ class AfterDetailerScript(scripts.Script): return args.ad_enable is True and args.ad_model != "None" def extra_params(self, args: ADetailerArgs): - params = {name: getattr(args, attr) for attr, name, *_ in ALL_ARGS[1:]} + params = {name: getattr(args, attr) for attr, name in ALL_ARGS[1:]} params["ADetailer conf"] = int(params["ADetailer conf"] * 100) params["ADetailer version"] = __version__ @@ -344,10 +280,6 @@ class AfterDetailerScript(scripts.Script): return params - @staticmethod - def get_args(*args): - return ADetailerArgs(*args) - @staticmethod def get_ultralytics_device(): '`device = ""` means autodetect' @@ -495,7 +427,7 @@ class AfterDetailerScript(scripts.Script): ) def process(self, p, *args_): - args = self.get_args(*args_) + args = get_args(*args_) if self.is_ad_enabled(args): extra_params = self.extra_params(args) p.extra_generation_params.update(extra_params) @@ -504,7 +436,7 @@ class AfterDetailerScript(scripts.Script): if getattr(p, "_disable_adetailer", False): return - args = self.get_args(*args_) + args = get_args(*args_) if not self.is_ad_enabled(args): return