feat: split ad_prompt with [SEP]

This commit is contained in:
Bingsu
2023-05-14 22:42:13 +09:00
parent a5b82742c8
commit e0b947eb0a

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import platform
import re
import sys
import traceback
from copy import copy, deepcopy
@@ -177,9 +178,9 @@ class AfterDetailerScript(scripts.Script):
return device
def _get_prompt(self, ad_prompt: str, all_prompts: list[str], i: int, default: str):
if ad_prompt:
return ad_prompt
def prompt_blank_replacement(
self, all_prompts: list[str], i: int, default: str
) -> str:
if not all_prompts:
return default
if i < len(all_prompts):
@@ -187,7 +188,17 @@ class AfterDetailerScript(scripts.Script):
j = i % len(all_prompts)
return all_prompts[j]
def get_prompt(self, p, args: ADetailerArgs) -> tuple[str, str]:
def _get_prompt(
self, ad_prompt: str, all_prompts: list[str], i: int, default: str
) -> list[str]:
prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
for n in range(len(prompts)):
if not prompts[n]:
prompts[n] = blank_replacement
return prompts
def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
i = p._idx
prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt)
@@ -285,7 +296,6 @@ class AfterDetailerScript(scripts.Script):
obj.enabled = False
def get_i2i_p(self, p, args: ADetailerArgs, image):
prompt, negative_prompt = self.get_prompt(p, args)
seed, subseed = self.get_seed(p)
width, height = self.get_width_height(p, args)
steps = self.get_steps(p, args)
@@ -308,8 +318,8 @@ class AfterDetailerScript(scripts.Script):
sd_model=p.sd_model,
outpath_samples=p.outpath_samples,
outpath_grids=p.outpath_grids,
prompt=prompt,
negative_prompt=negative_prompt,
prompt="", # replace later
negative_prompt="",
styles=p.styles,
seed=seed,
subseed=subseed,
@@ -359,6 +369,16 @@ class AfterDetailerScript(scripts.Script):
raise ValueError(msg)
return model_mapping[name]
def i2i_prompts_replace(
self, i2i, prompts: list[str], negative_prompts: list[str], j: int
):
i1 = j % len(prompts)
i2 = j % len(negative_prompts)
prompt = prompts[i1]
negative_prompt = negative_prompts[i2]
i2i.prompt = prompt
i2i.negative_prompt = negative_prompt
def process(self, p, *args_):
if getattr(p, "_disable_adetailer", False):
return
@@ -380,6 +400,7 @@ class AfterDetailerScript(scripts.Script):
i2i = self.get_i2i_p(p, args, pp.image)
seed, subseed = self.get_seed(p)
ad_prompts, ad_negatives = self.get_prompt(p, args)
is_mediapipe = args.ad_model.lower().startswith("mediapipe")
@@ -424,6 +445,7 @@ class AfterDetailerScript(scripts.Script):
p2 = copy(i2i)
for j in range(steps):
p2.image_mask = masks[j]
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
processed = process_images(p2)
p2 = copy(i2i)