feat: Adding AD S/R options to X/Y/Z plot (#356)

This commit is contained in:
Rush
2023-10-09 03:22:25 -05:00
committed by GitHub
parent d51f4b33cb
commit 9489435867
2 changed files with 26 additions and 3 deletions

View File

@@ -70,6 +70,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_use_clip_skip: bool = False
ad_clip_skip: conint(ge=1, le=12) = 1
ad_restore_face: bool = False
ad_prompt_replacements: list = []
ad_controlnet_model: constr(regex=cn_model_regex) = "None"
ad_controlnet_module: Optional[constr(regex=r".*inpaint.*|^None$")] = None
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0

View File

@@ -270,7 +270,7 @@ class AfterDetailerScript(scripts.Script):
return all_prompts[j]
def _get_prompt(
self, ad_prompt: str, all_prompts: list[str], i: int, default: str
self, ad_prompt: str, all_prompts: list[str], i: int, default: str, replacements: list[tuple[str,str]]
) -> list[str]:
prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
@@ -279,14 +279,16 @@ class AfterDetailerScript(scripts.Script):
prompts[n] = blank_replacement
elif "[PROMPT]" in prompts[n]:
prompts[n] = prompts[n].replace("[PROMPT]", f" {blank_replacement} ")
for pair in replacements:
prompts[n] = prompts[n].replace(pair[0], pair[1])
return prompts
def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
i = p._ad_idx
prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt)
prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt, args.ad_prompt_replacements)
negative_prompt = self._get_prompt(
args.ad_negative_prompt, p.all_negative_prompts, i, p.negative_prompt
args.ad_negative_prompt, p.all_negative_prompts, i, p.negative_prompt, args.ad_prompt_replacements
)
return prompt, negative_prompt
@@ -803,6 +805,16 @@ def make_axis_on_xyz_grid():
p._ad_xyz = {}
p._ad_xyz[field] = x
def search_and_replace_prompt(p, x, xs, replace_in_main_prompt):
if replace_in_main_prompt:
p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
if not hasattr(p, "_ad_xyz"):
p._ad_xyz = {}
if not "ad_prompt_replacements" in p._ad_xyz:
p._ad_xyz["ad_prompt_replacements"] = []
p._ad_xyz["ad_prompt_replacements"].append((xs[0],x))
axis = [
xyz_grid.AxisOption(
"[ADetailer] ADetailer model 1st",
@@ -820,6 +832,16 @@ def make_axis_on_xyz_grid():
str,
partial(set_value, field="ad_negative_prompt"),
),
xyz_grid.AxisOption(
"[ADetailer] Prompt S/R (AD 1st)",
str,
partial(search_and_replace_prompt, replace_in_main_prompt=False),
),
xyz_grid.AxisOption(
"[ADetailer] Prompt S/R (AD 1st and main prompt)",
str,
partial(search_and_replace_prompt, replace_in_main_prompt=True)
),
xyz_grid.AxisOption(
"[ADetailer] Mask erosion / dilation 1st",
int,