Merge branch 'main' into feature-sort-by-frequent-use

This commit is contained in:
DominikDoom
2024-01-26 16:21:15 +01:00
14 changed files with 184 additions and 24 deletions

View File

@@ -16,6 +16,8 @@ hash_dict = {}
def load_hash_cache():
if not known_hashes_file.exists():
known_hashes_file.touch()
with open(known_hashes_file, "r", encoding="utf-8") as file:
reader = csv.reader(
file.readlines(), delimiter=",", quotechar='"', skipinitialspace=True
@@ -28,6 +30,8 @@ def load_hash_cache():
def update_hash_cache():
global file_needs_update
if file_needs_update:
if not known_hashes_file.exists():
known_hashes_file.touch()
with open(known_hashes_file, "w", encoding="utf-8", newline='') as file:
writer = csv.writer(file)
for name, (hash, mtime) in hash_dict.items():

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import gradio as gr
import yaml
from fastapi import FastAPI
from fastapi.responses import FileResponse, JSONResponse
from fastapi.responses import Response, FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared, hashes
from pydantic import BaseModel
@@ -85,12 +85,13 @@ def get_wildcards():
def get_ext_wildcards():
"""Returns a list of all extension wildcards. Works on nested folders."""
wildcard_files = []
excluded_folder_names = [s.strip() for s in getattr(shared.opts, "tac_wildcardExclusionList", "").split(",")]
for path in WILDCARD_EXT_PATHS:
wildcard_files.append(path.as_posix())
resolved = [(w, w.relative_to(path).as_posix())
for w in path.rglob("*.txt")
if w.name != "put wildcards here.txt"
and not any(excluded in w.parts for excluded in excluded_folder_names)
and w.is_file()]
wildcard_files.extend(sort_models(resolved, name_has_subpath=True))
wildcard_files.append("-----")
@@ -293,6 +294,8 @@ def get_lora():
valid_loras = _get_lora()
loras_with_hash = []
for l in valid_loras:
if not l.exists() or not l.is_file():
continue
name = l.relative_to(LORA_PATH).as_posix()
if model_keyword_installed:
hash = get_lora_simple_hash(l)
@@ -309,6 +312,8 @@ def get_lyco():
valid_lycos = _get_lyco()
lycos_with_hash = []
for ly in valid_lycos:
if not ly.exists() or not ly.is_file():
continue
name = ly.relative_to(LYCO_PATH).as_posix()
if model_keyword_installed:
hash = get_lora_simple_hash(ly)
@@ -318,6 +323,13 @@ def get_lyco():
# Sort
return sort_models(lycos_with_hash)
def get_style_names():
try:
style_names: list[str] = shared.prompt_styles.styles.keys()
style_names = sorted(style_names, key=len, reverse=True)
return style_names
except Exception:
return None
def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
@@ -372,6 +384,7 @@ write_to_temp_file('umi_tags.txt', [])
write_to_temp_file('hyp.txt', [])
write_to_temp_file('lora.txt', [])
write_to_temp_file('lyco.txt', [])
write_to_temp_file('styles.txt', [])
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
if not TEMP_PATH.joinpath("emb.txt").exists():
write_to_temp_file('emb.txt', [])
@@ -396,19 +409,26 @@ def refresh_embeddings(force: bool, *args, **kwargs):
def refresh_temp_files(*args, **kwargs):
global WILDCARD_EXT_PATHS
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
write_temp_files()
skip_wildcard_refresh = getattr(shared.opts, "tac_skipWildcardRefresh", False)
if skip_wildcard_refresh:
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
write_temp_files(skip_wildcard_refresh)
refresh_embeddings(force=True)
def write_temp_files():
def write_style_names(*args, **kwargs):
styles = get_style_names()
if styles:
write_to_temp_file('styles.txt', styles)
def write_temp_files(skip_wildcard_refresh = False):
# Write wildcards to wc.txt if found
if WILDCARD_PATH.exists():
if WILDCARD_PATH.exists() and not skip_wildcard_refresh:
wildcards = [WILDCARD_PATH.relative_to(FILE_DIR).as_posix()] + get_wildcards()
if wildcards:
write_to_temp_file('wc.txt', wildcards)
# Write extension wildcards to wce.txt if found
if WILDCARD_EXT_PATHS is not None:
if WILDCARD_EXT_PATHS is not None and not skip_wildcard_refresh:
wildcards_ext = get_ext_wildcards()
if wildcards_ext:
write_to_temp_file('wce.txt', wildcards_ext)
@@ -440,6 +460,8 @@ def write_temp_files():
if model_keyword_installed:
update_hash_cache()
if shared.prompt_styles is not None:
write_style_names()
write_temp_files()
@@ -485,14 +507,18 @@ def on_ui_settings():
"tac_delayTime": shared.OptionInfo(100, "Time in ms to wait before triggering completion again").needs_restart(),
"tac_useWildcards": shared.OptionInfo(True, "Search for wildcards"),
"tac_sortWildcardResults": shared.OptionInfo(True, "Sort wildcard file contents alphabetically").info("If your wildcard files have a specific custom order, disable this to keep it"),
"tac_wildcardExclusionList": shared.OptionInfo("", "Wildcard folder exclusion list").info("Add folder names that shouldn't be searched for wildcards, separated by comma.").needs_restart(),
"tac_skipWildcardRefresh": shared.OptionInfo(False, "Don't re-scan for wildcard files when pressing the extra networks refresh button").info("Useful to prevent hanging if you use a very large wildcard collection."),
"tac_useEmbeddings": shared.OptionInfo(True, "Search for embeddings"),
"tac_includeEmbeddingsInNormalResults": shared.OptionInfo(False, "Include embeddings in normal tag results").info("The 'JumpTo...' keybinds (End & Home key by default) will select the first non-embedding result of their direction on the first press for quick navigation in longer lists."),
"tac_useHypernetworks": shared.OptionInfo(True, "Search for hypernetworks"),
"tac_useLoras": shared.OptionInfo(True, "Search for Loras"),
"tac_useLycos": shared.OptionInfo(True, "Search for LyCORIS/LoHa"),
"tac_useLoraPrefixForLycos": shared.OptionInfo(True, "Use the '<lora:' prefix instead of '<lyco:' for models in the LyCORIS folder").info("The lyco prefix is included for backwards compatibility and not used anymore by default. Disable this if you are on an old webui version without built-in lyco support."),
"tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"),
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
# Frequency sorting settings
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
"tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
@@ -580,10 +606,18 @@ def on_ui_settings():
script_callbacks.on_ui_settings(on_ui_settings)
def get_style_mtime():
style_file = getattr(shared, "styles_filename", "styles.csv")
style_file = Path(FILE_DIR).joinpath(style_file)
if Path.exists(style_file):
return style_file.stat().st_mtime
last_style_mtime = get_style_mtime()
def api_tac(_: gr.Blocks, app: FastAPI):
async def get_json_info(base_path: Path, filename: str = None):
if base_path is None or (not base_path.exists()):
return JSONResponse({}, status_code=404)
return Response(status_code=404)
try:
json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True)
@@ -594,7 +628,7 @@ def api_tac(_: gr.Blocks, app: FastAPI):
async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False):
if base_path is None or (not base_path.exists()):
return JSONResponse({}, status_code=404)
return Response(status_code=404)
try:
img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True)
@@ -658,21 +692,35 @@ def api_tac(_: gr.Blocks, app: FastAPI):
@app.get("/tacapi/v1/wildcard-contents")
async def get_wildcard_contents(basepath: str, filename: str):
if basepath is None or basepath == "":
return JSONResponse({}, status_code=404)
return Response(status_code=404)
base = Path(basepath)
if base is None or (not base.exists()):
return JSONResponse({}, status_code=404)
return Response(status_code=404)
try:
wildcard_path = base.joinpath(filename)
if wildcard_path.exists() and wildcard_path.is_file():
return FileResponse(wildcard_path)
else:
return JSONResponse({}, status_code=404)
return Response(status_code=404)
except Exception as e:
return JSONResponse({"error": e}, status_code=500)
@app.get("/tacapi/v1/refresh-styles-if-changed")
async def refresh_styles_if_changed():
global last_style_mtime
mtime = get_style_mtime()
if mtime > last_style_mtime:
last_style_mtime = mtime
# Update temp file
if shared.prompt_styles is not None:
write_style_names()
return Response(status_code=200) # Success
else:
return Response(status_code=304) # Not modified
def db_request(func, get = False):
if db is not None:
try: