mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-04-21 06:48:53 +00:00
feat: change ad args to dict
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
from .__version__ import __version__
|
||||
from .args import ALL_ARGS, ADetailerArgs, EnableChecker, get_one_args
|
||||
from .args import AD_ENABLE, ALL_ARGS, ADetailerArgs, enable_check
|
||||
from .common import PredictOutput, get_models
|
||||
from .mediapipe import mediapipe_predict
|
||||
from .ultralytics import ultralytics_predict
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"AD_ENABLE",
|
||||
"ADetailerArgs",
|
||||
"ALL_ARGS",
|
||||
"EnableChecker",
|
||||
"PredictOutput",
|
||||
"get_one_args",
|
||||
"enable_check",
|
||||
"get_models",
|
||||
"mediapipe_predict",
|
||||
"ultralytics_predict",
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from collections import UserList
|
||||
from collections.abc import Mapping
|
||||
from functools import cached_property
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pydantic
|
||||
@@ -17,35 +20,17 @@ class Arg(NamedTuple):
|
||||
name: str
|
||||
|
||||
|
||||
_all_args = [
|
||||
("ad_enable", "ADetailer enable"),
|
||||
("ad_model", "ADetailer model"),
|
||||
("ad_prompt", "ADetailer prompt"),
|
||||
("ad_negative_prompt", "ADetailer negative prompt"),
|
||||
("ad_conf", "ADetailer conf"),
|
||||
("ad_dilate_erode", "ADetailer dilate/erode"),
|
||||
("ad_x_offset", "ADetailer x offset"),
|
||||
("ad_y_offset", "ADetailer y offset"),
|
||||
("ad_mask_blur", "ADetailer mask blur"),
|
||||
("ad_denoising_strength", "ADetailer denoising strength"),
|
||||
("ad_inpaint_full_res", "ADetailer inpaint full"),
|
||||
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
|
||||
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
|
||||
("ad_inpaint_width", "ADetailer inpaint width"),
|
||||
("ad_inpaint_height", "ADetailer inpaint height"),
|
||||
("ad_use_steps", "ADetailer use separate steps"),
|
||||
("ad_steps", "ADetailer steps"),
|
||||
("ad_use_cfg_scale", "ADetailer use separate CFG scale"),
|
||||
("ad_cfg_scale", "ADetailer CFG scale"),
|
||||
("ad_controlnet_model", "ADetailer ControlNet model"),
|
||||
("ad_controlnet_weight", "ADetailer ControlNet weight"),
|
||||
]
|
||||
class ArgsList(UserList):
|
||||
@cached_property
|
||||
def attrs(self) -> tuple[str]:
|
||||
return tuple(attr for attr, _ in self)
|
||||
|
||||
ALL_ARGS = [Arg(*args) for args in _all_args]
|
||||
@cached_property
|
||||
def names(self) -> tuple[str]:
|
||||
return tuple(name for _, name in self)
|
||||
|
||||
|
||||
class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
ad_enable: bool = False
|
||||
ad_model: str = "None"
|
||||
ad_prompt: str = ""
|
||||
ad_negative_prompt: str = ""
|
||||
@@ -83,7 +68,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
if self.ad_model == "None":
|
||||
return {}
|
||||
|
||||
params = {name: getattr(self, attr) for attr, name in ALL_ARGS[1:]}
|
||||
params = {name: getattr(self, attr) for attr, name in ALL_ARGS}
|
||||
params["ADetailer conf"] = int(params["ADetailer conf"] * 100)
|
||||
|
||||
if not params["ADetailer prompt"]:
|
||||
@@ -122,14 +107,46 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
return params
|
||||
|
||||
|
||||
class EnableChecker(BaseModel):
|
||||
ad_enable: bool = False
|
||||
ad_model: str = "None"
|
||||
def enable_check(*args: Any) -> bool:
|
||||
if not args:
|
||||
return False
|
||||
a0 = args[0]
|
||||
ad_model = ALL_ARGS[0].attr
|
||||
|
||||
def is_enabled(self):
|
||||
return self.ad_enable and self.ad_model != "None"
|
||||
if isinstance(a0, Mapping):
|
||||
return a0.get(ad_model, "None") != "None"
|
||||
if len(args) == 1:
|
||||
return False
|
||||
|
||||
a1 = args[1]
|
||||
a1_model = a1.get(ad_model, "None")
|
||||
return a0 and a1_model != "None"
|
||||
|
||||
|
||||
def get_one_args(*args: Any) -> ADetailerArgs:
|
||||
arg_dict = {attr: arg for arg, (attr, *_) in zip(args, ALL_ARGS)}
|
||||
return ADetailerArgs(**arg_dict)
|
||||
_all_args = [
|
||||
("ad_enable", "ADetailer enable"),
|
||||
("ad_model", "ADetailer model"),
|
||||
("ad_prompt", "ADetailer prompt"),
|
||||
("ad_negative_prompt", "ADetailer negative prompt"),
|
||||
("ad_conf", "ADetailer conf"),
|
||||
("ad_dilate_erode", "ADetailer dilate/erode"),
|
||||
("ad_x_offset", "ADetailer x offset"),
|
||||
("ad_y_offset", "ADetailer y offset"),
|
||||
("ad_mask_blur", "ADetailer mask blur"),
|
||||
("ad_denoising_strength", "ADetailer denoising strength"),
|
||||
("ad_inpaint_full_res", "ADetailer inpaint full"),
|
||||
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
|
||||
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
|
||||
("ad_inpaint_width", "ADetailer inpaint width"),
|
||||
("ad_inpaint_height", "ADetailer inpaint height"),
|
||||
("ad_use_steps", "ADetailer use separate steps"),
|
||||
("ad_steps", "ADetailer steps"),
|
||||
("ad_use_cfg_scale", "ADetailer use separate CFG scale"),
|
||||
("ad_cfg_scale", "ADetailer CFG scale"),
|
||||
("ad_controlnet_model", "ADetailer ControlNet model"),
|
||||
("ad_controlnet_weight", "ADetailer ControlNet weight"),
|
||||
]
|
||||
|
||||
AD_ENABLE = Arg(*_all_args[0])
|
||||
_args = [Arg(*args) for args in _all_args[1:]]
|
||||
ALL_ARGS = ArgsList(_args)
|
||||
|
||||
@@ -4,6 +4,7 @@ import platform
|
||||
import sys
|
||||
import traceback
|
||||
from copy import copy, deepcopy
|
||||
from functools import partial
|
||||
from itertools import zip_longest
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
@@ -14,12 +15,12 @@ import torch
|
||||
|
||||
import modules # noqa: F401
|
||||
from adetailer import (
|
||||
AD_ENABLE,
|
||||
ALL_ARGS,
|
||||
ADetailerArgs,
|
||||
EnableChecker,
|
||||
__version__,
|
||||
enable_check,
|
||||
get_models,
|
||||
get_one_args,
|
||||
mediapipe_predict,
|
||||
ultralytics_predict,
|
||||
)
|
||||
@@ -53,7 +54,7 @@ print(
|
||||
|
||||
class Widgets:
|
||||
def tolist(self):
|
||||
return [getattr(self, attr) for attr, *_ in ALL_ARGS[1:]]
|
||||
return [getattr(self, attr) for attr in ALL_ARGS.attrs]
|
||||
|
||||
|
||||
class ChangeTorchLoad:
|
||||
@@ -65,8 +66,8 @@ class ChangeTorchLoad:
|
||||
torch.load = self.orig
|
||||
|
||||
|
||||
def gr_show(visible: bool = True):
|
||||
return {"visible": visible, "__type__": "update"}
|
||||
def gr_interactive(value: bool = True):
|
||||
return gr.update(interactive=value)
|
||||
|
||||
|
||||
def ordinal(n: int) -> str:
|
||||
@@ -78,15 +79,8 @@ def suffix(n: int, c: str = " ") -> str:
|
||||
return "" if n == 0 else c + ordinal(n + 1)
|
||||
|
||||
|
||||
def on_enable_change(ad_enable: bool, *states):
|
||||
for state in states:
|
||||
state["enabled"] = ad_enable
|
||||
return states
|
||||
|
||||
|
||||
def on_widget_change(state: dict, *values):
|
||||
for (attr, *_), value in zip(ALL_ARGS[1:], values):
|
||||
state[attr] = value
|
||||
def on_widget_change(state: dict, value: Any, *, attr: str):
|
||||
state[attr] = value
|
||||
return state
|
||||
|
||||
|
||||
@@ -116,7 +110,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
self.infotext_fields.append((ad_enable, ALL_ARGS[0].name))
|
||||
self.infotext_fields.append((ad_enable, AD_ENABLE.name))
|
||||
|
||||
with gr.Group(), gr.Tabs():
|
||||
for n in range(num_models):
|
||||
@@ -127,15 +121,13 @@ class AfterDetailerScript(scripts.Script):
|
||||
states.append(state)
|
||||
self.infotext_fields.extend(infofields)
|
||||
|
||||
ad_enable.change(
|
||||
fn=on_enable_change, inputs=[ad_enable] + states, outputs=states
|
||||
)
|
||||
return states
|
||||
# return: [bool, dict, dict, ...]
|
||||
return [ad_enable] + states
|
||||
|
||||
def one_ui_group(self, n: int):
|
||||
model_list = list(model_mapping.keys())
|
||||
w = Widgets()
|
||||
state = gr.State({"enabled": False})
|
||||
state = gr.State({})
|
||||
|
||||
with gr.Row():
|
||||
model_choices = model_list if n == 0 else ["None"] + model_list
|
||||
@@ -240,6 +232,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_inpaint_full_res.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_inpaint_full_res,
|
||||
outputs=w.ad_inpaint_full_res_padding,
|
||||
)
|
||||
|
||||
with gr.Column(variant="compact"):
|
||||
w.ad_use_inpaint_width_height = gr.Checkbox(
|
||||
label="Use separate width/height" + suffix(n),
|
||||
@@ -265,6 +263,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_use_inpaint_width_height.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_use_inpaint_width_height,
|
||||
outputs=[w.ad_inpaint_width, w.ad_inpaint_height],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(variant="compact"):
|
||||
w.ad_use_steps = gr.Checkbox(
|
||||
@@ -282,6 +286,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_use_steps.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_use_steps,
|
||||
outputs=w.ad_steps,
|
||||
)
|
||||
|
||||
with gr.Column(variant="compact"):
|
||||
w.ad_use_cfg_scale = gr.Checkbox(
|
||||
label="Use separate CFG scale" + suffix(n),
|
||||
@@ -298,6 +308,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_use_cfg_scale.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_use_cfg_scale,
|
||||
outputs=w.ad_cfg_scale,
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row(variant="panel"):
|
||||
cn_inpaint_models = ["None"] + get_cn_inpaint_models()
|
||||
|
||||
@@ -320,20 +336,13 @@ class AfterDetailerScript(scripts.Script):
|
||||
interactive=controlnet_exists,
|
||||
)
|
||||
|
||||
subscribers = []
|
||||
for attr, *_ in ALL_ARGS[1:]:
|
||||
for attr in ALL_ARGS.attrs:
|
||||
widget = getattr(w, attr)
|
||||
for method in ["edit", "click", "change", "clear"]:
|
||||
if hasattr(widget, method):
|
||||
subscribers.append(getattr(widget, method))
|
||||
|
||||
sub_inputs = [state] + w.tolist()
|
||||
|
||||
for subscriber in subscribers:
|
||||
subscriber(fn=on_widget_change, inputs=sub_inputs, outputs=state)
|
||||
on_change = partial(on_widget_change, attr=attr)
|
||||
widget.change(fn=on_change, inputs=[state, widget], outputs=[state])
|
||||
|
||||
infotext_fields = [
|
||||
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS[1:]
|
||||
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS
|
||||
]
|
||||
|
||||
return w, state, infotext_fields
|
||||
@@ -367,36 +376,33 @@ class AfterDetailerScript(scripts.Script):
|
||||
)
|
||||
|
||||
def is_ad_enabled(self, *args_) -> bool:
|
||||
if len(args_) < 2:
|
||||
if len(args_) == 0 or (len(args_) == 1 and isinstance(args_[0], bool)):
|
||||
message = f"""
|
||||
[-] ADetailer: Not enough arguments passed to adetailer.
|
||||
[-] ADetailer: Not enough arguments passed to ADetailer.
|
||||
input: {args_!r}
|
||||
"""
|
||||
raise ValueError(dedent(message))
|
||||
checker = EnableChecker(ad_enable=args_[0], ad_model=args_[1])
|
||||
return checker.is_enabled()
|
||||
return enable_check(*args_)
|
||||
|
||||
def get_args(self, *args_) -> list[ADetailerArgs]:
|
||||
"""
|
||||
`args_` is at least 2 in length by `is_ad_enabled` immediately above
|
||||
`args_` is at least 1 in length by `is_ad_enabled` immediately above
|
||||
"""
|
||||
enabled = args_[0]
|
||||
rem = args_[1:]
|
||||
length = len(ALL_ARGS) - 1
|
||||
args = args_[1:] if isinstance(args_[0], bool) else args_
|
||||
|
||||
all_inputs = []
|
||||
iter_args = (rem[i : i + length] for i in range(0, len(rem), length))
|
||||
|
||||
for n, args in enumerate(iter_args, 1):
|
||||
for n, arg_dict in enumerate(args, 1):
|
||||
try:
|
||||
inp = get_one_args(enabled, *args)
|
||||
inp = ADetailerArgs(**arg_dict)
|
||||
except ValueError as e:
|
||||
message = [
|
||||
f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
|
||||
]
|
||||
for arg, (attr, *_) in zip_longest(args, ALL_ARGS[1:]):
|
||||
for attr in ALL_ARGS.attrs:
|
||||
arg = arg_dict.get(attr)
|
||||
dtype = type(arg)
|
||||
arg = "MISSING" if arg is None else repr(arg)
|
||||
arg = "DEFAULT" if arg is None else repr(arg)
|
||||
message.append(f" {attr}: {arg} ({dtype})")
|
||||
raise ValueError("\n".join(message)) from e
|
||||
|
||||
|
||||
Reference in New Issue
Block a user