mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-04-27 09:41:40 +00:00
feat: split ad_prompt with [SEP]
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from copy import copy, deepcopy
|
||||
@@ -177,9 +178,9 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
return device
|
||||
|
||||
def _get_prompt(self, ad_prompt: str, all_prompts: list[str], i: int, default: str):
|
||||
if ad_prompt:
|
||||
return ad_prompt
|
||||
def prompt_blank_replacement(
|
||||
self, all_prompts: list[str], i: int, default: str
|
||||
) -> str:
|
||||
if not all_prompts:
|
||||
return default
|
||||
if i < len(all_prompts):
|
||||
@@ -187,7 +188,17 @@ class AfterDetailerScript(scripts.Script):
|
||||
j = i % len(all_prompts)
|
||||
return all_prompts[j]
|
||||
|
||||
def get_prompt(self, p, args: ADetailerArgs) -> tuple[str, str]:
|
||||
def _get_prompt(
|
||||
self, ad_prompt: str, all_prompts: list[str], i: int, default: str
|
||||
) -> list[str]:
|
||||
prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
|
||||
blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
|
||||
for n in range(len(prompts)):
|
||||
if not prompts[n]:
|
||||
prompts[n] = blank_replacement
|
||||
return prompts
|
||||
|
||||
def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
|
||||
i = p._idx
|
||||
|
||||
prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt)
|
||||
@@ -285,7 +296,6 @@ class AfterDetailerScript(scripts.Script):
|
||||
obj.enabled = False
|
||||
|
||||
def get_i2i_p(self, p, args: ADetailerArgs, image):
|
||||
prompt, negative_prompt = self.get_prompt(p, args)
|
||||
seed, subseed = self.get_seed(p)
|
||||
width, height = self.get_width_height(p, args)
|
||||
steps = self.get_steps(p, args)
|
||||
@@ -308,8 +318,8 @@ class AfterDetailerScript(scripts.Script):
|
||||
sd_model=p.sd_model,
|
||||
outpath_samples=p.outpath_samples,
|
||||
outpath_grids=p.outpath_grids,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt="", # replace later
|
||||
negative_prompt="",
|
||||
styles=p.styles,
|
||||
seed=seed,
|
||||
subseed=subseed,
|
||||
@@ -359,6 +369,16 @@ class AfterDetailerScript(scripts.Script):
|
||||
raise ValueError(msg)
|
||||
return model_mapping[name]
|
||||
|
||||
def i2i_prompts_replace(
|
||||
self, i2i, prompts: list[str], negative_prompts: list[str], j: int
|
||||
):
|
||||
i1 = j % len(prompts)
|
||||
i2 = j % len(negative_prompts)
|
||||
prompt = prompts[i1]
|
||||
negative_prompt = negative_prompts[i2]
|
||||
i2i.prompt = prompt
|
||||
i2i.negative_prompt = negative_prompt
|
||||
|
||||
def process(self, p, *args_):
|
||||
if getattr(p, "_disable_adetailer", False):
|
||||
return
|
||||
@@ -380,6 +400,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
i2i = self.get_i2i_p(p, args, pp.image)
|
||||
seed, subseed = self.get_seed(p)
|
||||
ad_prompts, ad_negatives = self.get_prompt(p, args)
|
||||
|
||||
is_mediapipe = args.ad_model.lower().startswith("mediapipe")
|
||||
|
||||
@@ -424,6 +445,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
p2 = copy(i2i)
|
||||
for j in range(steps):
|
||||
p2.image_mask = masks[j]
|
||||
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
|
||||
processed = process_images(p2)
|
||||
|
||||
p2 = copy(i2i)
|
||||
|
||||
Reference in New Issue
Block a user