From 297e88b2552ac23fd6feba4bb14c34d5a6ea5d53 Mon Sep 17 00:00:00 2001 From: Bingsu Date: Mon, 9 Oct 2023 17:32:47 +0900 Subject: [PATCH] fix: prompt sr --- adetailer/args.py | 1 - scripts/!adetailer.py | 32 ++++++++++++++++++-------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/adetailer/args.py b/adetailer/args.py index 9e739f9..503eebb 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -70,7 +70,6 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_use_clip_skip: bool = False ad_clip_skip: conint(ge=1, le=12) = 1 ad_restore_face: bool = False - ad_prompt_replacements: list = [] ad_controlnet_model: constr(regex=cn_model_regex) = "None" ad_controlnet_module: Optional[constr(regex=r".*inpaint.*|^None$")] = None ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 982e383..06527be 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -10,7 +10,7 @@ from copy import copy, deepcopy from functools import partial from pathlib import Path from textwrap import dedent -from typing import Any +from typing import Any, NamedTuple import gradio as gr import torch @@ -275,7 +275,7 @@ class AfterDetailerScript(scripts.Script): all_prompts: list[str], i: int, default: str, - replacements: list[tuple[str, str]], + replacements: list[PromptSR], ) -> list[str]: prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt) blank_replacement = self.prompt_blank_replacement(all_prompts, i, default) @@ -284,22 +284,22 @@ class AfterDetailerScript(scripts.Script): prompts[n] = blank_replacement elif "[PROMPT]" in prompts[n]: prompts[n] = prompts[n].replace("[PROMPT]", f" {blank_replacement} ") + for pair in replacements: - prompts[n] = prompts[n].replace(pair[0], pair[1]) + prompts[n] = prompts[n].replace(pair.s, pair.r) return prompts def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]: i = p._ad_idx + prompt_sr = p._ad_xyz_prompt_sr if hasattr(p, "_ad_xyz_prompt_sr") else [] - prompt = self._get_prompt( - args.ad_prompt, p.all_prompts, i, p.prompt, args.ad_prompt_replacements - ) + prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt, prompt_sr) negative_prompt = self._get_prompt( args.ad_negative_prompt, p.all_negative_prompts, i, p.negative_prompt, - args.ad_prompt_replacements, + prompt_sr, ) return prompt, negative_prompt @@ -798,21 +798,25 @@ def on_ui_settings(): # xyz_grid -def set_value(p, x, xs, *, field: str): +class PromptSR(NamedTuple): + s: str + r: str + + +def set_value(p, x: Any, xs: Any, *, field: str): if not hasattr(p, "_ad_xyz"): p._ad_xyz = {} p._ad_xyz[field] = x -def search_and_replace_prompt(p, x, xs, replace_in_main_prompt): +def search_and_replace_prompt(p, x: Any, xs: Any, replace_in_main_prompt: bool): if replace_in_main_prompt: p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) - if not hasattr(p, "_ad_xyz"): - p._ad_xyz = {} - if "ad_prompt_replacements" not in p._ad_xyz: - p._ad_xyz["ad_prompt_replacements"] = [] - p._ad_xyz["ad_prompt_replacements"].append((xs[0], x)) + + if not hasattr(p, "_ad_xyz_prompt_sr"): + p._ad_xyz_prompt_sr = [] + p._ad_xyz_prompt_sr.append(PromptSR(s=xs[0], r=x)) def make_axis_on_xyz_grid():