mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 11:09:54 +00:00
Added option to autocomplete style names
To be used in tandem with https://github.com/SirVeggie/extension-style-vars Closes #268
This commit is contained in:
@@ -19,6 +19,7 @@ var loras = [];
|
||||
var lycos = [];
|
||||
var modelKeywordDict = new Map();
|
||||
var chants = [];
|
||||
var styleNames = [];
|
||||
|
||||
// Selected model info for black/whitelisting
|
||||
var currentModelHash = "";
|
||||
|
||||
@@ -12,7 +12,8 @@ const ResultType = Object.freeze({
|
||||
"hypernetwork": 8,
|
||||
"lora": 9,
|
||||
"lyco": 10,
|
||||
"chant": 11
|
||||
"chant": 11,
|
||||
"styleName": 12
|
||||
});
|
||||
|
||||
// Class to hold result data and annotations to make it clearer to use
|
||||
|
||||
@@ -108,6 +108,29 @@ async function getExtraNetworkPreviewURL(filename, type) {
|
||||
}
|
||||
}
|
||||
|
||||
lastStyleRefresh = 0;
|
||||
// Refresh style file if needed
|
||||
async function refreshStyleNamesIfChanged() {
|
||||
// Only refresh once per second
|
||||
currentTimestamp = new Date().getTime();
|
||||
if (currentTimestamp - lastStyleRefresh < 1000) return;
|
||||
lastStyleRefresh = currentTimestamp;
|
||||
|
||||
const response = await fetch(`tacapi/v1/refresh-styles-if-changed?${new Date().getTime()}`)
|
||||
if (response.status === 304) {
|
||||
// Not modified
|
||||
} else if (response.status === 200) {
|
||||
// Reload
|
||||
QUEUE_FILE_LOAD.forEach(async fn => {
|
||||
if (fn.toString().includes("styleNames"))
|
||||
await fn.call(null, true);
|
||||
})
|
||||
} else {
|
||||
// Error
|
||||
console.error(`Error refreshing styles.txt: ` + response.status, response.statusText);
|
||||
}
|
||||
}
|
||||
|
||||
// Debounce function to prevent spamming the autocomplete function
|
||||
var dbTimeOut;
|
||||
const debounce = (func, wait = 300) => {
|
||||
|
||||
67
javascript/ext_styles.js
Normal file
67
javascript/ext_styles.js
Normal file
@@ -0,0 +1,67 @@
|
||||
const STYLE_REGEX = /(\$(\d*)\(?)[^$|\]\s]*\)?/;
|
||||
const STYLE_TRIGGER = () => TAC_CFG.useStyleVars && tagword.match(STYLE_REGEX);
|
||||
|
||||
var lastStyleVarIndex = "";
|
||||
|
||||
class StyleParser extends BaseTagParser {
|
||||
async parse() {
|
||||
// Refresh if needed
|
||||
await refreshStyleNamesIfChanged();
|
||||
|
||||
// Show styles
|
||||
let tempResults = [];
|
||||
let matchGroups = tagword.match(STYLE_REGEX);
|
||||
|
||||
// Save index to insert again later or clear last one
|
||||
lastStyleVarIndex = matchGroups[2] ? matchGroups[2] : "";
|
||||
|
||||
if (tagword !== matchGroups[1]) {
|
||||
let searchTerm = tagword.replace(matchGroups[1], "");
|
||||
|
||||
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
tempResults = styleNames.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = styleNames;
|
||||
}
|
||||
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.styleName)
|
||||
result.meta = "Style";
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
return finalResults;
|
||||
}
|
||||
}
|
||||
|
||||
async function load(force = false) {
|
||||
if (styleNames.length === 0 || force) {
|
||||
try {
|
||||
styleNames = (await loadCSV(`${tagBasePath}/temp/styles.txt`))
|
||||
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
|
||||
.filter(x => x[0] !== "None") // Remove "None" style
|
||||
.map(x => [x[0].trim()]); // Trim name
|
||||
} catch (e) {
|
||||
console.error("Error loading styles.txt: " + e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function sanitize(tagType, text) {
|
||||
if (tagType === ResultType.styleName) {
|
||||
if (text.includes(" ")) {
|
||||
return `$${lastStyleVarIndex}(${text})`;
|
||||
} else {
|
||||
return`$${lastStyleVarIndex}${text}`
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
PARSERS.push(new StyleParser(STYLE_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
@@ -221,6 +221,7 @@ async function syncOptions() {
|
||||
showWikiLinks: opts["tac_showWikiLinks"],
|
||||
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
|
||||
modelSortOrder: opts["tac_modelSortOrder"],
|
||||
useStyleVars: opts["tac_useStyleVars"],
|
||||
// Insertion related settings
|
||||
replaceUnderscores: opts["tac_replaceUnderscores"],
|
||||
escapeParentheses: opts["tac_escapeParentheses"],
|
||||
@@ -399,9 +400,10 @@ function isEnabled() {
|
||||
const WEIGHT_REGEX = /[([]([^()[\]:|]+)(?::(?:\d+(?:\.\d+)?|\.\d+))?[)\]]/g;
|
||||
const POINTY_REGEX = /<[^\s,<](?:[^\t\n\r,<>]*>|[^\t\n\r,> ]*)/g;
|
||||
const COMPLETED_WILDCARD_REGEX = /__[^\s,_][^\t\n\r,_]*[^\s,_]__[^\s,_]*/g;
|
||||
const STYLE_VAR_REGEX = /\$\(?[^$|\]\s]*\)?/g;
|
||||
const NORMAL_TAG_REGEX = /[^\s,|<>\]:]+_\([^\s,|<>\]:]*\)?|[^\s,|<>():\]]+|</g;
|
||||
const RUBY_TAG_REGEX = /[\w\d<][\w\d' \-?!/$%]{2,}>?/g;
|
||||
const TAG_REGEX = new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g");
|
||||
const TAG_REGEX = new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source}|${STYLE_VAR_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g");
|
||||
|
||||
// On click, insert the tag into the prompt textbox with respect to the cursor position
|
||||
async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithoutChoice = false) {
|
||||
@@ -567,9 +569,11 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength;
|
||||
textArea.selectionEnd = textArea.selectionStart
|
||||
|
||||
// Set self trigger flag to show wildcard contents after the filename was inserted
|
||||
if ([ResultType.wildcardFile, ResultType.yamlWildcard, ResultType.umiWildcard].includes(result.type))
|
||||
tacSelfTrigger = true;
|
||||
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
|
||||
// Uses a built-in method from the webui's ui.js which also already accounts for event target
|
||||
tacSelfTrigger = true;
|
||||
updateInput(textArea);
|
||||
|
||||
// Update previous tags with the edited prompt to prevent re-searching the same term
|
||||
@@ -993,7 +997,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
.map(match => match[1]);
|
||||
let tags = prompt.match(TAG_REGEX)
|
||||
if (weightedTags !== null && tags !== null) {
|
||||
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted) && !tag.startsWith("<[")))
|
||||
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted) && !tag.startsWith("<[") && !tag.startsWith("$(")))
|
||||
.concat(weightedTags);
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from pathlib import Path
|
||||
import gradio as gr
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.responses import Response, FileResponse, JSONResponse
|
||||
from modules import script_callbacks, sd_hijack, shared, hashes
|
||||
|
||||
from scripts.model_keyword_support import (get_lora_simple_hash,
|
||||
@@ -307,6 +307,13 @@ def get_lyco():
|
||||
# Sort
|
||||
return sort_models(lycos_with_hash)
|
||||
|
||||
def get_style_names():
|
||||
try:
|
||||
style_names: list[str] = shared.prompt_styles.styles.keys()
|
||||
style_names = sorted(style_names, key=len, reverse=True)
|
||||
return style_names
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def write_tag_base_path():
|
||||
"""Writes the tag base path to a fixed location temporary file"""
|
||||
@@ -361,6 +368,7 @@ write_to_temp_file('umi_tags.txt', [])
|
||||
write_to_temp_file('hyp.txt', [])
|
||||
write_to_temp_file('lora.txt', [])
|
||||
write_to_temp_file('lyco.txt', [])
|
||||
write_to_temp_file('styles.txt', [])
|
||||
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
|
||||
if not TEMP_PATH.joinpath("emb.txt").exists():
|
||||
write_to_temp_file('emb.txt', [])
|
||||
@@ -391,6 +399,11 @@ def refresh_temp_files(*args, **kwargs):
|
||||
write_temp_files(skip_wildcard_refresh)
|
||||
refresh_embeddings(force=True)
|
||||
|
||||
def write_style_names(*args, **kwargs):
|
||||
styles = get_style_names()
|
||||
if styles:
|
||||
write_to_temp_file('styles.txt', styles)
|
||||
|
||||
def write_temp_files(skip_wildcard_refresh = False):
|
||||
# Write wildcards to wc.txt if found
|
||||
if WILDCARD_PATH.exists() and not skip_wildcard_refresh:
|
||||
@@ -431,6 +444,8 @@ def write_temp_files(skip_wildcard_refresh = False):
|
||||
if model_keyword_installed:
|
||||
update_hash_cache()
|
||||
|
||||
if shared.prompt_styles is not None:
|
||||
write_style_names()
|
||||
|
||||
write_temp_files()
|
||||
|
||||
@@ -480,6 +495,7 @@ def on_ui_settings():
|
||||
"tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"),
|
||||
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
|
||||
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
|
||||
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
|
||||
# Insertion related settings
|
||||
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
|
||||
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
|
||||
@@ -561,10 +577,18 @@ def on_ui_settings():
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
||||
def get_style_mtime():
|
||||
style_file = getattr(shared, "styles_filename", "styles.csv")
|
||||
style_file = Path(FILE_DIR).joinpath(style_file)
|
||||
if Path.exists(style_file):
|
||||
return style_file.stat().st_mtime
|
||||
|
||||
last_style_mtime = get_style_mtime()
|
||||
|
||||
def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
async def get_json_info(base_path: Path, filename: str = None):
|
||||
if base_path is None or (not base_path.exists()):
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True)
|
||||
@@ -575,7 +599,7 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
|
||||
async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False):
|
||||
if base_path is None or (not base_path.exists()):
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True)
|
||||
@@ -639,20 +663,34 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/tacapi/v1/wildcard-contents")
|
||||
async def get_wildcard_contents(basepath: str, filename: str):
|
||||
if basepath is None or basepath == "":
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
base = Path(basepath)
|
||||
if base is None or (not base.exists()):
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
wildcard_path = base.joinpath(filename)
|
||||
if wildcard_path.exists() and wildcard_path.is_file():
|
||||
return FileResponse(wildcard_path)
|
||||
else:
|
||||
return JSONResponse({}, status_code=404)
|
||||
return Response(status_code=404)
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": e}, status_code=500)
|
||||
|
||||
@app.get("/tacapi/v1/refresh-styles-if-changed")
|
||||
async def refresh_styles_if_changed():
|
||||
global last_style_mtime
|
||||
|
||||
mtime = get_style_mtime()
|
||||
if mtime > last_style_mtime:
|
||||
last_style_mtime = mtime
|
||||
# Update temp file
|
||||
if shared.prompt_styles is not None:
|
||||
write_style_names()
|
||||
|
||||
return Response(status_code=200) # Success
|
||||
else:
|
||||
return Response(status_code=304) # Not modified
|
||||
|
||||
script_callbacks.on_app_started(api_tac)
|
||||
|
||||
Reference in New Issue
Block a user