feat: pydantic args

This commit is contained in:
Bingsu
2023-05-04 13:31:48 +09:00
parent 8c525a3c21
commit 778b5ebe29
4 changed files with 89 additions and 80 deletions

View File

@@ -1,11 +1,15 @@
from .__version__ import __version__
from .args import ALL_ARGS, ADetailerArgs, get_args
from .common import PredictOutput, get_models
from .mediapipe import mediapipe_predict
from .ultralytics import ultralytics_predict
__all__ = [
"__version__",
"ADetailerArgs",
"ALL_ARGS",
"PredictOutput",
"get_args",
"get_models",
"mediapipe_predict",
"ultralytics_predict",

72
adetailer/args.py Normal file
View File

@@ -0,0 +1,72 @@
from typing import Any, NamedTuple
import pydantic
from pydantic import (
BaseModel,
NonNegativeFloat,
NonNegativeInt,
PositiveInt,
confloat,
validator,
)
class Arg(NamedTuple):
attr: str
name: str
_all_args = [
("ad_enable", "ADetailer enable"),
("ad_model", "ADetailer model"),
("ad_prompt", "ADetailer prompt"),
("ad_negative_prompt", "ADetailer negative prompt"),
("ad_conf", "ADetailer conf"),
("ad_dilate_erode", "ADetailer dilate/erode"),
("ad_x_offset", "ADetailer x offset"),
("ad_y_offset", "ADetailer y offset"),
("ad_mask_blur", "ADetailer mask blur"),
("ad_denoising_strength", "ADetailer denoising strength"),
("ad_inpaint_full_res", "ADetailer inpaint full"),
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
("ad_inpaint_width", "ADetailer inpaint width"),
("ad_inpaint_height", "ADetailer inpaint height"),
("ad_cfg_scale", "ADetailer CFG scale"),
("ad_controlnet_model", "ADetailer ControlNet model"),
("ad_controlnet_weight", "ADetailer ControlNet weight"),
]
ALL_ARGS = [Arg(*args) for args in _all_args]
class ADetailerArgs(BaseModel):
ad_enable: bool = True
ad_model: str = "None"
ad_prompt: str = ""
ad_negative_prompt: str = ""
ad_conf: confloat(ge=0.0, le=1.0) = 0.3
ad_dilate_erode: int = 32
ad_x_offset: int = 0
ad_y_offset: int = 0
ad_mask_blur: NonNegativeInt = 4
ad_denoising_strength: confloat(ge=0.0, le=1.0) = 0.4
ad_inpaint_full_res: bool = True
ad_inpaint_full_res_padding: NonNegativeInt = 0
ad_use_inpaint_width_height: bool = False
ad_inpaint_width: PositiveInt = 512
ad_inpaint_height: PositiveInt = 512
ad_cfg_scale: NonNegativeFloat = 7.0
ad_controlnet_model: str = "None"
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
@validator("ad_conf", pre=True)
def check_ad_conf(cls, v): # noqa: N805
if isinstance(v, int):
v = v / 100.0
return v
def get_args(*args: Any) -> ADetailerArgs:
arg_dict = {all_args.attr: args[i] for i, all_args in enumerate(ALL_ARGS)}
return ADetailerArgs(**arg_dict)

View File

@@ -46,6 +46,7 @@ def install():
("ultralytics", "8.0.87", None),
("mediapipe", "0.9.3.0", None),
("huggingface_hub", None, None),
("pydantic", None, None),
# mediapipe
("protobuf", "3.20.0", "3.20.9999"),
]

View File

