diff --git a/adetailer/ui.py b/adetailer/ui.py index 7b6261e..fc4c9dc 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -2,7 +2,7 @@ from __future__ import annotations from functools import partial from types import SimpleNamespace -from typing import Any +from typing import Any, NamedTuple import gradio as gr @@ -22,6 +22,13 @@ class Widgets(SimpleNamespace): return [getattr(self, attr) for attr in ALL_ARGS.attrs] +class WebuiInfo(NamedTuple): + ad_model_list: list[str] + sampler_names: list[str] + t2i_button: gr.Button + i2i_button: gr.Button + + def gr_interactive(value: bool = True): return gr.update(interactive=value) @@ -64,10 +71,7 @@ def elem_id(item_id: str, n: int, is_img2img: bool) -> str: def adui( num_models: int, is_img2img: bool, - model_list: list[str], - samplers: list[str], - t2i_button: gr.Button, - i2i_button: gr.Button, + webui_info: WebuiInfo, ): states = [] infotext_fields = [] @@ -97,10 +101,7 @@ def adui( state, infofields = one_ui_group( n=n, is_img2img=is_img2img, - model_list=model_list, - samplers=samplers, - t2i_button=t2i_button, - i2i_button=i2i_button, + webui_info=webui_info, ) states.append(state) @@ -111,20 +112,17 @@ def adui( return components, infotext_fields -def one_ui_group( - n: int, - is_img2img: bool, - model_list: list[str], - samplers: list[str], - t2i_button: gr.Button, - i2i_button: gr.Button, -): +def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo): w = Widgets() state = gr.State({}) eid = partial(elem_id, n=n, is_img2img=is_img2img) with gr.Row(): - model_choices = [*model_list, "None"] if n == 0 else ["None", *model_list] + model_choices = ( + [*webui_info.ad_model_list, "None"] + if n == 0 + else ["None", *webui_info.ad_model_list] + ) w.ad_model = gr.Dropdown( label="ADetailer model" + suffix(n), @@ -174,13 +172,13 @@ def one_ui_group( with gr.Accordion( "Inpainting", open=False, elem_id=eid("ad_inpainting_accordion") ): - inpainting(w, n, is_img2img, samplers) + inpainting(w, n, is_img2img, webui_info.sampler_names) with gr.Group(): controlnet(w, n, is_img2img) all_inputs = [state, *w.tolist()] - target_button = i2i_button if is_img2img else t2i_button + target_button = webui_info.i2i_button if is_img2img else webui_info.t2i_button target_button.click( fn=on_generate_click, inputs=all_inputs, outputs=state, queue=False ) @@ -280,7 +278,7 @@ def mask_preprocessing(w: Widgets, n: int, is_img2img: bool): ) -def inpainting(w: Widgets, n: int, is_img2img: bool, samplers: list[str]): +def inpainting(w: Widgets, n: int, is_img2img: bool, sampler_names: list[str]): eid = partial(elem_id, n=n, is_img2img=is_img2img) with gr.Group(): @@ -427,8 +425,8 @@ def inpainting(w: Widgets, n: int, is_img2img: bool, samplers: list[str]): w.ad_sampler = gr.Dropdown( label="ADetailer sampler" + suffix(n), - choices=samplers, - value=samplers[0], + choices=sampler_names, + value=sampler_names[0], visible=True, elem_id=eid("ad_sampler"), ) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index fee6717..33dc832 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -33,7 +33,7 @@ from adetailer.mask import ( sort_bboxes, ) from adetailer.traceback import rich_traceback -from adetailer.ui import adui, ordinal, suffix +from adetailer.ui import WebuiInfo, adui, ordinal, suffix from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models from controlnet_ext.restore import ( CNHijackRestore, @@ -118,18 +118,17 @@ class AfterDetailerScript(scripts.Script): def ui(self, is_img2img): num_models = opts.data.get("ad_max_models", 2) - model_list = list(model_mapping.keys()) - samplers = [sampler.name for sampler in all_samplers] - - components, infotext_fields = adui( - num_models, - is_img2img, - model_list, - samplers, - txt2img_submit_button, - img2img_submit_button, + ad_model_list = list(model_mapping.keys()) + sampler_names = [sampler.name for sampler in all_samplers] + webui_info = WebuiInfo( + ad_model_list=ad_model_list, + sampler_names=sampler_names, + t2i_button=txt2img_submit_button, + i2i_button=img2img_submit_button, ) + components, infotext_fields = adui(num_models, is_img2img, webui_info) + self.infotext_fields = infotext_fields return components