feat: change ad args to dict

This commit is contained in:
Bingsu
2023-05-13 13:05:04 +09:00
parent b29711a490
commit 7da441a133
3 changed files with 104 additions and 81 deletions

View File

@@ -1,16 +1,16 @@
from .__version__ import __version__
from .args import ALL_ARGS, ADetailerArgs, EnableChecker, get_one_args
from .args import AD_ENABLE, ALL_ARGS, ADetailerArgs, enable_check
from .common import PredictOutput, get_models
from .mediapipe import mediapipe_predict
from .ultralytics import ultralytics_predict
__all__ = [
"__version__",
"AD_ENABLE",
"ADetailerArgs",
"ALL_ARGS",
"EnableChecker",
"PredictOutput",
"get_one_args",
"enable_check",
"get_models",
"mediapipe_predict",
"ultralytics_predict",

View File

@@ -1,3 +1,6 @@
from collections import UserList
from collections.abc import Mapping
from functools import cached_property
from typing import Any, NamedTuple
import pydantic
@@ -17,35 +20,17 @@ class Arg(NamedTuple):
name: str
_all_args = [
("ad_enable", "ADetailer enable"),
("ad_model", "ADetailer model"),
("ad_prompt", "ADetailer prompt"),
("ad_negative_prompt", "ADetailer negative prompt"),
("ad_conf", "ADetailer conf"),
("ad_dilate_erode", "ADetailer dilate/erode"),
("ad_x_offset", "ADetailer x offset"),
("ad_y_offset", "ADetailer y offset"),
("ad_mask_blur", "ADetailer mask blur"),
("ad_denoising_strength", "ADetailer denoising strength"),
("ad_inpaint_full_res", "ADetailer inpaint full"),
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
("ad_inpaint_width", "ADetailer inpaint width"),
("ad_inpaint_height", "ADetailer inpaint height"),
("ad_use_steps", "ADetailer use separate steps"),
("ad_steps", "ADetailer steps"),
("ad_use_cfg_scale", "ADetailer use separate CFG scale"),
("ad_cfg_scale", "ADetailer CFG scale"),
("ad_controlnet_model", "ADetailer ControlNet model"),
("ad_controlnet_weight", "ADetailer ControlNet weight"),
]
class ArgsList(UserList):
@cached_property
def attrs(self) -> tuple[str]:
return tuple(attr for attr, _ in self)
ALL_ARGS = [Arg(*args) for args in _all_args]
@cached_property
def names(self) -> tuple[str]:
return tuple(name for _, name in self)
class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_enable: bool = False
ad_model: str = "None"
ad_prompt: str = ""
ad_negative_prompt: str = ""
@@ -83,7 +68,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
if self.ad_model == "None":
return {}
params = {name: getattr(self, attr) for attr, name in ALL_ARGS[1:]}
params = {name: getattr(self, attr) for attr, name in ALL_ARGS}
params["ADetailer conf"] = int(params["ADetailer conf"] * 100)
if not params["ADetailer prompt"]:
@@ -122,14 +107,46 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
return params
class EnableChecker(BaseModel):
ad_enable: bool = False
ad_model: str = "None"
def enable_check(*args: Any) -> bool:
if not args:
return False
a0 = args[0]
ad_model = ALL_ARGS[0].attr
def is_enabled(self):
return self.ad_enable and self.ad_model != "None"
if isinstance(a0, Mapping):
return a0.get(ad_model, "None") != "None"
if len(args) == 1:
return False
a1 = args[1]
a1_model = a1.get(ad_model, "None")
return a0 and a1_model != "None"
def get_one_args(*args: Any) -> ADetailerArgs:
arg_dict = {attr: arg for arg, (attr, *_) in zip(args, ALL_ARGS)}
return ADetailerArgs(**arg_dict)
_all_args = [
("ad_enable", "ADetailer enable"),
("ad_model", "ADetailer model"),
("ad_prompt", "ADetailer prompt"),
("ad_negative_prompt", "ADetailer negative prompt"),
("ad_conf", "ADetailer conf"),
("ad_dilate_erode", "ADetailer dilate/erode"),
("ad_x_offset", "ADetailer x offset"),
("ad_y_offset", "ADetailer y offset"),
("ad_mask_blur", "ADetailer mask blur"),
("ad_denoising_strength", "ADetailer denoising strength"),
("ad_inpaint_full_res", "ADetailer inpaint full"),
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
("ad_inpaint_width", "ADetailer inpaint width"),
("ad_inpaint_height", "ADetailer inpaint height"),
("ad_use_steps", "ADetailer use separate steps"),
("ad_steps", "ADetailer steps"),
("ad_use_cfg_scale", "ADetailer use separate CFG scale"),
("ad_cfg_scale", "ADetailer CFG scale"),
("ad_controlnet_model", "ADetailer ControlNet model"),
("ad_controlnet_weight", "ADetailer ControlNet weight"),
]
AD_ENABLE = Arg(*_all_args[0])
_args = [Arg(*args) for args in _all_args[1:]]
ALL_ARGS = ArgsList(_args)

View File

@@ -4,6 +4,7 @@ import platform
import sys
import traceback
from copy import copy, deepcopy
from functools import partial
from itertools import zip_longest
from pathlib import Path
from textwrap import dedent
@@ -14,12 +15,12 @@ import torch
import modules # noqa: F401
from adetailer import (
AD_ENABLE,
ALL_ARGS,
ADetailerArgs,
EnableChecker,
__version__,
enable_check,
get_models,
get_one_args,
mediapipe_predict,
ultralytics_predict,
)
@@ -53,7 +54,7 @@ print(
class Widgets:
def tolist(self):
return [getattr(self, attr) for attr, *_ in ALL_ARGS[1:]]
return [getattr(self, attr) for attr in ALL_ARGS.attrs]
class ChangeTorchLoad:
@@ -65,8 +66,8 @@ class ChangeTorchLoad:
torch.load = self.orig
def gr_show(visible: bool = True):
return {"visible": visible, "__type__": "update"}
def gr_interactive(value: bool = True):
return gr.update(interactive=value)
def ordinal(n: int) -> str:
@@ -78,15 +79,8 @@ def suffix(n: int, c: str = " ") -> str:
return "" if n == 0 else c + ordinal(n + 1)
def on_enable_change(ad_enable: bool, *states):
for state in states:
state["enabled"] = ad_enable
return states
def on_widget_change(state: dict, *values):
for (attr, *_), value in zip(ALL_ARGS[1:], values):
state[attr] = value
def on_widget_change(state: dict, value: Any, *, attr: str):
state[attr] = value
return state
@@ -116,7 +110,7 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
self.infotext_fields.append((ad_enable, ALL_ARGS[0].name))
self.infotext_fields.append((ad_enable, AD_ENABLE.name))
with gr.Group(), gr.Tabs():
for n in range(num_models):
@@ -127,15 +121,13 @@ class AfterDetailerScript(scripts.Script):
states.append(state)
self.infotext_fields.extend(infofields)
ad_enable.change(
fn=on_enable_change, inputs=[ad_enable] + states, outputs=states
)
return states
# return: [bool, dict, dict, ...]
return [ad_enable] + states
def one_ui_group(self, n: int):
model_list = list(model_mapping.keys())
w = Widgets()
state = gr.State({"enabled": False})
state = gr.State({})
with gr.Row():
model_choices = model_list if n == 0 else ["None"] + model_list
@@ -240,6 +232,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_inpaint_full_res.change(
gr_interactive,
inputs=w.ad_inpaint_full_res,
outputs=w.ad_inpaint_full_res_padding,
)
with gr.Column(variant="compact"):
w.ad_use_inpaint_width_height = gr.Checkbox(
label="Use separate width/height" + suffix(n),
@@ -265,6 +263,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_use_inpaint_width_height.change(
gr_interactive,
inputs=w.ad_use_inpaint_width_height,
outputs=[w.ad_inpaint_width, w.ad_inpaint_height],
)
with gr.Row():
with gr.Column(variant="compact"):
w.ad_use_steps = gr.Checkbox(
@@ -282,6 +286,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_use_steps.change(
gr_interactive,
inputs=w.ad_use_steps,
outputs=w.ad_steps,
)
with gr.Column(variant="compact"):
w.ad_use_cfg_scale = gr.Checkbox(
label="Use separate CFG scale" + suffix(n),
@@ -298,6 +308,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_use_cfg_scale.change(
gr_interactive,
inputs=w.ad_use_cfg_scale,
outputs=w.ad_cfg_scale,
)
with gr.Group(), gr.Row(variant="panel"):
cn_inpaint_models = ["None"] + get_cn_inpaint_models()
@@ -320,20 +336,13 @@ class AfterDetailerScript(scripts.Script):
interactive=controlnet_exists,
)
subscribers = []
for attr, *_ in ALL_ARGS[1:]:
for attr in ALL_ARGS.attrs:
widget = getattr(w, attr)
for method in ["edit", "click", "change", "clear"]:
if hasattr(widget, method):
subscribers.append(getattr(widget, method))
sub_inputs = [state] + w.tolist()
for subscriber in subscribers:
subscriber(fn=on_widget_change, inputs=sub_inputs, outputs=state)
on_change = partial(on_widget_change, attr=attr)
widget.change(fn=on_change, inputs=[state, widget], outputs=[state])
infotext_fields = [
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS[1:]
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS
]
return w, state, infotext_fields
@@ -367,36 +376,33 @@ class AfterDetailerScript(scripts.Script):
)
def is_ad_enabled(self, *args_) -> bool:
if len(args_) < 2:
if len(args_) == 0 or (len(args_) == 1 and isinstance(args_[0], bool)):
message = f"""
[-] ADetailer: Not enough arguments passed to adetailer.
[-] ADetailer: Not enough arguments passed to ADetailer.
input: {args_!r}
"""
raise ValueError(dedent(message))
checker = EnableChecker(ad_enable=args_[0], ad_model=args_[1])
return checker.is_enabled()
return enable_check(*args_)
def get_args(self, *args_) -> list[ADetailerArgs]:
"""
`args_` is at least 2 in length by `is_ad_enabled` immediately above
`args_` is at least 1 in length by `is_ad_enabled` immediately above
"""
enabled = args_[0]
rem = args_[1:]
length = len(ALL_ARGS) - 1
args = args_[1:] if isinstance(args_[0], bool) else args_
all_inputs = []
iter_args = (rem[i : i + length] for i in range(0, len(rem), length))
for n, args in enumerate(iter_args, 1):
for n, arg_dict in enumerate(args, 1):
try:
inp = get_one_args(enabled, *args)
inp = ADetailerArgs(**arg_dict)
except ValueError as e:
message = [
f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
]
for arg, (attr, *_) in zip_longest(args, ALL_ARGS[1:]):
for attr in ALL_ARGS.attrs:
arg = arg_dict.get(attr)
dtype = type(arg)
arg = "MISSING" if arg is None else repr(arg)
arg = "DEFAULT" if arg is None else repr(arg)
message.append(f" {attr}: {arg} ({dtype})")
raise ValueError("\n".join(message)) from e