mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 10:59:47 +00:00
(#1862) a fix for the 0.01% who use comments in prompts. Before this, styles could be considered part of a comment. strips comments from prompts first, then from each applied style before merge same process for extracting styles from prompts updated tooltips for toolbuttons to apply styles removed code made redundant by this change, from modules.processing_scripts.comments
235 lines
9.5 KiB
Python
235 lines
9.5 KiB
Python
from pathlib import Path
|
|
from modules import errors
|
|
import csv
|
|
import os
|
|
import typing
|
|
import shutil
|
|
import modules.processing_scripts.comments as comments
|
|
|
|
|
|
class PromptStyle(typing.NamedTuple):
|
|
name: str
|
|
prompt: str | None
|
|
negative_prompt: str | None
|
|
path: str | None = None
|
|
|
|
|
|
def apply_styles_to_prompt(prompt, styles):
|
|
prompt = comments.strip_comments(prompt).strip()
|
|
|
|
for style in styles:
|
|
style = comments.strip_comments(style).strip()
|
|
|
|
if "{prompt}" in style:
|
|
prompt = style.replace("{prompt}", prompt)
|
|
elif style != "":
|
|
if prompt != "":
|
|
prompt += ", "
|
|
prompt += style
|
|
|
|
return prompt
|
|
|
|
|
|
def extract_style_text_from_prompt(style_text, prompt):
|
|
"""This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
|
|
|
|
extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
|
|
extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
|
|
extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
|
|
"""
|
|
|
|
stripped_prompt = prompt.strip()
|
|
stripped_style_text = comments.strip_comments(style_text).strip()
|
|
|
|
if "{prompt}" in stripped_style_text:
|
|
left, _, right = stripped_style_text.partition("{prompt}")
|
|
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
|
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
|
return True, prompt
|
|
else:
|
|
if stripped_prompt.endswith(stripped_style_text):
|
|
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
|
|
|
if prompt.endswith(', '):
|
|
prompt = prompt[:-2]
|
|
|
|
return True, prompt
|
|
|
|
return False, prompt
|
|
|
|
|
|
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
|
"""
|
|
Takes a style and compares it to the prompt and negative prompt. If the style
|
|
matches, returns True plus the prompt and negative prompt with the style text
|
|
removed. Otherwise, returns False with the original prompt and negative prompt.
|
|
"""
|
|
if not style.prompt and not style.negative_prompt:
|
|
return False, prompt, negative_prompt
|
|
|
|
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
|
if not match_positive:
|
|
return False, prompt, negative_prompt
|
|
|
|
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
|
if not match_negative:
|
|
return False, prompt, negative_prompt
|
|
|
|
return True, extracted_positive, extracted_negative
|
|
|
|
|
|
class StyleDatabase:
|
|
def __init__(self, paths: list[str | Path]):
|
|
self.no_style = PromptStyle("None", "", "", None)
|
|
self.styles = {}
|
|
self.paths = paths
|
|
self.all_styles_files: list[Path] = []
|
|
|
|
folder, file = os.path.split(self.paths[0])
|
|
if '*' in file or '?' in file:
|
|
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
|
|
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
|
|
self.paths.insert(0, self.default_path)
|
|
else:
|
|
self.default_path = Path(self.paths[0])
|
|
|
|
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
|
|
|
self.reload()
|
|
|
|
def reload(self):
|
|
"""
|
|
Clears the style database and reloads the styles from the CSV file(s)
|
|
matching the path used to initialize the database.
|
|
"""
|
|
self.styles.clear()
|
|
|
|
# scans for all styles files
|
|
all_styles_files = []
|
|
for pattern in self.paths:
|
|
folder, file = os.path.split(pattern)
|
|
if '*' in file or '?' in file:
|
|
found_files = Path(folder).glob(file)
|
|
[all_styles_files.append(file) for file in found_files]
|
|
else:
|
|
# if os.path.exists(pattern):
|
|
all_styles_files.append(Path(pattern))
|
|
|
|
# Remove any duplicate entries
|
|
seen = set()
|
|
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
|
|
|
|
for styles_file in self.all_styles_files:
|
|
if len(all_styles_files) > 1:
|
|
# add divider when more than styles file
|
|
# '---------------- STYLES ----------------'
|
|
divider = f' {styles_file.stem.upper()} '.center(40, '-')
|
|
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
|
|
if styles_file.is_file():
|
|
self.load_from_csv(styles_file)
|
|
|
|
def load_from_csv(self, path: str | Path):
|
|
try:
|
|
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
|
reader = csv.DictReader(file, skipinitialspace=True)
|
|
for row in reader:
|
|
# Ignore empty rows or rows starting with a comment
|
|
if not row or row["name"].startswith("#"):
|
|
continue
|
|
# Support loading old CSV format with "name, text"-columns
|
|
prompt = row["prompt"] if "prompt" in row else row["text"]
|
|
negative_prompt = row.get("negative_prompt", "")
|
|
# Add style to database
|
|
self.styles[row["name"]] = PromptStyle(
|
|
row["name"], prompt, negative_prompt, str(path)
|
|
)
|
|
except Exception:
|
|
errors.report(f'Error loading styles from {path}: ', exc_info=True)
|
|
|
|
def get_style_paths(self) -> set:
|
|
"""Returns a set of all distinct paths of files that styles are loaded from."""
|
|
# Update any styles without a path to the default path
|
|
for style in list(self.styles.values()):
|
|
if not style.path:
|
|
self.styles[style.name] = style._replace(path=str(self.default_path))
|
|
|
|
# Create a list of all distinct paths, including the default path
|
|
style_paths = set()
|
|
style_paths.add(str(self.default_path))
|
|
for _, style in self.styles.items():
|
|
if style.path:
|
|
style_paths.add(style.path)
|
|
|
|
# Remove any paths for styles that are just list dividers
|
|
style_paths.discard("do_not_save")
|
|
|
|
return style_paths
|
|
|
|
def get_style_prompts(self, styles):
|
|
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
|
|
|
def get_negative_style_prompts(self, styles):
|
|
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
|
|
|
def apply_styles_to_prompt(self, prompt, styles):
|
|
return apply_styles_to_prompt(
|
|
prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
|
|
)
|
|
|
|
def apply_negative_styles_to_prompt(self, prompt, styles):
|
|
return apply_styles_to_prompt(
|
|
prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
|
)
|
|
|
|
def save_styles(self, path: str = None) -> None:
|
|
# The path argument is deprecated, but kept for backwards compatibility
|
|
|
|
style_paths = self.get_style_paths()
|
|
|
|
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
|
|
|
for style_path in style_paths:
|
|
# Always keep a backup file around
|
|
if os.path.exists(style_path):
|
|
shutil.copy(style_path, f"{style_path}.bak")
|
|
|
|
# Write the styles to the CSV file
|
|
with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
|
|
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
|
|
writer.writeheader()
|
|
for style in (s for s in self.styles.values() if s.path == style_path):
|
|
# Skip style list dividers, e.g. "STYLES.CSV"
|
|
if style.name.lower().strip("# ") in csv_names:
|
|
continue
|
|
# Write style fields, ignoring the path field
|
|
writer.writerow(
|
|
{k: v for k, v in style._asdict().items() if k != "path"}
|
|
)
|
|
|
|
def extract_styles_from_prompt(self, positive, negative):
|
|
extracted = []
|
|
|
|
applicable_styles = list(self.styles.values())
|
|
|
|
positive = comments.strip_comments(positive)
|
|
negative = comments.strip_comments(negative)
|
|
|
|
while True:
|
|
found_style = None
|
|
|
|
for style in applicable_styles:
|
|
is_match, new_positive, new_negative = extract_original_prompts(
|
|
style, positive, negative
|
|
)
|
|
if is_match:
|
|
found_style = style
|
|
positive = new_positive
|
|
negative = new_negative
|
|
applicable_styles.remove(found_style)
|
|
extracted.append(found_style.name)
|
|
break
|
|
|
|
if not found_style:
|
|
break
|
|
|
|
return list(reversed(extracted)), positive, negative |