feat: cn inpaint modules

This commit is contained in:
Bingsu
2023-06-19 23:04:20 +09:00
parent a6467ec968
commit ad13b03fa3
4 changed files with 55 additions and 7 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from collections import UserList
from functools import cached_property, partial
from typing import Any, Literal, NamedTuple, Union
from typing import Any, Literal, NamedTuple, Optional, Union
import pydantic
from pydantic import (
@@ -13,6 +13,7 @@ from pydantic import (
PositiveInt,
confloat,
constr,
root_validator,
)
cn_model_regex = r".*(inpaint|tile|scribble|lineart|openpose).*|^None$"
@@ -57,10 +58,19 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_cfg_scale: NonNegativeFloat = 7.0
ad_restore_face: bool = False
ad_controlnet_model: constr(regex=cn_model_regex) = "None"
ad_controlnet_module: Optional[constr(regex=r".*inpaint.*|^None$")] = None
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0
ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0
@root_validator
def ad_controlnt_module_validator(cls, values): # noqa: N805
cn_model = values.get("ad_controlnet_model", "None")
cn_module = values.get("ad_controlnet_module", None)
if "inpaint" not in cn_model or cn_module == "None":
values["ad_controlnet_module"] = None
return values
@staticmethod
def ppop(
p: dict[str, Any],
@@ -115,6 +125,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
"ADetailer ControlNet model",
[
"ADetailer ControlNet model",
"ADetailer ControlNet module",
"ADetailer ControlNet weight",
"ADetailer ControlNet guidance start",
"ADetailer ControlNet guidance end",
@@ -169,6 +180,7 @@ _all_args = [
("ad_cfg_scale", "ADetailer CFG scale"),
("ad_restore_face", "ADetailer restore face"),
("ad_controlnet_model", "ADetailer ControlNet model"),
("ad_controlnet_module", "ADetailer ControlNet module"),
("ad_controlnet_weight", "ADetailer ControlNet weight"),
("ad_controlnet_guidance_start", "ADetailer ControlNet guidance start"),
("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"),

View File

@@ -10,6 +10,12 @@ from adetailer import AFTER_DETAILER, __version__
from adetailer.args import AD_ENABLE, ALL_ARGS, MASK_MERGE_INVERT
from controlnet_ext import controlnet_exists, get_cn_models
cn_module_choices = [
"inpaint_global_harmonious",
"inpaint_only",
"inpaint_only+lama",
]
class Widgets(SimpleNamespace):
def tolist(self):
@@ -40,6 +46,12 @@ def on_generate_click(state: dict, *values: Any):
return state
def on_cn_model_update(cn_model: str):
if "inpaint" in cn_model:
return gr.update(visible=True, choices=cn_module_choices)
return gr.update(visible=False, choices=["None"])
def elem_id(item_id: str, n: int, is_img2img: bool) -> str:
tap = "img2img" if is_img2img else "txt2img"
suf = suffix(n, "_")
@@ -416,6 +428,16 @@ def controlnet(w: Widgets, n: int, is_img2img: bool):
elem_id=eid("ad_controlnet_model"),
)
w.ad_controlnet_module = gr.Dropdown(
label="ControlNet module" + suffix(n),
choices=cn_module_choices,
value="inpaint_global_harmonious",
visible=False,
type="value",
interactive=controlnet_exists,
elem_id=eid("ad_controlnet_module"),
)
w.ad_controlnet_weight = gr.Slider(
label="ControlNet weight" + suffix(n),
minimum=0.0,
@@ -427,6 +449,13 @@ def controlnet(w: Widgets, n: int, is_img2img: bool):
elem_id=eid("ad_controlnet_weight"),
)
w.ad_controlnet_model.change(
on_cn_model_update,
inputs=w.ad_controlnet_model,
outputs=w.ad_controlnet_module,
queue=False,
)
with gr.Column(variant="compact"):
w.ad_controlnet_guidance_start = gr.Slider(
label="ControlNet guidance start" + suffix(n),

View File

@@ -49,16 +49,22 @@ class ControlNetExt:
self.cn_models.extend(m for m in models if cn_model_regex.search(m))
def update_scripts_args(
self, p, model: str, weight: float, guidance_start: float, guidance_end: float
self,
p,
model: str,
module: str | None,
weight: float,
guidance_start: float,
guidance_end: float,
):
if (not self.cn_available) or model == "None":
return
module = None
for m, v in cn_model_module.items():
if m in model:
module = v
break
if module is None:
for m, v in cn_model_module.items():
if m in model:
module = v
break
cn_units = [
self.external_cn.ControlNetUnit(

View File

@@ -141,6 +141,7 @@ class AfterDetailerScript(scripts.Script):
self.controlnet_ext.update_scripts_args(
p,
model=args.ad_controlnet_model,
module=args.ad_controlnet_module,
weight=args.ad_controlnet_weight,
guidance_start=args.ad_controlnet_guidance_start,
guidance_end=args.ad_controlnet_guidance_end,