diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py index 064e4ca6..15a9f24c 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py @@ -7,7 +7,9 @@ from modules import scripts from lib_controlnet.infotext import parse_unit, serialize_unit from lib_controlnet.controlnet_ui.tool_button import ToolButton from lib_controlnet.logging import logger -from lib_controlnet import external_code +from lib_controlnet.external_code import ControlNetUnit, UiControlNetUnit +from lib_controlnet.global_state import get_preprocessor +from modules_forge.supported_preprocessor import Preprocessor save_symbol = "\U0001f4be" # 💾 delete_symbol = "\U0001f5d1\ufe0f" # 🗑️ @@ -33,24 +35,10 @@ def load_presets(preset_dir: str) -> Dict[str, str]: return presets -def infer_control_type(module: str, model: str) -> str: - def matches_control_type(input_string: str, control_type: str) -> bool: - return any(t.lower() in input_string for t in control_type.split("/")) - - control_types = preprocessor_filters.keys() - control_type_candidates = [ - control_type - for control_type in control_types - if ( - matches_control_type(module, control_type) - or matches_control_type(model, control_type) - ) - ] - if len(control_type_candidates) != 1: - raise ValueError( - f"Unable to infer control type from module {module} and model {model}" - ) - return control_type_candidates[0] +def infer_control_type(module: str) -> str: + preprocessor: Preprocessor = get_preprocessor(module) + assert preprocessor is not None + return preprocessor.tags[0] if preprocessor.tags else "All" class ControlNetPresetUI(object): @@ -111,13 +99,19 @@ class ControlNetPresetUI(object): control_type: gr.Radio, *ui_states, ): + def init_with_ui_states(*ui_states) -> ControlNetUnit: + return ControlNetUnit(**{ + field: value + for field, value in zip(ControlNetUnit.infotext_fields(), ui_states) + }) + def apply_preset(name: str, control_type: str, *ui_states): if name == NEW_PRESET: return ( gr.update(visible=False), *( (gr.skip(),) - * (len(external_code.ControlNetUnit.infotext_fields()) + 1) + * (len(ControlNetUnit.infotext_fields()) + 1) ), ) @@ -125,7 +119,7 @@ class ControlNetPresetUI(object): infotext = ControlNetPresetUI.presets[name] preset_unit = parse_unit(infotext) - current_unit = external_code.ControlNetUnit(*ui_states) + current_unit = init_with_ui_states(*ui_states) preset_unit.image = None current_unit.image = None @@ -140,14 +134,14 @@ class ControlNetPresetUI(object): gr.update(visible=False), *( (gr.skip(),) - * (len(external_code.ControlNetUnit.infotext_fields()) + 1) + * (len(ControlNetUnit.infotext_fields()) + 1) ), ) unit = preset_unit try: - new_control_type = infer_control_type(unit.module, unit.model) + new_control_type = infer_control_type(unit.module) except ValueError as e: logger.error(e) new_control_type = control_type @@ -166,7 +160,7 @@ class ControlNetPresetUI(object): gr.update(value=new_control_type), *[ gr.update(value=value) if value is not None else gr.update() - for field in external_code.ControlNetUnit.infotext_fields() + for field in ControlNetUnit.infotext_fields() for value in (getattr(unit, field),) ], ) @@ -191,7 +185,7 @@ class ControlNetPresetUI(object): return gr.update(visible=True), gr.update(), gr.update() ControlNetPresetUI.save_preset( - name, external_code.ControlNetUnit(*ui_states) + name, init_with_ui_states(*ui_states) ) return ( gr.update(), # name dialog @@ -236,7 +230,7 @@ class ControlNetPresetUI(object): return gr.update(visible=False), gr.update() ControlNetPresetUI.save_preset( - new_name, external_code.ControlNetUnit(*ui_states) + new_name, init_with_ui_states(*ui_states) ) return gr.update(visible=False), gr.update( choices=ControlNetPresetUI.dropdown_choices(), value=new_name @@ -262,7 +256,7 @@ class ControlNetPresetUI(object): infotext = ControlNetPresetUI.presets[preset_name] preset_unit = parse_unit(infotext) - current_unit = external_code.ControlNetUnit(*ui_states) + current_unit = init_with_ui_states(*ui_states) preset_unit.image = None current_unit.image = None @@ -293,7 +287,7 @@ class ControlNetPresetUI(object): return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET] @staticmethod - def save_preset(name: str, unit: external_code.ControlNetUnit): + def save_preset(name: str, unit: ControlNetUnit): infotext = serialize_unit(unit) with open( os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"