mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-04-25 16:59:00 +00:00
feat: change ad args to dict
This commit is contained in:
@@ -4,6 +4,7 @@ import platform
|
||||
import sys
|
||||
import traceback
|
||||
from copy import copy, deepcopy
|
||||
from functools import partial
|
||||
from itertools import zip_longest
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
@@ -14,12 +15,12 @@ import torch
|
||||
|
||||
import modules # noqa: F401
|
||||
from adetailer import (
|
||||
AD_ENABLE,
|
||||
ALL_ARGS,
|
||||
ADetailerArgs,
|
||||
EnableChecker,
|
||||
__version__,
|
||||
enable_check,
|
||||
get_models,
|
||||
get_one_args,
|
||||
mediapipe_predict,
|
||||
ultralytics_predict,
|
||||
)
|
||||
@@ -53,7 +54,7 @@ print(
|
||||
|
||||
class Widgets:
|
||||
def tolist(self):
|
||||
return [getattr(self, attr) for attr, *_ in ALL_ARGS[1:]]
|
||||
return [getattr(self, attr) for attr in ALL_ARGS.attrs]
|
||||
|
||||
|
||||
class ChangeTorchLoad:
|
||||
@@ -65,8 +66,8 @@ class ChangeTorchLoad:
|
||||
torch.load = self.orig
|
||||
|
||||
|
||||
def gr_show(visible: bool = True):
|
||||
return {"visible": visible, "__type__": "update"}
|
||||
def gr_interactive(value: bool = True):
|
||||
return gr.update(interactive=value)
|
||||
|
||||
|
||||
def ordinal(n: int) -> str:
|
||||
@@ -78,15 +79,8 @@ def suffix(n: int, c: str = " ") -> str:
|
||||
return "" if n == 0 else c + ordinal(n + 1)
|
||||
|
||||
|
||||
def on_enable_change(ad_enable: bool, *states):
|
||||
for state in states:
|
||||
state["enabled"] = ad_enable
|
||||
return states
|
||||
|
||||
|
||||
def on_widget_change(state: dict, *values):
|
||||
for (attr, *_), value in zip(ALL_ARGS[1:], values):
|
||||
state[attr] = value
|
||||
def on_widget_change(state: dict, value: Any, *, attr: str):
|
||||
state[attr] = value
|
||||
return state
|
||||
|
||||
|
||||
@@ -116,7 +110,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
self.infotext_fields.append((ad_enable, ALL_ARGS[0].name))
|
||||
self.infotext_fields.append((ad_enable, AD_ENABLE.name))
|
||||
|
||||
with gr.Group(), gr.Tabs():
|
||||
for n in range(num_models):
|
||||
@@ -127,15 +121,13 @@ class AfterDetailerScript(scripts.Script):
|
||||
states.append(state)
|
||||
self.infotext_fields.extend(infofields)
|
||||
|
||||
ad_enable.change(
|
||||
fn=on_enable_change, inputs=[ad_enable] + states, outputs=states
|
||||
)
|
||||
return states
|
||||
# return: [bool, dict, dict, ...]
|
||||
return [ad_enable] + states
|
||||
|
||||
def one_ui_group(self, n: int):
|
||||
model_list = list(model_mapping.keys())
|
||||
w = Widgets()
|
||||
state = gr.State({"enabled": False})
|
||||
state = gr.State({})
|
||||
|
||||
with gr.Row():
|
||||
model_choices = model_list if n == 0 else ["None"] + model_list
|
||||
@@ -240,6 +232,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_inpaint_full_res.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_inpaint_full_res,
|
||||
outputs=w.ad_inpaint_full_res_padding,
|
||||
)
|
||||
|
||||
with gr.Column(variant="compact"):
|
||||
w.ad_use_inpaint_width_height = gr.Checkbox(
|
||||
label="Use separate width/height" + suffix(n),
|
||||
@@ -265,6 +263,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_use_inpaint_width_height.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_use_inpaint_width_height,
|
||||
outputs=[w.ad_inpaint_width, w.ad_inpaint_height],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(variant="compact"):
|
||||
w.ad_use_steps = gr.Checkbox(
|
||||
@@ -282,6 +286,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_use_steps.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_use_steps,
|
||||
outputs=w.ad_steps,
|
||||
)
|
||||
|
||||
with gr.Column(variant="compact"):
|
||||
w.ad_use_cfg_scale = gr.Checkbox(
|
||||
label="Use separate CFG scale" + suffix(n),
|
||||
@@ -298,6 +308,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
visible=True,
|
||||
)
|
||||
|
||||
w.ad_use_cfg_scale.change(
|
||||
gr_interactive,
|
||||
inputs=w.ad_use_cfg_scale,
|
||||
outputs=w.ad_cfg_scale,
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row(variant="panel"):
|
||||
cn_inpaint_models = ["None"] + get_cn_inpaint_models()
|
||||
|
||||
@@ -320,20 +336,13 @@ class AfterDetailerScript(scripts.Script):
|
||||
interactive=controlnet_exists,
|
||||
)
|
||||
|
||||
subscribers = []
|
||||
for attr, *_ in ALL_ARGS[1:]:
|
||||
for attr in ALL_ARGS.attrs:
|
||||
widget = getattr(w, attr)
|
||||
for method in ["edit", "click", "change", "clear"]:
|
||||
if hasattr(widget, method):
|
||||
subscribers.append(getattr(widget, method))
|
||||
|
||||
sub_inputs = [state] + w.tolist()
|
||||
|
||||
for subscriber in subscribers:
|
||||
subscriber(fn=on_widget_change, inputs=sub_inputs, outputs=state)
|
||||
on_change = partial(on_widget_change, attr=attr)
|
||||
widget.change(fn=on_change, inputs=[state, widget], outputs=[state])
|
||||
|
||||
infotext_fields = [
|
||||
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS[1:]
|
||||
(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS
|
||||
]
|
||||
|
||||
return w, state, infotext_fields
|
||||
@@ -367,36 +376,33 @@ class AfterDetailerScript(scripts.Script):
|
||||
)
|
||||
|
||||
def is_ad_enabled(self, *args_) -> bool:
|
||||
if len(args_) < 2:
|
||||
if len(args_) == 0 or (len(args_) == 1 and isinstance(args_[0], bool)):
|
||||
message = f"""
|
||||
[-] ADetailer: Not enough arguments passed to adetailer.
|
||||
[-] ADetailer: Not enough arguments passed to ADetailer.
|
||||
input: {args_!r}
|
||||
"""
|
||||
raise ValueError(dedent(message))
|
||||
checker = EnableChecker(ad_enable=args_[0], ad_model=args_[1])
|
||||
return checker.is_enabled()
|
||||
return enable_check(*args_)
|
||||
|
||||
def get_args(self, *args_) -> list[ADetailerArgs]:
|
||||
"""
|
||||
`args_` is at least 2 in length by `is_ad_enabled` immediately above
|
||||
`args_` is at least 1 in length by `is_ad_enabled` immediately above
|
||||
"""
|
||||
enabled = args_[0]
|
||||
rem = args_[1:]
|
||||
length = len(ALL_ARGS) - 1
|
||||
args = args_[1:] if isinstance(args_[0], bool) else args_
|
||||
|
||||
all_inputs = []
|
||||
iter_args = (rem[i : i + length] for i in range(0, len(rem), length))
|
||||
|
||||
for n, args in enumerate(iter_args, 1):
|
||||
for n, arg_dict in enumerate(args, 1):
|
||||
try:
|
||||
inp = get_one_args(enabled, *args)
|
||||
inp = ADetailerArgs(**arg_dict)
|
||||
except ValueError as e:
|
||||
message = [
|
||||
f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
|
||||
]
|
||||
for arg, (attr, *_) in zip_longest(args, ALL_ARGS[1:]):
|
||||
for attr in ALL_ARGS.attrs:
|
||||
arg = arg_dict.get(attr)
|
||||
dtype = type(arg)
|
||||
arg = "MISSING" if arg is None else repr(arg)
|
||||
arg = "DEFAULT" if arg is None else repr(arg)
|
||||
message.append(f" {attr}: {arg} ({dtype})")
|
||||
raise ValueError("\n".join(message)) from e
|
||||
|
||||
|
||||
Reference in New Issue
Block a user