fix(scripts): fix is ad enabled

This commit is contained in:
Dowon
2024-08-24 19:14:10 +09:00
parent a3935fcc4f
commit 6090dcdaa9
2 changed files with 22 additions and 26 deletions

View File

@@ -219,6 +219,7 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
with gr.Group(): with gr.Group():
with gr.Row(elem_id=eid("ad_toprow_prompt")): with gr.Row(elem_id=eid("ad_toprow_prompt")):
w.ad_prompt = gr.Textbox( w.ad_prompt = gr.Textbox(
value="",
label="ad_prompt" + suffix(n), label="ad_prompt" + suffix(n),
show_label=False, show_label=False,
lines=3, lines=3,
@@ -230,6 +231,7 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
with gr.Row(elem_id=eid("ad_toprow_negative_prompt")): with gr.Row(elem_id=eid("ad_toprow_negative_prompt")):
w.ad_negative_prompt = gr.Textbox( w.ad_negative_prompt = gr.Textbox(
value="",
label="ad_negative_prompt" + suffix(n), label="ad_negative_prompt" + suffix(n),
show_label=False, show_label=False,
lines=2, lines=2,

View File

@@ -8,7 +8,6 @@ from collections.abc import Sequence
from copy import copy from copy import copy
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, Any, NamedTuple, cast from typing import TYPE_CHECKING, Any, NamedTuple, cast
import gradio as gr import gradio as gr
@@ -177,25 +176,23 @@ class AfterDetailerScript(scripts.Script):
guidance_end=args.ad_controlnet_guidance_end, guidance_end=args.ad_controlnet_guidance_end,
) )
def is_ad_enabled(self, *args_) -> bool: def is_ad_enabled(self, *args) -> bool:
arg_list = [arg for arg in args_ if isinstance(arg, dict)] arg_list = [arg for arg in args if isinstance(arg, dict)]
if not args_ or not arg_list: if not arg_list:
message = f"""
[-] ADetailer: Invalid arguments passed to ADetailer.
input: {args_!r}
ADetailer disabled.
"""
print(dedent(message), file=sys.stderr)
return False return False
ad_enabled = args_[0] if isinstance(args_[0], bool) else True ad_enabled = args[0] if isinstance(args[0], bool) else True
pydantic_args = []
not_none = False
for arg in arg_list: for arg in arg_list:
try: try:
pydantic_args.append(ADetailerArgs(**arg)) adarg = ADetailerArgs(**arg)
except ValueError: # noqa: PERF203 except ValueError: # noqa: PERF203
continue continue
not_none = not all(arg.need_skip() for arg in pydantic_args) else:
if not adarg.need_skip():
not_none = True
break
return ad_enabled and not_none return ad_enabled and not_none
def set_skip_img2img(self, p, *args_) -> None: def set_skip_img2img(self, p, *args_) -> None:
@@ -232,9 +229,6 @@ class AfterDetailerScript(scripts.Script):
p.height = 128 p.height = 128
def get_args(self, p, *args_) -> list[ADetailerArgs]: def get_args(self, p, *args_) -> list[ADetailerArgs]:
"""
`args_` is at least 1 in length by `is_ad_enabled` immediately above
"""
args = [arg for arg in args_ if isinstance(arg, dict)] args = [arg for arg in args_ if isinstance(arg, dict)]
if not args: if not args:
@@ -244,21 +238,21 @@ class AfterDetailerScript(scripts.Script):
if hasattr(p, "_ad_xyz"): if hasattr(p, "_ad_xyz"):
args[0] = {**args[0], **p._ad_xyz} args[0] = {**args[0], **p._ad_xyz}
all_inputs = [] all_inputs: list[ADetailerArgs] = []
for n, arg_dict in enumerate(args, 1): for n, arg_dict in enumerate(args, 1):
try: try:
inp = ADetailerArgs(**arg_dict) inp = ADetailerArgs(**arg_dict)
except ValueError as e: except ValueError:
msg = f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments" msg = f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments:"
if hasattr(e, "add_note"): print(msg, arg_dict, file=sys.stderr)
e.add_note(msg) continue
else:
print(msg, file=sys.stderr)
raise
all_inputs.append(inp) all_inputs.append(inp)
if not all_inputs:
msg = "[-] ADetailer: No valid arguments found."
raise ValueError(msg)
return all_inputs return all_inputs
def extra_params(self, arg_list: list[ADetailerArgs]) -> dict: def extra_params(self, arg_list: list[ADetailerArgs]) -> dict:
@@ -643,7 +637,7 @@ class AfterDetailerScript(scripts.Script):
) )
@staticmethod @staticmethod
def get_i2i_init_image(p, pp): def get_i2i_init_image(p, pp: PPImage):
if is_skip_img2img(p): if is_skip_img2img(p):
return p.init_images[0] return p.init_images[0]
return pp.image return pp.image