mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 11:19:53 +00:00
feat: pydantic args
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
from .__version__ import __version__
|
||||
from .args import ALL_ARGS, ADetailerArgs, get_args
|
||||
from .common import PredictOutput, get_models
|
||||
from .mediapipe import mediapipe_predict
|
||||
from .ultralytics import ultralytics_predict
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"ADetailerArgs",
|
||||
"ALL_ARGS",
|
||||
"PredictOutput",
|
||||
"get_args",
|
||||
"get_models",
|
||||
"mediapipe_predict",
|
||||
"ultralytics_predict",
|
||||
|
||||
72
adetailer/args.py
Normal file
72
adetailer/args.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pydantic
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
PositiveInt,
|
||||
confloat,
|
||||
validator,
|
||||
)
|
||||
|
||||
|
||||
class Arg(NamedTuple):
|
||||
attr: str
|
||||
name: str
|
||||
|
||||
|
||||
_all_args = [
|
||||
("ad_enable", "ADetailer enable"),
|
||||
("ad_model", "ADetailer model"),
|
||||
("ad_prompt", "ADetailer prompt"),
|
||||
("ad_negative_prompt", "ADetailer negative prompt"),
|
||||
("ad_conf", "ADetailer conf"),
|
||||
("ad_dilate_erode", "ADetailer dilate/erode"),
|
||||
("ad_x_offset", "ADetailer x offset"),
|
||||
("ad_y_offset", "ADetailer y offset"),
|
||||
("ad_mask_blur", "ADetailer mask blur"),
|
||||
("ad_denoising_strength", "ADetailer denoising strength"),
|
||||
("ad_inpaint_full_res", "ADetailer inpaint full"),
|
||||
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
|
||||
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
|
||||
("ad_inpaint_width", "ADetailer inpaint width"),
|
||||
("ad_inpaint_height", "ADetailer inpaint height"),
|
||||
("ad_cfg_scale", "ADetailer CFG scale"),
|
||||
("ad_controlnet_model", "ADetailer ControlNet model"),
|
||||
("ad_controlnet_weight", "ADetailer ControlNet weight"),
|
||||
]
|
||||
|
||||
ALL_ARGS = [Arg(*args) for args in _all_args]
|
||||
|
||||
|
||||
class ADetailerArgs(BaseModel):
|
||||
ad_enable: bool = True
|
||||
ad_model: str = "None"
|
||||
ad_prompt: str = ""
|
||||
ad_negative_prompt: str = ""
|
||||
ad_conf: confloat(ge=0.0, le=1.0) = 0.3
|
||||
ad_dilate_erode: int = 32
|
||||
ad_x_offset: int = 0
|
||||
ad_y_offset: int = 0
|
||||
ad_mask_blur: NonNegativeInt = 4
|
||||
ad_denoising_strength: confloat(ge=0.0, le=1.0) = 0.4
|
||||
ad_inpaint_full_res: bool = True
|
||||
ad_inpaint_full_res_padding: NonNegativeInt = 0
|
||||
ad_use_inpaint_width_height: bool = False
|
||||
ad_inpaint_width: PositiveInt = 512
|
||||
ad_inpaint_height: PositiveInt = 512
|
||||
ad_cfg_scale: NonNegativeFloat = 7.0
|
||||
ad_controlnet_model: str = "None"
|
||||
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
|
||||
|
||||
@validator("ad_conf", pre=True)
|
||||
def check_ad_conf(cls, v): # noqa: N805
|
||||
if isinstance(v, int):
|
||||
v = v / 100.0
|
||||
return v
|
||||
|
||||
|
||||
def get_args(*args: Any) -> ADetailerArgs:
|
||||
arg_dict = {all_args.attr: args[i] for i, all_args in enumerate(ALL_ARGS)}
|
||||
return ADetailerArgs(**arg_dict)
|
||||
@@ -46,6 +46,7 @@ def install():
|
||||
("ultralytics", "8.0.87", None),
|
||||
("mediapipe", "0.9.3.0", None),
|
||||
("huggingface_hub", None, None),
|
||||
("pydantic", None, None),
|
||||
# mediapipe
|
||||
("protobuf", "3.20.0", "3.20.9999"),
|
||||
]
|
||||
|
||||
@@ -10,7 +10,15 @@ import gradio as gr
|
||||
import torch
|
||||
|
||||
import modules # noqa: F401
|
||||
from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict
|
||||
from adetailer import (
|
||||
ALL_ARGS,
|
||||
ADetailerArgs,
|
||||
__version__,
|
||||
get_args,
|
||||
get_models,
|
||||
mediapipe_predict,
|
||||
ultralytics_predict,
|
||||
)
|
||||
from adetailer.common import dilate_erode, is_all_black, offset
|
||||
from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models
|
||||
from modules import images, safe, script_callbacks, scripts, shared
|
||||
@@ -38,78 +46,6 @@ print(
|
||||
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
|
||||
)
|
||||
|
||||
ALL_ARGS = [
|
||||
("ad_enable", "ADetailer enable", bool),
|
||||
("ad_model", "ADetailer model", str),
|
||||
("ad_prompt", "ADetailer prompt", str),
|
||||
("ad_negative_prompt", "ADetailer negative prompt", str),
|
||||
("ad_conf", "ADetailer conf", int),
|
||||
("ad_dilate_erode", "ADetailer dilate/erode", int),
|
||||
("ad_x_offset", "ADetailer x offset", int),
|
||||
("ad_y_offset", "ADetailer y offset", int),
|
||||
("ad_mask_blur", "ADetailer mask blur", int),
|
||||
("ad_denoising_strength", "ADetailer denoising strength", float),
|
||||
("ad_inpaint_full_res", "ADetailer inpaint full", bool),
|
||||
("ad_inpaint_full_res_padding", "ADetailer inpaint padding", int),
|
||||
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height", bool),
|
||||
("ad_inpaint_width", "ADetailer inpaint width", int),
|
||||
("ad_inpaint_height", "ADetailer inpaint height", int),
|
||||
("ad_cfg_scale", "ADetailer CFG scale", float),
|
||||
("ad_controlnet_model", "ADetailer ControlNet model", str),
|
||||
("ad_controlnet_weight", "ADetailer ControlNet weight", float),
|
||||
]
|
||||
|
||||
|
||||
class ADetailerArgs:
|
||||
ad_enable: bool
|
||||
ad_model: str
|
||||
ad_prompt: str
|
||||
ad_negative_prompt: str
|
||||
ad_conf: float
|
||||
ad_dilate_erode: int
|
||||
ad_x_offset: int
|
||||
ad_y_offset: int
|
||||
ad_mask_blur: int
|
||||
ad_denoising_strength: float
|
||||
ad_inpaint_full_res: bool
|
||||
ad_inpaint_full_res_padding: int
|
||||
ad_use_inpaint_width_height: bool
|
||||
ad_inpaint_width: int
|
||||
ad_inpaint_height: int
|
||||
ad_cfg_scale: float
|
||||
ad_controlnet_model: str
|
||||
ad_controlnet_weight: float
|
||||
|
||||
def __init__(self, *args):
|
||||
args = self.ensure_dtype(args)
|
||||
for i, (attr, *_) in enumerate(ALL_ARGS):
|
||||
if attr == "ad_conf":
|
||||
setattr(self, attr, args[i] / 100.0)
|
||||
else:
|
||||
setattr(self, attr, args[i])
|
||||
|
||||
def asdict(self):
|
||||
return self.__dict__
|
||||
|
||||
def ensure_dtype(self, args):
|
||||
args = list(args)
|
||||
for i, (attr, _, dtype) in enumerate(ALL_ARGS):
|
||||
if not isinstance(args[i], dtype):
|
||||
try:
|
||||
if dtype is bool:
|
||||
args[i] = self.is_true(args[i])
|
||||
else:
|
||||
args[i] = dtype(args[i])
|
||||
except ValueError as e:
|
||||
msg = f"Error converting {args[i]!r}({attr}) to {dtype}: {e}"
|
||||
raise ValueError(msg) from e
|
||||
return args
|
||||
|
||||
def is_true(self, value: Any):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).lower() == "true"
|
||||
|
||||
|
||||
class Widgets:
|
||||
def tolist(self):
|
||||
@@ -325,7 +261,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
return args.ad_enable is True and args.ad_model != "None"
|
||||
|
||||
def extra_params(self, args: ADetailerArgs):
|
||||
params = {name: getattr(args, attr) for attr, name, *_ in ALL_ARGS[1:]}
|
||||
params = {name: getattr(args, attr) for attr, name in ALL_ARGS[1:]}
|
||||
params["ADetailer conf"] = int(params["ADetailer conf"] * 100)
|
||||
params["ADetailer version"] = __version__
|
||||
|
||||
@@ -344,10 +280,6 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def get_args(*args):
|
||||
return ADetailerArgs(*args)
|
||||
|
||||
@staticmethod
|
||||
def get_ultralytics_device():
|
||||
'`device = ""` means autodetect'
|
||||
@@ -495,7 +427,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
)
|
||||
|
||||
def process(self, p, *args_):
|
||||
args = self.get_args(*args_)
|
||||
args = get_args(*args_)
|
||||
if self.is_ad_enabled(args):
|
||||
extra_params = self.extra_params(args)
|
||||
p.extra_generation_params.update(extra_params)
|
||||
@@ -504,7 +436,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
if getattr(p, "_disable_adetailer", False):
|
||||
return
|
||||
|
||||
args = self.get_args(*args_)
|
||||
args = get_args(*args_)
|
||||
|
||||
if not self.is_ad_enabled(args):
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user