feat: change ad args to dict

This commit is contained in:
Bingsu
2023-05-13 13:05:04 +09:00
parent b29711a490
commit 7da441a133
3 changed files with 104 additions and 81 deletions

View File

@@ -4,6 +4,7 @@ import platform
import sys
import traceback
from copy import copy, deepcopy
from functools import partial
from itertools import zip_longest
from pathlib import Path
from textwrap import dedent
@@ -14,12 +15,12 @@ import torch
import modules # noqa: F401
from adetailer import (
AD_ENABLE,
ALL_ARGS,
ADetailerArgs,
EnableChecker,
__version__,
enable_check,
get_models,
get_one_args,
mediapipe_predict,
ultralytics_predict,
)
@@ -53,7 +54,7 @@ print(
class Widgets:
def tolist(self):
return [getattr(self, attr) for attr, *_ in ALL_ARGS[1:]]
return [getattr(self, attr) for attr in ALL_ARGS.attrs]
class ChangeTorchLoad:
@@ -65,8 +66,8 @@ class ChangeTorchLoad:
torch.load = self.orig
def gr_show(visible: bool = True):
return {"visible": visible, "__type__": "update"}
def gr_interactive(value: bool = True):
return gr.update(interactive=value)
def ordinal(n: int) -> str:
@@ -78,15 +79,8 @@ def suffix(n: int, c: str = " ") -> str:
return "" if n == 0 else c + ordinal(n + 1)
def on_enable_change(ad_enable: bool, *states):
for state in states:
state["enabled"] = ad_enable
return states
def on_widget_change(state: dict, *values):
for (attr, *_), value in zip(ALL_ARGS[1:], values):
state[attr] = value
def on_widget_change(state: dict, value: Any, *, attr: str):
state[attr] = value
return state
@@ -116,7 +110,7 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
self.infotext_fields.append((ad_enable, ALL_ARGS[0].name))
self.infotext_fields.append((ad_enable, AD_ENABLE.name))
with gr.Group(), gr.Tabs():
for n in range(num_models):
@@ -127,15 +121,13 @@ class AfterDetailerScript(scripts.Script):
states.append(state)
self.infotext_fields.extend(infofields)
ad_enable.change(
fn=on_enable_change, inputs=[ad_enable] + states, outputs=states
)
return states
# return: [bool, dict, dict, ...]
return [ad_enable] + states
def one_ui_group(self, n: int):
model_list = list(model_mapping.keys())
w = Widgets()
state = gr.State({"enabled": False})
state = gr.State({})
with gr.Row():
model_choices = model_list if n == 0 else ["None"] + model_list
@@ -240,6 +232,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_inpaint_full_res.change(
gr_interactive,
inputs=w.ad_inpaint_full_res,
outputs=w.ad_inpaint_full_res_padding,
)
with gr.Column(variant="compact"):
w.ad_use_inpaint_width_height = gr.Checkbox(
label="Use separate width/height" + suffix(n),
@@ -265,6 +263,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_use_inpaint_width_height.change(
gr_interactive,
inputs=w.ad_use_inpaint_width_height,
outputs=[w.ad_inpaint_width, w.ad_inpaint_height],
)
with gr.Row():
with gr.Column(variant="compact"):
w.ad_use_steps = gr.Checkbox(
@@ -282,6 +286,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_use_steps.change(
gr_interactive,
inputs=w.ad_use_steps,
outputs=w.ad_steps,
)
with gr.Column(variant="compact"):
w.ad_use_cfg_scale = gr.Checkbox(
label="Use separate CFG scale" + suffix(n),
@@ -298,6 +308,12 @@ class AfterDetailerScript(scripts.Script):
visible=True,
)
w.ad_use_cfg_scale.change(
gr_interactive,
inputs=w.ad_use_cfg_scale,
outputs=w.ad_cfg_scale,
)
with gr.Group(), gr.Row(variant="panel"):
cn_inpaint_models = ["None"] + get_cn_inpaint_models()
@@ -320,20 +336,13 @@ class AfterDetailerScript(scripts.Script):
interactive=controlnet_exists,
)
subscribers = []
for attr, *_ in ALL_ARGS[1:]:
for attr in ALL_ARGS.attrs:
widget = getattr(w, attr)
for method in ["edit", "click", "change", "clear"]:
if hasattr(widget, method):
subscribers.append(getattr(widget, method))
sub_inputs = [state] + w.tolist()
for subscriber in subscribers:
subscriber(fn=on_widget_change, inputs=sub_inputs, outputs=state)
on_change = partial(on_widget_change, attr=attr)
widget.change(fn=on_change, inputs=[state, widget], outputs=[state])
infotext_fields = [
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS[1:]
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS
]
return w, state, infotext_fields
@@ -367,36 +376,33 @@ class AfterDetailerScript(scripts.Script):
)
def is_ad_enabled(self, *args_) -> bool:
if len(args_) < 2:
if len(args_) == 0 or (len(args_) == 1 and isinstance(args_[0], bool)):
message = f"""
[-] ADetailer: Not enough arguments passed to adetailer.
[-] ADetailer: Not enough arguments passed to ADetailer.
input: {args_!r}
"""
raise ValueError(dedent(message))
checker = EnableChecker(ad_enable=args_[0], ad_model=args_[1])
return checker.is_enabled()
return enable_check(*args_)
def get_args(self, *args_) -> list[ADetailerArgs]:
"""
`args_` is at least 2 in length by `is_ad_enabled` immediately above
`args_` is at least 1 in length by `is_ad_enabled` immediately above
"""
enabled = args_[0]
rem = args_[1:]
length = len(ALL_ARGS) - 1
args = args_[1:] if isinstance(args_[0], bool) else args_
all_inputs = []
iter_args = (rem[i : i + length] for i in range(0, len(rem), length))
for n, args in enumerate(iter_args, 1):
for n, arg_dict in enumerate(args, 1):
try:
inp = get_one_args(enabled, *args)
inp = ADetailerArgs(**arg_dict)
except ValueError as e:
message = [
f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
]
for arg, (attr, *_) in zip_longest(args, ALL_ARGS[1:]):
for attr in ALL_ARGS.attrs:
arg = arg_dict.get(attr)
dtype = type(arg)
arg = "MISSING" if arg is None else repr(arg)
arg = "DEFAULT" if arg is None else repr(arg)
message.append(f" {attr}: {arg} ({dtype})")
raise ValueError("\n".join(message)) from e