diff --git a/controlnet_ext/controlnet_ext.py b/controlnet_ext/controlnet_ext.py index c95c276..be1ac80 100644 --- a/controlnet_ext/controlnet_ext.py +++ b/controlnet_ext/controlnet_ext.py @@ -5,12 +5,25 @@ from functools import lru_cache from pathlib import Path from modules import sd_models, shared -from modules.paths import data_path, models_path +from modules.paths import data_path, models_path, script_path -extensions_path = Path(data_path, "extensions") -controlnet_exists = any( - p.name == "sd-webui-controlnet" for p in extensions_path.iterdir() if p.is_dir() -) +ext_path = Path(data_path, "extensions") +ext_builtin_path = Path(script_path, "extensions-builtin") +is_in_builtin = False # compatibility for vladmandic/automatic +controlnet_exists = False + +if ext_path.exists(): + controlnet_exists = any( + p.name == "sd-webui-controlnet" for p in ext_path.iterdir() if p.is_dir() + ) + +if not controlnet_exists and ext_builtin_path.exists(): + controlnet_exists = any( + p.name == "sd-webui-controlnet" for p in ext_builtin_path.iterdir() + ) + + if controlnet_exists: + is_in_builtin = True class ControlNetExt: @@ -20,10 +33,13 @@ class ControlNetExt: self.external_cn = None def init_controlnet(self) -> bool: + if is_in_builtin: + import_path = "extensions-builtin.sd-webui-controlnet.scripts.external_code" + else: + import_path = "extensions.sd-webui-controlnet.scripts.external_code" + try: - self.external_cn = importlib.import_module( - "extensions.sd-webui-controlnet.scripts.external_code", "external_code" - ) + self.external_cn = importlib.import_module(import_path, "external_code") self.cn_available = True models = self.external_cn.get_models() self.cn_models.extend(m for m in models if "inpaint" in m) @@ -49,6 +65,23 @@ class ControlNetExt: self._update_scripts_args(p, model, weight) +def get_cn_model_dirs() -> list[Path]: + cn_model_dir = Path(models_path, "ControlNet") + if is_in_builtin: + cn_model_dir_old = Path(ext_builtin_path, "sd-webui-controlnet", "models") + else: + cn_model_dir_old = Path(ext_path, "sd-webui-controlnet", "models") + ext_dir1 = shared.opts.data.get("control_net_models_path", "") + ext_dir2 = shared.opts.data.get("controlnet_dir", "") + + dirs = [cn_model_dir, cn_model_dir_old] + for ext_dir in [ext_dir1, ext_dir2]: + if ext_dir: + dirs.append(Path(ext_dir)) + + return dirs + + @lru_cache def _get_cn_inpaint_models() -> list[str]: """ @@ -56,19 +89,13 @@ def _get_cn_inpaint_models() -> list[str]: controlnet's `list(global_state.cn_models_names.values())`. """ cn_model_exts = (".pt", ".pth", ".ckpt", ".safetensors") - cn_model_dir = Path(models_path, "ControlNet") - cn_model_dir_old = Path(extensions_path, "sd-webui-controlnet", "models") - ext_dir1 = shared.opts.data.get("control_net_models_path", "") - ext_dir2 = shared.opts.data.get("controlnet_dir", "") + dirs = get_cn_model_dirs() name_filter = shared.opts.data.get("control_net_models_name_filter", "") name_filter = name_filter.strip(" ").lower() model_paths = [] - for base in [cn_model_dir, cn_model_dir_old, ext_dir1, ext_dir2]: - if not base: - continue - base = Path(base) + for base in dirs: if not base.exists(): continue