diff --git a/aaaaaa/ui.py b/aaaaaa/ui.py index 2608ee3..f7a5148 100644 --- a/aaaaaa/ui.py +++ b/aaaaaa/ui.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass from functools import partial +from itertools import chain from types import SimpleNamespace from typing import Any @@ -42,6 +43,9 @@ else: "depth": ["depth_midas", "depth_hand_refiner"], } +union = list(chain.from_iterable(cn_module_choices.values())) +cn_module_choices["union"] = union + class Widgets(SimpleNamespace): def tolist(self): diff --git a/controlnet_ext/common.py b/controlnet_ext/common.py index a9e0e97..f485da5 100644 --- a/controlnet_ext/common.py +++ b/controlnet_ext/common.py @@ -8,4 +8,5 @@ cn_model_module = { "tile": "tile_resample", "depth": "depth_midas", } -cn_model_regex = re.compile("|".join(cn_model_module.keys()), flags=re.IGNORECASE) +_names = [*cn_model_module, "union"] +cn_model_regex = re.compile("|".join(_names), flags=re.IGNORECASE) diff --git a/controlnet_ext/controlnet_ext.py b/controlnet_ext/controlnet_ext.py index 9af1238..bcf7130 100644 --- a/controlnet_ext/controlnet_ext.py +++ b/controlnet_ext/controlnet_ext.py @@ -61,13 +61,13 @@ class ControlNetExt: if (not self.cn_available) or model == "None": return - if module is None or module == "None": + if module == "None": + module = None + if module is None: for m, v in cn_model_module.items(): if m in model: module = v break - else: - module = None cn_units = [ self.external_cn.ControlNetUnit(