fix: add cn guidance start, var names, style, ui

This commit is contained in:
Bingsu
2023-06-01 09:34:09 +09:00
parent 39baa120a6
commit 3263bd5ed0
5 changed files with 78 additions and 60 deletions

View File

@@ -1,7 +1,7 @@
from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_inpaint_models
from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models
__all__ = [
"ControlNetExt",
"controlnet_exists",
"get_cn_inpaint_models",
"get_cn_models",
]

View File

@@ -12,16 +12,6 @@ ext_path = Path(data_path, "extensions")
ext_builtin_path = Path(script_path, "extensions-builtin")
is_in_builtin = False # compatibility for vladmandic/automatic
controlnet_exists = False
controlnet_enabled_models = {
"inpaint": "inpaint_global_harmonious",
"scribble": "t2ia_sketch_pidi",
"lineart": "lineart_coarse",
"openpose": "openpose_full",
"tile": None,
}
controlnet_model_regex = re.compile(
r".*(" + ("|".join(controlnet_enabled_models.keys())) + ").*"
)
if ext_path.exists():
controlnet_exists = any(
@@ -38,6 +28,15 @@ if not controlnet_exists and ext_builtin_path.exists():
if controlnet_exists:
is_in_builtin = True
cn_model_module = {
"inpaint": "inpaint_global_harmonious",
"scribble": "t2ia_sketch_pidi",
"lineart": "lineart_coarse",
"openpose": "openpose_full",
"tile": None,
}
cn_model_regex = re.compile("|".join(cn_model_module.keys()))
class ControlNetExt:
def __init__(self):
@@ -54,11 +53,16 @@ class ControlNetExt:
self.external_cn = importlib.import_module(import_path, "external_code")
self.cn_available = True
models = self.external_cn.get_models()
self.cn_models.extend(m for m in models if controlnet_model_regex.match(m))
self.cn_models.extend(m for m in models if cn_model_regex.search(m))
def update_scripts_args(
self, p, model: str, weight: float, guidance_start: float, guidance_end: float
):
if (not self.cn_available) or model == "None":
return
def _update_scripts_args(self, p, model: str, weight: float, guidance_end: float):
module = None
for m, v in controlnet_enabled_models.items():
for m, v in cn_model_module.items():
if m in model:
module = v
break
@@ -69,6 +73,7 @@ class ControlNetExt:
weight=weight,
control_mode=self.external_cn.ControlMode.BALANCED,
module=module,
guidance_start=guidance_start,
guidance_end=guidance_end,
pixel_perfect=True,
)
@@ -76,10 +81,6 @@ class ControlNetExt:
self.external_cn.update_cn_script_in_processing(p, cn_units)
def update_scripts_args(self, p, model: str, weight: float, guidance_end: float):
if self.cn_available and model != "None":
self._update_scripts_args(p, model, weight, guidance_end)
def get_cn_model_dirs() -> list[Path]:
cn_model_dir = Path(models_path, "ControlNet")
@@ -99,7 +100,7 @@ def get_cn_model_dirs() -> list[Path]:
@lru_cache
def _get_cn_inpaint_models() -> list[str]:
def _get_cn_models() -> list[str]:
"""
Since we can't import ControlNet, we use a function that does something like
controlnet's `list(global_state.cn_models_names.values())`.
@@ -119,7 +120,7 @@ def _get_cn_inpaint_models() -> list[str]:
if (
p.is_file()
and p.suffix in cn_model_exts
and controlnet_model_regex.match(p.name)
and cn_model_regex.search(p.name)
):
if name_filter and name_filter not in p.name.lower():
continue
@@ -134,7 +135,7 @@ def _get_cn_inpaint_models() -> list[str]:
return models
def get_cn_inpaint_models() -> list[str]:
def get_cn_models() -> list[str]:
if controlnet_exists:
return _get_cn_inpaint_models()
return _get_cn_models()
return []