Added option to autocomplete style names

To be used in tandem with https://github.com/SirVeggie/extension-style-vars
Closes #268
This commit is contained in:
DominikDoom
2024-01-26 16:16:04 +01:00
parent 7778142520
commit d37e37acfa
6 changed files with 144 additions and 10 deletions

View File

@@ -9,7 +9,7 @@ from pathlib import Path
import gradio as gr
import yaml
from fastapi import FastAPI
from fastapi.responses import FileResponse, JSONResponse
from fastapi.responses import Response, FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared, hashes
from scripts.model_keyword_support import (get_lora_simple_hash,
@@ -307,6 +307,13 @@ def get_lyco():
# Sort
return sort_models(lycos_with_hash)
def get_style_names():
try:
style_names: list[str] = shared.prompt_styles.styles.keys()
style_names = sorted(style_names, key=len, reverse=True)
return style_names
except Exception:
return None
def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
@@ -361,6 +368,7 @@ write_to_temp_file('umi_tags.txt', [])
write_to_temp_file('hyp.txt', [])
write_to_temp_file('lora.txt', [])
write_to_temp_file('lyco.txt', [])
write_to_temp_file('styles.txt', [])
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
if not TEMP_PATH.joinpath("emb.txt").exists():
write_to_temp_file('emb.txt', [])
@@ -391,6 +399,11 @@ def refresh_temp_files(*args, **kwargs):
write_temp_files(skip_wildcard_refresh)
refresh_embeddings(force=True)
def write_style_names(*args, **kwargs):
styles = get_style_names()
if styles:
write_to_temp_file('styles.txt', styles)
def write_temp_files(skip_wildcard_refresh = False):
# Write wildcards to wc.txt if found
if WILDCARD_PATH.exists() and not skip_wildcard_refresh:
@@ -431,6 +444,8 @@ def write_temp_files(skip_wildcard_refresh = False):
if model_keyword_installed:
update_hash_cache()
if shared.prompt_styles is not None:
write_style_names()
write_temp_files()
@@ -480,6 +495,7 @@ def on_ui_settings():
"tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"),
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
# Insertion related settings
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
@@ -561,10 +577,18 @@ def on_ui_settings():
script_callbacks.on_ui_settings(on_ui_settings)
def get_style_mtime():
style_file = getattr(shared, "styles_filename", "styles.csv")
style_file = Path(FILE_DIR).joinpath(style_file)
if Path.exists(style_file):
return style_file.stat().st_mtime
last_style_mtime = get_style_mtime()
def api_tac(_: gr.Blocks, app: FastAPI):
async def get_json_info(base_path: Path, filename: str = None):
if base_path is None or (not base_path.exists()):
return JSONResponse({}, status_code=404)
return Response(status_code=404)
try:
json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True)
@@ -575,7 +599,7 @@ def api_tac(_: gr.Blocks, app: FastAPI):
async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False):
if base_path is None or (not base_path.exists()):
return JSONResponse({}, status_code=404)
return Response(status_code=404)
try:
img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True)
@@ -639,20 +663,34 @@ def api_tac(_: gr.Blocks, app: FastAPI):
@app.get("/tacapi/v1/wildcard-contents")
async def get_wildcard_contents(basepath: str, filename: str):
if basepath is None or basepath == "":
return JSONResponse({}, status_code=404)
return Response(status_code=404)
base = Path(basepath)
if base is None or (not base.exists()):
return JSONResponse({}, status_code=404)
return Response(status_code=404)
try:
wildcard_path = base.joinpath(filename)
if wildcard_path.exists() and wildcard_path.is_file():
return FileResponse(wildcard_path)
else:
return JSONResponse({}, status_code=404)
return Response(status_code=404)
except Exception as e:
return JSONResponse({"error": e}, status_code=500)
@app.get("/tacapi/v1/refresh-styles-if-changed")
async def refresh_styles_if_changed():
global last_style_mtime
mtime = get_style_mtime()
if mtime > last_style_mtime:
last_style_mtime = mtime
# Update temp file
if shared.prompt_styles is not None:
write_style_names()
return Response(status_code=200) # Success
else:
return Response(status_code=304) # Not modified
script_callbacks.on_app_started(api_tac)