refactor: use ui components tuple

This commit is contained in:
Bingsu
2023-08-27 13:31:49 +09:00
parent 1716f55f0e
commit 55a92ec9ef
2 changed files with 31 additions and 34 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from functools import partial from functools import partial
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any from typing import Any, NamedTuple
import gradio as gr import gradio as gr
@@ -22,6 +22,13 @@ class Widgets(SimpleNamespace):
return [getattr(self, attr) for attr in ALL_ARGS.attrs] return [getattr(self, attr) for attr in ALL_ARGS.attrs]
class WebuiInfo(NamedTuple):
ad_model_list: list[str]
sampler_names: list[str]
t2i_button: gr.Button
i2i_button: gr.Button
def gr_interactive(value: bool = True): def gr_interactive(value: bool = True):
return gr.update(interactive=value) return gr.update(interactive=value)
@@ -64,10 +71,7 @@ def elem_id(item_id: str, n: int, is_img2img: bool) -> str:
def adui( def adui(
num_models: int, num_models: int,
is_img2img: bool, is_img2img: bool,
model_list: list[str], webui_info: WebuiInfo,
samplers: list[str],
t2i_button: gr.Button,
i2i_button: gr.Button,
): ):
states = [] states = []
infotext_fields = [] infotext_fields = []
@@ -97,10 +101,7 @@ def adui(
state, infofields = one_ui_group( state, infofields = one_ui_group(
n=n, n=n,
is_img2img=is_img2img, is_img2img=is_img2img,
model_list=model_list, webui_info=webui_info,
samplers=samplers,
t2i_button=t2i_button,
i2i_button=i2i_button,
) )
states.append(state) states.append(state)
@@ -111,20 +112,17 @@ def adui(
return components, infotext_fields return components, infotext_fields
def one_ui_group( def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
n: int,
is_img2img: bool,
model_list: list[str],
samplers: list[str],
t2i_button: gr.Button,
i2i_button: gr.Button,
):
w = Widgets() w = Widgets()
state = gr.State({}) state = gr.State({})
eid = partial(elem_id, n=n, is_img2img=is_img2img) eid = partial(elem_id, n=n, is_img2img=is_img2img)
with gr.Row(): with gr.Row():
model_choices = [*model_list, "None"] if n == 0 else ["None", *model_list] model_choices = (
[*webui_info.ad_model_list, "None"]
if n == 0
else ["None", *webui_info.ad_model_list]
)
w.ad_model = gr.Dropdown( w.ad_model = gr.Dropdown(
label="ADetailer model" + suffix(n), label="ADetailer model" + suffix(n),
@@ -174,13 +172,13 @@ def one_ui_group(
with gr.Accordion( with gr.Accordion(
"Inpainting", open=False, elem_id=eid("ad_inpainting_accordion") "Inpainting", open=False, elem_id=eid("ad_inpainting_accordion")
): ):
inpainting(w, n, is_img2img, samplers) inpainting(w, n, is_img2img, webui_info.sampler_names)
with gr.Group(): with gr.Group():
controlnet(w, n, is_img2img) controlnet(w, n, is_img2img)
all_inputs = [state, *w.tolist()] all_inputs = [state, *w.tolist()]
target_button = i2i_button if is_img2img else t2i_button target_button = webui_info.i2i_button if is_img2img else webui_info.t2i_button
target_button.click( target_button.click(
fn=on_generate_click, inputs=all_inputs, outputs=state, queue=False fn=on_generate_click, inputs=all_inputs, outputs=state, queue=False
) )
@@ -280,7 +278,7 @@ def mask_preprocessing(w: Widgets, n: int, is_img2img: bool):
) )
def inpainting(w: Widgets, n: int, is_img2img: bool, samplers: list[str]): def inpainting(w: Widgets, n: int, is_img2img: bool, sampler_names: list[str]):
eid = partial(elem_id, n=n, is_img2img=is_img2img) eid = partial(elem_id, n=n, is_img2img=is_img2img)
with gr.Group(): with gr.Group():
@@ -427,8 +425,8 @@ def inpainting(w: Widgets, n: int, is_img2img: bool, samplers: list[str]):
w.ad_sampler = gr.Dropdown( w.ad_sampler = gr.Dropdown(
label="ADetailer sampler" + suffix(n), label="ADetailer sampler" + suffix(n),
choices=samplers, choices=sampler_names,
value=samplers[0], value=sampler_names[0],
visible=True, visible=True,
elem_id=eid("ad_sampler"), elem_id=eid("ad_sampler"),
) )

View File

@@ -33,7 +33,7 @@ from adetailer.mask import (
sort_bboxes, sort_bboxes,
) )
from adetailer.traceback import rich_traceback from adetailer.traceback import rich_traceback
from adetailer.ui import adui, ordinal, suffix from adetailer.ui import WebuiInfo, adui, ordinal, suffix
from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models
from controlnet_ext.restore import ( from controlnet_ext.restore import (
CNHijackRestore, CNHijackRestore,
@@ -118,18 +118,17 @@ class AfterDetailerScript(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
num_models = opts.data.get("ad_max_models", 2) num_models = opts.data.get("ad_max_models", 2)
model_list = list(model_mapping.keys()) ad_model_list = list(model_mapping.keys())
samplers = [sampler.name for sampler in all_samplers] sampler_names = [sampler.name for sampler in all_samplers]
webui_info = WebuiInfo(
components, infotext_fields = adui( ad_model_list=ad_model_list,
num_models, sampler_names=sampler_names,
is_img2img, t2i_button=txt2img_submit_button,
model_list, i2i_button=img2img_submit_button,
samplers,
txt2img_submit_button,
img2img_submit_button,
) )
components, infotext_fields = adui(num_models, is_img2img, webui_info)
self.infotext_fields = infotext_fields self.infotext_fields = infotext_fields
return components return components