diff --git a/adetailer/args.py b/adetailer/args.py index 503eebb..9e739f9 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -70,6 +70,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_use_clip_skip: bool = False ad_clip_skip: conint(ge=1, le=12) = 1 ad_restore_face: bool = False + ad_prompt_replacements: list = [] ad_controlnet_model: constr(regex=cn_model_regex) = "None" ad_controlnet_module: Optional[constr(regex=r".*inpaint.*|^None$")] = None ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index c507068..9f74d80 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -270,7 +270,7 @@ class AfterDetailerScript(scripts.Script): return all_prompts[j] def _get_prompt( - self, ad_prompt: str, all_prompts: list[str], i: int, default: str + self, ad_prompt: str, all_prompts: list[str], i: int, default: str, replacements: list[tuple[str,str]] ) -> list[str]: prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt) blank_replacement = self.prompt_blank_replacement(all_prompts, i, default) @@ -279,14 +279,16 @@ class AfterDetailerScript(scripts.Script): prompts[n] = blank_replacement elif "[PROMPT]" in prompts[n]: prompts[n] = prompts[n].replace("[PROMPT]", f" {blank_replacement} ") + for pair in replacements: + prompts[n] = prompts[n].replace(pair[0], pair[1]) return prompts def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]: i = p._ad_idx - prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt) + prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt, args.ad_prompt_replacements) negative_prompt = self._get_prompt( - args.ad_negative_prompt, p.all_negative_prompts, i, p.negative_prompt + args.ad_negative_prompt, p.all_negative_prompts, i, p.negative_prompt, args.ad_prompt_replacements ) return prompt, negative_prompt @@ -803,6 +805,16 @@ def make_axis_on_xyz_grid(): p._ad_xyz = {} p._ad_xyz[field] = x + def search_and_replace_prompt(p, x, xs, replace_in_main_prompt): + if replace_in_main_prompt: + p.prompt = p.prompt.replace(xs[0], x) + p.negative_prompt = p.negative_prompt.replace(xs[0], x) + if not hasattr(p, "_ad_xyz"): + p._ad_xyz = {} + if not "ad_prompt_replacements" in p._ad_xyz: + p._ad_xyz["ad_prompt_replacements"] = [] + p._ad_xyz["ad_prompt_replacements"].append((xs[0],x)) + axis = [ xyz_grid.AxisOption( "[ADetailer] ADetailer model 1st", @@ -820,6 +832,16 @@ def make_axis_on_xyz_grid(): str, partial(set_value, field="ad_negative_prompt"), ), + xyz_grid.AxisOption( + "[ADetailer] Prompt S/R (AD 1st)", + str, + partial(search_and_replace_prompt, replace_in_main_prompt=False), + ), + xyz_grid.AxisOption( + "[ADetailer] Prompt S/R (AD 1st and main prompt)", + str, + partial(search_and_replace_prompt, replace_in_main_prompt=True) + ), xyz_grid.AxisOption( "[ADetailer] Mask erosion / dilation 1st", int,