@@ -10,7 +10,15 @@ import gradio as gr
import torch
import modules # noqa: F401
from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict
from adetailer import (
ALL_ARGS,
ADetailerArgs,
__version__,
get_args,
get_models,
mediapipe_predict,
ultralytics_predict,
)
from adetailer.common import dilate_erode, is_all_black, offset
from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models
from modules import images, safe, script_callbacks, scripts, shared
@@ -38,78 +46,6 @@ print(
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
)
ALL_ARGS = [
("ad_enable", "ADetailer enable", bool),
("ad_model", "ADetailer model", str),
("ad_prompt", "ADetailer prompt", str),
("ad_negative_prompt", "ADetailer negative prompt", str),
("ad_conf", "ADetailer conf", int),
("ad_dilate_erode", "ADetailer dilate/erode", int),
("ad_x_offset", "ADetailer x offset", int),
("ad_y_offset", "ADetailer y offset", int),
("ad_mask_blur", "ADetailer mask blur", int),
("ad_denoising_strength", "ADetailer denoising strength", float),
("ad_inpaint_full_res", "ADetailer inpaint full", bool),
("ad_inpaint_full_res_padding", "ADetailer inpaint padding", int),
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height", bool),
("ad_inpaint_width", "ADetailer inpaint width", int),
("ad_inpaint_height", "ADetailer inpaint height", int),
("ad_cfg_scale", "ADetailer CFG scale", float),
("ad_controlnet_model", "ADetailer ControlNet model", str),
("ad_controlnet_weight", "ADetailer ControlNet weight", float),
]
class ADetailerArgs:
ad_enable: bool
ad_model: str
ad_prompt: str
ad_negative_prompt: str
ad_conf: float
ad_dilate_erode: int
ad_x_offset: int
ad_y_offset: int
ad_mask_blur: int
ad_denoising_strength: float
ad_inpaint_full_res: bool
ad_inpaint_full_res_padding: int
ad_use_inpaint_width_height: bool
ad_inpaint_width: int
ad_inpaint_height: int
ad_cfg_scale: float
ad_controlnet_model: str
ad_controlnet_weight: float
def __init__(self, *args):
args = self.ensure_dtype(args)
for i, (attr, *_) in enumerate(ALL_ARGS):
if attr == "ad_conf":
setattr(self, attr, args[i] / 100.0)
else:
setattr(self, attr, args[i])
def asdict(self):
return self.__dict__
def ensure_dtype(self, args):
args = list(args)
for i, (attr, _, dtype) in enumerate(ALL_ARGS):
if not isinstance(args[i], dtype):
try:
if dtype is bool:
args[i] = self.is_true(args[i])
else:
args[i] = dtype(args[i])
except ValueError as e:
msg = f"Error converting {args[i]!r}({attr}) to {dtype}: {e}"
raise ValueError(msg) from e
return args
def is_true(self, value: Any):
if isinstance(value, bool):
return value
return str(value).lower() == "true"
class Widgets:
def tolist(self):
@@ -325,7 +261,7 @@ class AfterDetailerScript(scripts.Script):
return args.ad_enable is True and args.ad_model != "None"
def extra_params(self, args: ADetailerArgs):
params = {name: getattr(args, attr) for attr, name, *_ in ALL_ARGS[1:]}
params = {name: getattr(args, attr) for attr, name in ALL_ARGS[1:]}
params["ADetailer conf"] = int(params["ADetailer conf"] * 100)
params["ADetailer version"] = __version__
@@ -344,10 +280,6 @@ class AfterDetailerScript(scripts.Script):
return params
@staticmethod
def get_args(*args):
return ADetailerArgs(*args)
@staticmethod
def get_ultralytics_device():
'`device = ""` means autodetect'
@@ -495,7 +427,7 @@ class AfterDetailerScript(scripts.Script):
)
def process(self, p, *args_):
args = self.get_args(*args_)
args = get_args(*args_)
if self.is_ad_enabled(args):
extra_params = self.extra_params(args)
p.extra_generation_params.update(extra_params)
@@ -504,7 +436,7 @@ class AfterDetailerScript(scripts.Script):
if getattr(p, "_disable_adetailer", False):
return
args = self.get_args(*args_)
args = get_args(*args_)
if not self.is_ad_enabled(args):
return