feat: builtin contronet support

This commit is contained in:
Bingsu
2023-04-30 14:41:02 +09:00
parent dce15538d3
commit ce74bc839e

View File

@@ -5,12 +5,25 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from modules import sd_models, shared from modules import sd_models, shared
from modules.paths import data_path, models_path from modules.paths import data_path, models_path, script_path
extensions_path = Path(data_path, "extensions") ext_path = Path(data_path, "extensions")
controlnet_exists = any( ext_builtin_path = Path(script_path, "extensions-builtin")
p.name == "sd-webui-controlnet" for p in extensions_path.iterdir() if p.is_dir() is_in_builtin = False # compatibility for vladmandic/automatic
) controlnet_exists = False
if ext_path.exists():
controlnet_exists = any(
p.name == "sd-webui-controlnet" for p in ext_path.iterdir() if p.is_dir()
)
if not controlnet_exists and ext_builtin_path.exists():
controlnet_exists = any(
p.name == "sd-webui-controlnet" for p in ext_builtin_path.iterdir()
)
if controlnet_exists:
is_in_builtin = True
class ControlNetExt: class ControlNetExt:
@@ -20,10 +33,13 @@ class ControlNetExt:
self.external_cn = None self.external_cn = None
def init_controlnet(self) -> bool: def init_controlnet(self) -> bool:
if is_in_builtin:
import_path = "extensions-builtin.sd-webui-controlnet.scripts.external_code"
else:
import_path = "extensions.sd-webui-controlnet.scripts.external_code"
try: try:
self.external_cn = importlib.import_module( self.external_cn = importlib.import_module(import_path, "external_code")
"extensions.sd-webui-controlnet.scripts.external_code", "external_code"
)
self.cn_available = True self.cn_available = True
models = self.external_cn.get_models() models = self.external_cn.get_models()
self.cn_models.extend(m for m in models if "inpaint" in m) self.cn_models.extend(m for m in models if "inpaint" in m)
@@ -49,6 +65,23 @@ class ControlNetExt:
self._update_scripts_args(p, model, weight) self._update_scripts_args(p, model, weight)
def get_cn_model_dirs() -> list[Path]:
cn_model_dir = Path(models_path, "ControlNet")
if is_in_builtin:
cn_model_dir_old = Path(ext_builtin_path, "sd-webui-controlnet", "models")
else:
cn_model_dir_old = Path(ext_path, "sd-webui-controlnet", "models")
ext_dir1 = shared.opts.data.get("control_net_models_path", "")
ext_dir2 = shared.opts.data.get("controlnet_dir", "")
dirs = [cn_model_dir, cn_model_dir_old]
for ext_dir in [ext_dir1, ext_dir2]:
if ext_dir:
dirs.append(Path(ext_dir))
return dirs
@lru_cache @lru_cache
def _get_cn_inpaint_models() -> list[str]: def _get_cn_inpaint_models() -> list[str]:
""" """
@@ -56,19 +89,13 @@ def _get_cn_inpaint_models() -> list[str]:
controlnet's `list(global_state.cn_models_names.values())`. controlnet's `list(global_state.cn_models_names.values())`.
""" """
cn_model_exts = (".pt", ".pth", ".ckpt", ".safetensors") cn_model_exts = (".pt", ".pth", ".ckpt", ".safetensors")
cn_model_dir = Path(models_path, "ControlNet") dirs = get_cn_model_dirs()
cn_model_dir_old = Path(extensions_path, "sd-webui-controlnet", "models")
ext_dir1 = shared.opts.data.get("control_net_models_path", "")
ext_dir2 = shared.opts.data.get("controlnet_dir", "")
name_filter = shared.opts.data.get("control_net_models_name_filter", "") name_filter = shared.opts.data.get("control_net_models_name_filter", "")
name_filter = name_filter.strip(" ").lower() name_filter = name_filter.strip(" ").lower()
model_paths = [] model_paths = []
for base in [cn_model_dir, cn_model_dir_old, ext_dir1, ext_dir2]: for base in dirs:
if not base:
continue
base = Path(base)
if not base.exists(): if not base.exists():
continue continue