mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-03-13 09:10:26 +00:00
Merge branch 'main' into feature-sort-by-frequent-use
This commit is contained in:
@@ -16,6 +16,8 @@ hash_dict = {}
|
||||
|
||||
|
||||
def load_hash_cache():
|
||||
if not known_hashes_file.exists():
|
||||
known_hashes_file.touch()
|
||||
with open(known_hashes_file, "r", encoding="utf-8") as file:
|
||||
reader = csv.reader(
|
||||
file.readlines(), delimiter=",", quotechar='"', skipinitialspace=True
|
||||
@@ -28,6 +30,8 @@ def load_hash_cache():
|
||||
def update_hash_cache():
|
||||
global file_needs_update
|
||||
if file_needs_update:
|
||||
if not known_hashes_file.exists():
|
||||
known_hashes_file.touch()
|
||||
with open(known_hashes_file, "w", encoding="utf-8", newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
for name, (hash, mtime) in hash_dict.items():
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
import gradio as gr
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.responses import Response, FileResponse, JSONResponse
|
||||
from modules import script_callbacks, sd_hijack, shared, hashes
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -85,12 +85,13 @@ def get_wildcards():
|
||||
def get_ext_wildcards():
|
||||
"""Returns a list of all extension wildcards. Works on nested folders."""
|
||||
wildcard_files = []
|
||||
|
||||
excluded_folder_names = [s.strip() for s in getattr(shared.opts, "tac_wildcardExclusionList", "").split(",")]
|
||||
for path in WILDCARD_EXT_PATHS:
|
||||
wildcard_files.append(path.as_posix())
|
||||
resolved = [(w, w.relative_to(path).as_posix())
|
||||
for w in path.rglob("*.txt")
|
||||
if w.name != "put wildcards here.txt"
|
||||
and not any(excluded in w.parts for excluded in excluded_folder_names)
|
||||
and w.is_file()]
|
||||
wildcard_files.extend(sort_models(resolved, name_has_subpath=True))
|
||||
wildcard_files.append("-----")
|
||||
@@ -293,6 +294,8 @@ def get_lora():
|
||||
valid_loras = _get_lora()
|
||||
loras_with_hash = []
|
||||
for l in valid_loras:
|
||||
if not l.exists() or not l.is_file():
|
||||
continue
|
||||
name = l.relative_to(LORA_PATH).as_posix()
|
||||
if model_keyword_installed:
|
||||
hash = get_lora_simple_hash(l)
|
||||
@@ -309,6 +312,8 @@ def get_lyco():
|
||||
valid_lycos = _get_lyco()
|
||||
lycos_with_hash = []
|
||||
for ly in valid_lycos:
|
||||
if not ly.exists() or not ly.is_file():
|
||||
continue
|
||||
name = ly.relative_to(LYCO_PATH).as_posix()
|
||||
if model_keyword_installed:
|
||||
hash = get_lora_simple_hash(ly)
|
||||
@@ -318,6 +323,13 @@ def get_lyco():
|
||||
# Sort
|
||||
return sort_models(lycos_with_hash)
|
||||
|
||||
def get_style_names():
|
||||
try:
|
||||
style_names: list[str] = shared.prompt_styles.styles.keys()
|
||||
style_names = sorted(style_names, key=len, reverse=True)
|
||||
return style_names
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def write_tag_base_path():
|
||||
"""Writes the tag base path to a fixed location temporary file"""
|
||||
@@ -372,6 +384,7 @@ write_to_temp_file('umi_tags.txt', [])
|
||||
write_to_temp_file('hyp.txt', [])
|
||||
write_to_temp_file('lora.txt', [])
|
||||
write_to_temp_file('lyco.txt', [])
|
||||
write_to_temp_file('styles.txt', [])
|
||||
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
|
||||
if not TEMP_PATH.joinpath("emb.txt").exists():
|
||||
write_to_temp_file('emb.txt', [])
|
||||
@@ -396,19 +409,26 @@ def refresh_embeddings(force: bool, *args, **kwargs):
|
||||
|
||||
def refresh_temp_files(*args, **kwargs):
|
||||
global WILDCARD_EXT_PATHS
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
write_temp_files()
|
||||
skip_wildcard_refresh = getattr(shared.opts, "tac_skipWildcardRefresh", False)
|
||||
if skip_wildcard_refresh:
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
write_temp_files(skip_wildcard_refresh)
|
||||
refresh_embeddings(force=True)
|
||||
|
||||
def write_temp_files():
|
||||
def write_style_names(*args, **kwargs):
|
||||
styles = get_style_names()
|
||||
if styles:
|
||||
write_to_temp_file('styles.txt', styles)
|
||||
|
||||
def write_temp_files(skip_wildcard_refresh = False):
|
||||
# Write wildcards to wc.txt if found
|
||||
if WILDCARD_PATH.exists():
|
||||
if WILDCARD_PATH.exists() and not skip_wildcard_refresh:
|
||||
wildcards = [WILDCARD_PATH.relative_to(FILE_DIR).as_posix()] + get_wildcards()
|
||||
if wildcards:
|
||||
write_to_temp_file('wc.txt', wildcards)
|
||||
|
||||
# Write extension wildcards to wce.txt if found
|
||||
if WILDCARD_EXT_PATHS is not None:
|
||||
if WILDCARD_EXT_PATHS is not None and not skip_wildcard_refresh:
|
||||
wildcards_ext = get_ext_wildcards()
|
||||
if wildcards_ext:
|
||||
write_to_temp_file('wce.txt', wildcards_ext)
|
||||
@@ -440,6 +460,8 @@ def write_temp_files():
|
||||
if model_keyword_installed:
|
||||
update_hash_cache()
|
||||
|
||||
if shared.prompt_styles is not None:
|
||||
write_style_names()
|
||||
|
||||
write_temp_files()
|
||||
|
||||
@@ -485,14 +507,18 @@ def on_ui_settings():
|
||||
"tac_delayTime": shared.OptionInfo(100, "Time in ms to wait before triggering completion again").needs_restart(),
|
||||
"tac_useWildcards": shared.OptionInfo(True, "Search for wildcards"),
|
||||
"tac_sortWildcardResults": shared.OptionInfo(True, "Sort wildcard file contents alphabetically").info("If your wildcard files have a specific custom order, disable this to keep it"),
|
||||
"tac_wildcardExclusionList": shared.OptionInfo("", "Wildcard folder exclusion list").info("Add folder names that shouldn't be searched for wildcards, separated by comma.").needs_restart(),
|
||||
"tac_skipWildcardRefresh": shared.OptionInfo(False, "Don't re-scan for wildcard files when pressing the extra networks refresh button").info("Useful to prevent hanging if you use a very large wildcard collection."),
|
||||
"tac_useEmbeddings": shared.OptionInfo(True, "Search for embeddings"),
|
||||
"tac_includeEmbeddingsInNormalResults": shared.OptionInfo(False, "Include embeddings in normal tag results").info("The 'JumpTo...' keybinds (End & Home key by default) will select the first non-embedding result of their direction on the first press for quick navigation in longer lists."),
|
||||
"tac_useHypernetworks": shared.OptionInfo(True, "Search for hypernetworks"),
|
||||
"tac_useLoras": shared.OptionInfo(True, "Search for Loras"),
|
||||
"tac_useLycos": shared.OptionInfo(True, "Search for LyCORIS/LoHa"),
|
||||
"tac_useLoraPrefixForLycos": shared.OptionInfo(True, "Use the '<lora:' prefix instead of '<lyco:' for models in the LyCORIS folder").info("The lyco prefix is included for backwards compatibility and not used anymore by default. Disable this if you are on an old webui version without built-in lyco support."),
|
||||
"tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"),
|
||||
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
|
||||
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
|
||||
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
|
||||
# Frequency sorting settings
|
||||
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
|
||||
"tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
|
||||
@@ -580,10 +606,18 @@ def on_ui_settings():
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
||||
def get_style_mtime():
|
||||
style_file = getattr(shared, "styles_filename", "styles.csv")
|
||||
style_file = Path(FILE_DIR).joinpath(style_file)
|
||||
if Path.exists(style_file):
|
||||
return style_file.stat().st_mtime
|
||||
|
||||
last_style_mtime = get_style_mtime()
|
||||
|
||||
def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
async def get_json_info(base_path: Path, filename: str = None):
|
||||
if base_path is None or (not base_path.exists()):
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True)
|
||||
@@ -594,7 +628,7 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
|
||||
async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False):
|
||||
if base_path is None or (not base_path.exists()):
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True)
|
||||
@@ -658,21 +692,35 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/tacapi/v1/wildcard-contents")
|
||||
async def get_wildcard_contents(basepath: str, filename: str):
|
||||
if basepath is None or basepath == "":
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
base = Path(basepath)
|
||||
if base is None or (not base.exists()):
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
wildcard_path = base.joinpath(filename)
|
||||
if wildcard_path.exists() and wildcard_path.is_file():
|
||||
return FileResponse(wildcard_path)
|
||||
else:
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": e}, status_code=500)
|
||||
|
||||
@app.get("/tacapi/v1/refresh-styles-if-changed")
|
||||
async def refresh_styles_if_changed():
|
||||
global last_style_mtime
|
||||
|
||||
mtime = get_style_mtime()
|
||||
if mtime > last_style_mtime:
|
||||
last_style_mtime = mtime
|
||||
# Update temp file
|
||||
if shared.prompt_styles is not None:
|
||||
write_style_names()
|
||||
|
||||
return Response(status_code=200) # Success
|
||||
else:
|
||||
return Response(status_code=304) # Not modified
|
||||
def db_request(func, get = False):
|
||||
if db is not None:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user