diff --git a/javascript/__globals.js b/javascript/__globals.js index cf75e81..32cf94c 100644 --- a/javascript/__globals.js +++ b/javascript/__globals.js @@ -19,6 +19,7 @@ var loras = []; var lycos = []; var modelKeywordDict = new Map(); var chants = []; +var styleNames = []; // Selected model info for black/whitelisting var currentModelHash = ""; diff --git a/javascript/_result.js b/javascript/_result.js index 823f26d..73dab12 100644 --- a/javascript/_result.js +++ b/javascript/_result.js @@ -12,7 +12,8 @@ const ResultType = Object.freeze({ "hypernetwork": 8, "lora": 9, "lyco": 10, - "chant": 11 + "chant": 11, + "styleName": 12 }); // Class to hold result data and annotations to make it clearer to use diff --git a/javascript/_utils.js b/javascript/_utils.js index 9a93104..e924928 100644 --- a/javascript/_utils.js +++ b/javascript/_utils.js @@ -108,6 +108,29 @@ async function getExtraNetworkPreviewURL(filename, type) { } } +lastStyleRefresh = 0; +// Refresh style file if needed +async function refreshStyleNamesIfChanged() { + // Only refresh once per second + currentTimestamp = new Date().getTime(); + if (currentTimestamp - lastStyleRefresh < 1000) return; + lastStyleRefresh = currentTimestamp; + + const response = await fetch(`tacapi/v1/refresh-styles-if-changed?${new Date().getTime()}`) + if (response.status === 304) { + // Not modified + } else if (response.status === 200) { + // Reload + QUEUE_FILE_LOAD.forEach(async fn => { + if (fn.toString().includes("styleNames")) + await fn.call(null, true); + }) + } else { + // Error + console.error(`Error refreshing styles.txt: ` + response.status, response.statusText); + } +} + // Debounce function to prevent spamming the autocomplete function var dbTimeOut; const debounce = (func, wait = 300) => { diff --git a/javascript/ext_styles.js b/javascript/ext_styles.js new file mode 100644 index 0000000..7f26a9c --- /dev/null +++ b/javascript/ext_styles.js @@ -0,0 +1,67 @@ +const STYLE_REGEX = /(\$(\d*)\(?)[^$|\]\s]*\)?/; +const STYLE_TRIGGER = () => TAC_CFG.useStyleVars && tagword.match(STYLE_REGEX); + +var lastStyleVarIndex = ""; + +class StyleParser extends BaseTagParser { + async parse() { + // Refresh if needed + await refreshStyleNamesIfChanged(); + + // Show styles + let tempResults = []; + let matchGroups = tagword.match(STYLE_REGEX); + + // Save index to insert again later or clear last one + lastStyleVarIndex = matchGroups[2] ? matchGroups[2] : ""; + + if (tagword !== matchGroups[1]) { + let searchTerm = tagword.replace(matchGroups[1], ""); + + let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm); + tempResults = styleNames.filter(x => filterCondition(x)); // Filter by tagword + } else { + tempResults = styleNames; + } + + // Add final results + let finalResults = []; + tempResults.forEach(t => { + let result = new AutocompleteResult(t[0].trim(), ResultType.styleName) + result.meta = "Style"; + finalResults.push(result); + }); + + return finalResults; + } +} + +async function load(force = false) { + if (styleNames.length === 0 || force) { + try { + styleNames = (await loadCSV(`${tagBasePath}/temp/styles.txt`)) + .filter(x => x[0]?.trim().length > 0) // Remove empty lines + .filter(x => x[0] !== "None") // Remove "None" style + .map(x => [x[0].trim()]); // Trim name + } catch (e) { + console.error("Error loading styles.txt: " + e); + } + } +} + +function sanitize(tagType, text) { + if (tagType === ResultType.styleName) { + if (text.includes(" ")) { + return `$${lastStyleVarIndex}(${text})`; + } else { + return`$${lastStyleVarIndex}${text}` + } + } + return null; +} + +PARSERS.push(new StyleParser(STYLE_TRIGGER)); + +// Add our utility functions to their respective queues +QUEUE_FILE_LOAD.push(load); +QUEUE_SANITIZE.push(sanitize); \ No newline at end of file diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index b25050e..e03553e 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -221,6 +221,7 @@ async function syncOptions() { showWikiLinks: opts["tac_showWikiLinks"], showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"], modelSortOrder: opts["tac_modelSortOrder"], + useStyleVars: opts["tac_useStyleVars"], // Insertion related settings replaceUnderscores: opts["tac_replaceUnderscores"], escapeParentheses: opts["tac_escapeParentheses"], @@ -399,9 +400,10 @@ function isEnabled() { const WEIGHT_REGEX = /[([]([^()[\]:|]+)(?::(?:\d+(?:\.\d+)?|\.\d+))?[)\]]/g; const POINTY_REGEX = /<[^\s,<](?:[^\t\n\r,<>]*>|[^\t\n\r,> ]*)/g; const COMPLETED_WILDCARD_REGEX = /__[^\s,_][^\t\n\r,_]*[^\s,_]__[^\s,_]*/g; +const STYLE_VAR_REGEX = /\$\(?[^$|\]\s]*\)?/g; const NORMAL_TAG_REGEX = /[^\s,|<>\]:]+_\([^\s,|<>\]:]*\)?|[^\s,|<>():\]]+|?/g; -const TAG_REGEX = new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g"); +const TAG_REGEX = new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source}|${STYLE_VAR_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g"); // On click, insert the tag into the prompt textbox with respect to the cursor position async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithoutChoice = false) { @@ -567,9 +569,11 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength; textArea.selectionEnd = textArea.selectionStart + // Set self trigger flag to show wildcard contents after the filename was inserted + if ([ResultType.wildcardFile, ResultType.yamlWildcard, ResultType.umiWildcard].includes(result.type)) + tacSelfTrigger = true; // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python. // Uses a built-in method from the webui's ui.js which also already accounts for event target - tacSelfTrigger = true; updateInput(textArea); // Update previous tags with the edited prompt to prevent re-searching the same term @@ -993,7 +997,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) { .map(match => match[1]); let tags = prompt.match(TAG_REGEX) if (weightedTags !== null && tags !== null) { - tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted) && !tag.startsWith("<["))) + tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted) && !tag.startsWith("<[") && !tag.startsWith("$("))) .concat(weightedTags); } diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index deccc0e..3a4cefb 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -9,7 +9,7 @@ from pathlib import Path import gradio as gr import yaml from fastapi import FastAPI -from fastapi.responses import FileResponse, JSONResponse +from fastapi.responses import Response, FileResponse, JSONResponse from modules import script_callbacks, sd_hijack, shared, hashes from scripts.model_keyword_support import (get_lora_simple_hash, @@ -307,6 +307,13 @@ def get_lyco(): # Sort return sort_models(lycos_with_hash) +def get_style_names(): + try: + style_names: list[str] = shared.prompt_styles.styles.keys() + style_names = sorted(style_names, key=len, reverse=True) + return style_names + except Exception: + return None def write_tag_base_path(): """Writes the tag base path to a fixed location temporary file""" @@ -361,6 +368,7 @@ write_to_temp_file('umi_tags.txt', []) write_to_temp_file('hyp.txt', []) write_to_temp_file('lora.txt', []) write_to_temp_file('lyco.txt', []) +write_to_temp_file('styles.txt', []) # Only reload embeddings if the file doesn't exist, since they are already re-written on model load if not TEMP_PATH.joinpath("emb.txt").exists(): write_to_temp_file('emb.txt', []) @@ -391,6 +399,11 @@ def refresh_temp_files(*args, **kwargs): write_temp_files(skip_wildcard_refresh) refresh_embeddings(force=True) +def write_style_names(*args, **kwargs): + styles = get_style_names() + if styles: + write_to_temp_file('styles.txt', styles) + def write_temp_files(skip_wildcard_refresh = False): # Write wildcards to wc.txt if found if WILDCARD_PATH.exists() and not skip_wildcard_refresh: @@ -431,6 +444,8 @@ def write_temp_files(skip_wildcard_refresh = False): if model_keyword_installed: update_hash_cache() + if shared.prompt_styles is not None: + write_style_names() write_temp_files() @@ -480,6 +495,7 @@ def on_ui_settings(): "tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"), "tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"), "tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"), + "tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like style-vars to actually apply the styles before generating."), # Insertion related settings "tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"), "tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"), @@ -561,10 +577,18 @@ def on_ui_settings(): script_callbacks.on_ui_settings(on_ui_settings) +def get_style_mtime(): + style_file = getattr(shared, "styles_filename", "styles.csv") + style_file = Path(FILE_DIR).joinpath(style_file) + if Path.exists(style_file): + return style_file.stat().st_mtime + +last_style_mtime = get_style_mtime() + def api_tac(_: gr.Blocks, app: FastAPI): async def get_json_info(base_path: Path, filename: str = None): if base_path is None or (not base_path.exists()): - return JSONResponse({}, status_code=404) + return Response(status_code=404) try: json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True) @@ -575,7 +599,7 @@ def api_tac(_: gr.Blocks, app: FastAPI): async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False): if base_path is None or (not base_path.exists()): - return JSONResponse({}, status_code=404) + return Response(status_code=404) try: img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True) @@ -639,20 +663,34 @@ def api_tac(_: gr.Blocks, app: FastAPI): @app.get("/tacapi/v1/wildcard-contents") async def get_wildcard_contents(basepath: str, filename: str): if basepath is None or basepath == "": - return JSONResponse({}, status_code=404) + return Response(status_code=404) base = Path(basepath) if base is None or (not base.exists()): - return JSONResponse({}, status_code=404) + return Response(status_code=404) try: wildcard_path = base.joinpath(filename) if wildcard_path.exists() and wildcard_path.is_file(): return FileResponse(wildcard_path) else: - return JSONResponse({}, status_code=404) + return Response(status_code=404) except Exception as e: return JSONResponse({"error": e}, status_code=500) + @app.get("/tacapi/v1/refresh-styles-if-changed") + async def refresh_styles_if_changed(): + global last_style_mtime + + mtime = get_style_mtime() + if mtime > last_style_mtime: + last_style_mtime = mtime + # Update temp file + if shared.prompt_styles is not None: + write_style_names() + + return Response(status_code=200) # Success + else: + return Response(status_code=304) # Not modified script_callbacks.on_app_started(api_tac)