Fix ControlNet UI preset (#87)

This commit is contained in:
Chenlei Hu
2024-02-07 02:52:04 +00:00
committed by GitHub
parent e62631350a
commit 1110183943

View File

@@ -7,7 +7,9 @@ from modules import scripts
from lib_controlnet.infotext import parse_unit, serialize_unit
from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.logging import logger
from lib_controlnet import external_code
from lib_controlnet.external_code import ControlNetUnit, UiControlNetUnit
from lib_controlnet.global_state import get_preprocessor
from modules_forge.supported_preprocessor import Preprocessor
save_symbol = "\U0001f4be" # 💾
delete_symbol = "\U0001f5d1\ufe0f" # 🗑️
@@ -33,24 +35,10 @@ def load_presets(preset_dir: str) -> Dict[str, str]:
return presets
def infer_control_type(module: str, model: str) -> str:
def matches_control_type(input_string: str, control_type: str) -> bool:
return any(t.lower() in input_string for t in control_type.split("/"))
control_types = preprocessor_filters.keys()
control_type_candidates = [
control_type
for control_type in control_types
if (
matches_control_type(module, control_type)
or matches_control_type(model, control_type)
)
]
if len(control_type_candidates) != 1:
raise ValueError(
f"Unable to infer control type from module {module} and model {model}"
)
return control_type_candidates[0]
def infer_control_type(module: str) -> str:
preprocessor: Preprocessor = get_preprocessor(module)
assert preprocessor is not None
return preprocessor.tags[0] if preprocessor.tags else "All"
class ControlNetPresetUI(object):
@@ -111,13 +99,19 @@ class ControlNetPresetUI(object):
control_type: gr.Radio,
*ui_states,
):
def init_with_ui_states(*ui_states) -> ControlNetUnit:
return ControlNetUnit(**{
field: value
for field, value in zip(ControlNetUnit.infotext_fields(), ui_states)
})
def apply_preset(name: str, control_type: str, *ui_states):
if name == NEW_PRESET:
return (
gr.update(visible=False),
*(
(gr.skip(),)
* (len(external_code.ControlNetUnit.infotext_fields()) + 1)
* (len(ControlNetUnit.infotext_fields()) + 1)
),
)
@@ -125,7 +119,7 @@ class ControlNetPresetUI(object):
infotext = ControlNetPresetUI.presets[name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
current_unit = init_with_ui_states(*ui_states)
preset_unit.image = None
current_unit.image = None
@@ -140,14 +134,14 @@ class ControlNetPresetUI(object):
gr.update(visible=False),
*(
(gr.skip(),)
* (len(external_code.ControlNetUnit.infotext_fields()) + 1)
* (len(ControlNetUnit.infotext_fields()) + 1)
),
)
unit = preset_unit
try:
new_control_type = infer_control_type(unit.module, unit.model)
new_control_type = infer_control_type(unit.module)
except ValueError as e:
logger.error(e)
new_control_type = control_type
@@ -166,7 +160,7 @@ class ControlNetPresetUI(object):
gr.update(value=new_control_type),
*[
gr.update(value=value) if value is not None else gr.update()
for field in external_code.ControlNetUnit.infotext_fields()
for field in ControlNetUnit.infotext_fields()
for value in (getattr(unit, field),)
],
)
@@ -191,7 +185,7 @@ class ControlNetPresetUI(object):
return gr.update(visible=True), gr.update(), gr.update()
ControlNetPresetUI.save_preset(
name, external_code.ControlNetUnit(*ui_states)
name, init_with_ui_states(*ui_states)
)
return (
gr.update(), # name dialog
@@ -236,7 +230,7 @@ class ControlNetPresetUI(object):
return gr.update(visible=False), gr.update()
ControlNetPresetUI.save_preset(
new_name, external_code.ControlNetUnit(*ui_states)
new_name, init_with_ui_states(*ui_states)
)
return gr.update(visible=False), gr.update(
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
@@ -262,7 +256,7 @@ class ControlNetPresetUI(object):
infotext = ControlNetPresetUI.presets[preset_name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
current_unit = init_with_ui_states(*ui_states)
preset_unit.image = None
current_unit.image = None
@@ -293,7 +287,7 @@ class ControlNetPresetUI(object):
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]
@staticmethod
def save_preset(name: str, unit: external_code.ControlNetUnit):
def save_preset(name: str, unit: ControlNetUnit):
infotext = serialize_unit(unit)
with open(
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"