From 200f2b69edda97314ce133e7298824603a719f5e Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 9 Feb 2024 21:34:09 +0000 Subject: [PATCH] Add back ControlNet model version filter (#131) * Add back ControlNet model version filter * Update choice after sd model changes --- .../controlnet_ui/controlnet_ui_group.py | 32 +++++++++++++------ .../lib_controlnet/global_state.py | 8 +++-- modules/script_callbacks.py | 30 +++++++++++++++++ modules/ui_settings.py | 14 ++++++-- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index 9c3f2df6..389788bf 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -17,7 +17,7 @@ from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI from lib_controlnet.controlnet_ui.tool_button import ToolButton from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.enums import InputMode, HiResFixOption -from modules import shared +from modules import shared, script_callbacks from modules.ui_components import FormRow from modules_forge.forge_util import HWC3 from lib_controlnet.external_code import UiControlNetUnit @@ -47,7 +47,6 @@ class A1111Context: img2img_inpaint_area: Optional[gr.components.IOComponent] = None txt2img_enable_hr: Optional[gr.components.IOComponent] = None - setting_sd_model_checkpoint: Optional[gr.components.IOComponent] = None @property def img2img_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]: @@ -75,10 +74,6 @@ class A1111Context: "img2img_inpaint_tab": "img2img_inpaint_tab", "img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab", "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab", - # SDNext does not have this field. Temporarily disable the callback on - # the checkpoint change until we find a way to register an event when - # all A1111 UI components are ready. - "setting_sd_model_checkpoint": "setting_sd_model_checkpoint", } return all( c @@ -104,8 +99,6 @@ class A1111Context: "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab", "img2img_inpaint_full_res": "img2img_inpaint_area", "txt2img_hr-checkbox": "txt2img_enable_hr", - # setting_sd_model_checkpoint is expected to be initialized last. - # "setting_sd_model_checkpoint": "setting_sd_model_checkpoint", } elem_id = getattr(component, "elem_id", None) # Do not set component if it has already been set. @@ -1141,7 +1134,6 @@ class ControlNetUiGroup(object): self.register_refresh_all_models() self.register_build_sliders() self.register_shift_preview() - self.register_shift_upload_mask() self.register_create_canvas() self.register_clear_preview() self.register_multi_images_upload() @@ -1162,6 +1154,26 @@ class ControlNetUiGroup(object): if self.is_img2img: self.register_img2img_same_input() + def register_sd_model_changed(self): + def sd_version_changed(type_filter: str, current_model: str, setting_value: str, setting_name: str): + """When SD version changes, update model dropdown choices.""" + if setting_name != "sd_model_checkpoint": + return gr.update() + + filtered_model_list = global_state.get_filtered_controlnet_names(type_filter) + assert len(filtered_model_list) > 0 + default_model = filtered_model_list[1] if len(filtered_model_list) > 1 else filtered_model_list[0] + return gr.Dropdown.update( + choices=filtered_model_list, + value=current_model if current_model in filtered_model_list else default_model + ) + + script_callbacks.on_setting_updated_subscriber(dict( + fn=sd_version_changed, + inputs=[self.type_filter, self.model], + outputs=[self.model], + )) + def register_callbacks(self): """Register callbacks that involves A1111 context gradio components.""" # Prevent infinite recursion. @@ -1172,6 +1184,8 @@ class ControlNetUiGroup(object): self.register_send_dimensions() self.register_run_annotator() self.register_sync_batch_dir() + self.register_shift_upload_mask() + self.register_sd_model_changed() if self.is_img2img: self.register_shift_crop_input_image() else: diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py index 05e8d41e..0f52ac8e 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py @@ -98,7 +98,7 @@ def get_filtered_preprocessor_names(tag): return list(get_filtered_preprocessors(tag).keys()) -def get_filtered_controlnet_names(tag, filter_version: bool = True): +def get_filtered_controlnet_names(tag): filtered_preprocessors = get_filtered_preprocessors(tag) model_filename_filters = [] for p in filtered_preprocessors.values(): @@ -106,8 +106,8 @@ def get_filtered_controlnet_names(tag, filter_version: bool = True): return [ x for x in controlnet_names if x == 'None' or ( - any(f.lower() in x.lower() for f in model_filename_filters) # and - # get_sd_version().is_compatible_with(StableDiffusionVersion.detect_from_model_name(x)) + any(f.lower() in x.lower() for f in model_filename_filters) and + get_sd_version().is_compatible_with(StableDiffusionVersion.detect_from_model_name(x)) ) ] @@ -134,6 +134,8 @@ def update_controlnet_filenames(): def get_sd_version() -> StableDiffusionVersion: + if not shared.sd_model: + return StableDiffusionVersion.UNKNOWN if shared.sd_model.is_sdxl: return StableDiffusionVersion.SDXL elif shared.sd_model.is_sd2: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a54cb3eb..48ab289c 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -129,6 +129,9 @@ callback_map = dict( callbacks_list_optimizers=[], callbacks_list_unets=[], ) +event_subscriber_map = dict( + callbacks_setting_updated=[], +) def clear_callbacks(): @@ -309,6 +312,23 @@ def list_unets_callback(): return res +def setting_updated_event_subscriber_chain(handler, component, setting_name: str): + """ + Arguments: + - handler: The returned handler from calling an event subscriber. + - component: The component that is updated. The component should provide + the value of setting after update. + - setting_name: The name of the setting. + """ + for param in event_subscriber_map['callbacks_setting_updated']: + handler = handler.then( + fn=lambda *args: param["fn"](*args, setting_name), + inputs=param["inputs"] + [component], + outputs=param["outputs"], + show_progress=False, + ) + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if stack else 'unknown file' @@ -483,3 +503,13 @@ def on_list_unets(callback): The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" add_callback(callback_map['callbacks_list_unets'], callback) + + +def on_setting_updated_subscriber(subscriber_params): + """register a function to be called after settings update. `subscriber_params` + should contain necessary fields to register an gradio event handler. Necessary + fields are ["fn", "outputs", "inputs"]. + Setting name and setting value after update will be append to inputs. So be + sure to handle these extra params when defining the callback function. + """ + event_subscriber_map['callbacks_setting_updated'].append(subscriber_params) diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 24a7f2aa..f2576dc5 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -303,20 +303,30 @@ class UiSettings: methods = [component.change] for method in methods: - method( + handler = method( fn=lambda value, k=k: self.run_settings_single(value, key=k), inputs=[component], outputs=[component, self.text_settings], show_progress=False, ) + script_callbacks.setting_updated_event_subscriber_chain( + handler=handler, + component=component, + setting_name=k, + ) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) - button_set_checkpoint.click( + handler = button_set_checkpoint.click( fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'), _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component], outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings], ) + script_callbacks.setting_updated_event_subscriber_chain( + handler=handler, + component=self.component_dict['sd_model_checkpoint'], + setting_name="sd_model_checkpoint" + ) component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]