fix: prompt sr

This commit is contained in:
Bingsu
2023-10-09 17:32:47 +09:00
parent a68affe1a8
commit 297e88b255
2 changed files with 18 additions and 15 deletions

View File

@@ -70,7 +70,6 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_use_clip_skip: bool = False
ad_clip_skip: conint(ge=1, le=12) = 1
ad_restore_face: bool = False
ad_prompt_replacements: list = []
ad_controlnet_model: constr(regex=cn_model_regex) = "None"
ad_controlnet_module: Optional[constr(regex=r".*inpaint.*|^None$")] = None
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0

View File

@@ -10,7 +10,7 @@ from copy import copy, deepcopy
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import Any
from typing import Any, NamedTuple
import gradio as gr
import torch
@@ -275,7 +275,7 @@ class AfterDetailerScript(scripts.Script):
all_prompts: list[str],
i: int,
default: str,
replacements: list[tuple[str, str]],
replacements: list[PromptSR],
) -> list[str]:
prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
@@ -284,22 +284,22 @@ class AfterDetailerScript(scripts.Script):
prompts[n] = blank_replacement
elif "[PROMPT]" in prompts[n]:
prompts[n] = prompts[n].replace("[PROMPT]", f" {blank_replacement} ")
for pair in replacements:
prompts[n] = prompts[n].replace(pair[0], pair[1])
prompts[n] = prompts[n].replace(pair.s, pair.r)
return prompts
def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
i = p._ad_idx
prompt_sr = p._ad_xyz_prompt_sr if hasattr(p, "_ad_xyz_prompt_sr") else []
prompt = self._get_prompt(
args.ad_prompt, p.all_prompts, i, p.prompt, args.ad_prompt_replacements
)
prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt, prompt_sr)
negative_prompt = self._get_prompt(
args.ad_negative_prompt,
p.all_negative_prompts,
i,
p.negative_prompt,
args.ad_prompt_replacements,
prompt_sr,
)
return prompt, negative_prompt
@@ -798,21 +798,25 @@ def on_ui_settings():
# xyz_grid
def set_value(p, x, xs, *, field: str):
class PromptSR(NamedTuple):
s: str
r: str
def set_value(p, x: Any, xs: Any, *, field: str):
if not hasattr(p, "_ad_xyz"):
p._ad_xyz = {}
p._ad_xyz[field] = x
def search_and_replace_prompt(p, x, xs, replace_in_main_prompt):
def search_and_replace_prompt(p, x: Any, xs: Any, replace_in_main_prompt: bool):
if replace_in_main_prompt:
p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
if not hasattr(p, "_ad_xyz"):
p._ad_xyz = {}
if "ad_prompt_replacements" not in p._ad_xyz:
p._ad_xyz["ad_prompt_replacements"] = []
p._ad_xyz["ad_prompt_replacements"].append((xs[0], x))
if not hasattr(p, "_ad_xyz_prompt_sr"):
p._ad_xyz_prompt_sr = []
p._ad_xyz_prompt_sr.append(PromptSR(s=xs[0], r=x))
def make_axis_on_xyz_grid():