Compare commits

...

5 Commits

Author SHA1 Message Date
DominikDoom
5db035cc3a Add missing comma for keyword insertion at end 2023-09-09 14:54:11 +02:00
DominikDoom
90cf3147fd Formatting 2023-09-09 14:51:24 +02:00
DominikDoom
4d4f23e551 Formatting 2023-09-09 14:43:55 +02:00
DominikDoom
80b47c61bb Add new setting to choose where keywords get inserted
Closes #232
2023-09-09 14:41:52 +02:00
DominikDoom
57821aae6a Add option to include embeddings in normal search
along with new keybind functionality for quick jumping between sections.
Closes #230
2023-09-07 13:18:04 +02:00
7 changed files with 95 additions and 60 deletions

View File

@@ -36,6 +36,7 @@ let hideBlocked = false;
// Tag selection for keyboard navigation
var selectedTag = null;
var oldSelectedTag = null;
var resultCountBeforeNormalTags = 0;
// Lora keyword undo/redo history
var textBeforeKeywordInsertion = "";

View File

@@ -41,7 +41,7 @@ function parseCSV(str) {
async function readFile(filePath, json = false, cache = false) {
if (!cache)
filePath += `?${new Date().getTime()}`;
let response = await fetch(`file=${filePath}`);
if (response.status != 200) {
@@ -146,7 +146,7 @@ function flatten(obj, roots = [], sep = ".") {
{}
);
}
// Sliding window function to get possible combination groups of an array
function toNgrams(inputArray, size) {

View File

@@ -1,5 +1,5 @@
const EMB_REGEX = /<(?!l:|h:|c:)[^,> ]*>?/g;
const EMB_TRIGGER = () => TAC_CFG.useEmbeddings && tagword.match(EMB_REGEX);
const EMB_TRIGGER = () => TAC_CFG.useEmbeddings && (tagword.match(EMB_REGEX) || TAC_CFG.includeEmbeddingsInNormalResults);
class EmbeddingParser extends BaseTagParser {
parse() {

View File

@@ -7,7 +7,7 @@ class UmiParser extends BaseTagParser {
parse(textArea, prompt) {
// We are in a UMI yaml tag definition, parse further
let umiSubPrompts = [...prompt.matchAll(UMI_PROMPT_REGEX)];
let umiTags = [];
let umiTagsWithOperators = []
@@ -15,7 +15,7 @@ class UmiParser extends BaseTagParser {
umiSubPrompts.forEach(umiSubPrompt => {
umiTags = umiTags.concat([...umiSubPrompt[0].matchAll(UMI_TAG_REGEX)].map(x => x[1].toLowerCase()));
const start = umiSubPrompt.index;
const end = umiSubPrompt.index + umiSubPrompt[0].length;
if (textArea.selectionStart >= start && textArea.selectionStart <= end) {
@@ -113,7 +113,7 @@ class UmiParser extends BaseTagParser {
|| !matches.all.includes(x[0])
);
}
if (umiTags.length > 0) {
// Get difference for subprompt
let tagCountChange = umiTags.length - umiPreviousTags.length;
@@ -152,7 +152,7 @@ class UmiParser extends BaseTagParser {
return finalResults;
} else if (showAll) {
let filteredWildcardsSorted = filteredWildcards("");
// Add final results
let finalResults = [];
filteredWildcardsSorted.forEach(t => {
@@ -160,14 +160,14 @@ class UmiParser extends BaseTagParser {
result.count = t[1];
finalResults.push(result);
});
originalTagword = tagword;
tagword = "";
return finalResults;
}
} else {
let filteredWildcardsSorted = filteredWildcards("");
// Add final results
let finalResults = [];
filteredWildcardsSorted.forEach(t => {

View File

@@ -19,7 +19,7 @@ class WildcardParser extends BaseTagParser {
let wcPairs = wcFound || wildcardExtFiles.filter(x => x[1].toLowerCase() === wcFile);
if (!wcPairs) return [];
let wildcards = [];
for (let i = 0; i < wcPairs.length; i++) {
const basePath = wcPairs[i][0];
@@ -86,7 +86,7 @@ class WildcardFileParser extends BaseTagParser {
result = new AutocompleteResult(wcFile[1].trim(), ResultType.wildcardFile);
result.meta = "Wildcard file";
}
finalResults.push(result);
alreadyAdded.set(wcFile[1], true);
});
@@ -132,7 +132,7 @@ async function load() {
// Load the yaml wildcard json file and append it as a wildcard file, appending each key as a path component until we reach the end
yamlWildcards = await readFile(`${tagBasePath}/temp/wc_yaml.json`, true);
// Append each key as a path component until we reach a leaf
Object.keys(yamlWildcards).forEach(file => {
const flattened = flatten(yamlWildcards[file], [], "/");

View File

@@ -211,6 +211,7 @@ async function syncOptions() {
useWildcards: opts["tac_useWildcards"],
sortWildcardResults: opts["tac_sortWildcardResults"],
useEmbeddings: opts["tac_useEmbeddings"],
includeEmbeddingsInNormalResults: opts["tac_includeEmbeddingsInNormalResults"],
useHypernetworks: opts["tac_useHypernetworks"],
useLoras: opts["tac_useLoras"],
useLycos: opts["tac_useLycos"],
@@ -224,6 +225,7 @@ async function syncOptions() {
alwaysSpaceAtEnd: opts["tac_alwaysSpaceAtEnd"],
wildcardCompletionMode: opts["tac_wildcardCompletionMode"],
modelKeywordCompletion: opts["tac_modelKeywordCompletion"],
modelKeywordLocation: opts["tac_modelKeywordLocation"],
// Alias settings
alias: {
searchByAlias: opts["tac_alias.searchByAlias"],
@@ -330,7 +332,7 @@ function showResults(textArea) {
if (TAC_CFG.slidingPopup) {
let caretPosition = getCaretCoordinates(textArea, textArea.selectionEnd).left;
let offset = Math.min(textArea.offsetLeft - textArea.scrollLeft + caretPosition, textArea.offsetWidth - parentDiv.offsetWidth);
parentDiv.style.left = `${offset}px`;
} else {
if (parentDiv.style.left)
@@ -346,9 +348,9 @@ function showResults(textArea) {
function hideResults(textArea) {
let textAreaId = getTextAreaIdentifier(textArea);
let resultsDiv = gradioApp().querySelector('.autocompleteParent' + textAreaId);
if (!resultsDiv) return;
resultsDiv.style.display = "none";
selectedTag = null;
}
@@ -358,12 +360,12 @@ function isEnabled() {
if (TAC_CFG.activeIn.global) {
// Skip check if the current model was not correctly detected, since it could wrongly disable the script otherwise
if (!currentModelName || !currentModelHash) return true;
let modelList = TAC_CFG.activeIn.modelList
.split(",")
.map(x => x.trim())
.filter(x => x.length > 0);
let shortHash = currentModelHash.substring(0, 10);
let modelNameWithoutHash = currentModelName.replace(/\[.*\]$/g, "").trim();
if (TAC_CFG.activeIn.modelListMode.toLowerCase() === "blacklist") {
@@ -493,7 +495,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
keywords = info["activation text"];
}
}
if (!keywords && modelKeywordPath.length > 0 && result.hash && result.hash !== "NOFILE" && result.hash.length > 0) {
let nameDict = modelKeywordDict.get(result.hash);
let names = [result.text + ".safetensors", result.text + ".pt", result.text + ".ckpt"];
@@ -506,7 +508,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
keywords = nameDict.get(name);
}
});
if (!found)
keywords = nameDict.get("none");
}
@@ -514,17 +516,25 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
if (keywords && keywords.length > 0) {
textBeforeKeywordInsertion = newPrompt;
newPrompt = `${keywords}, ${newPrompt}`; // Insert keywords
if (TAC_CFG.modelKeywordLocation === "Start of prompt")
newPrompt = `${keywords}, ${newPrompt}`; // Insert keywords
else if (TAC_CFG.modelKeywordLocation === "End of prompt")
newPrompt = `${newPrompt}, ${keywords}`; // Insert keywords
else {
let keywordStart = prompt[editStart - 1] === " " ? editStart - 1 : editStart;
newPrompt = prompt.substring(0, keywordStart) + `, ${keywords} ${insert}` + prompt.substring(editEnd);
}
textAfterKeywordInsertion = newPrompt;
keywordInsertionUndone = false;
setTimeout(() => lastEditWasKeywordInsertion = true, 200)
keywordsLength = keywords.length + 2; // +2 for the comma and space
}
}
// Insert into prompt textbox and reposition cursor
textArea.value = newPrompt;
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength;
@@ -656,7 +666,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
if (linkPart.includes("[")) {
linkPart = linkPart.split("[")[0]
}
// Set link based on selected file
let tagFileNameLower = tagFileName.toLowerCase();
if (tagFileNameLower.startsWith("danbooru")) {
@@ -664,7 +674,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
} else if (tagFileNameLower.startsWith("e621")) {
wikiLink.href = `https://e621.net/wiki_pages/${linkPart}`;
}
wikiLink.target = "_blank";
flexDiv.appendChild(wikiLink);
}
@@ -717,7 +727,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
else if (result.meta.startsWith("v2"))
itemText.classList.add("acEmbeddingV2");
}
flexDiv.appendChild(metaDiv);
}
@@ -782,7 +792,7 @@ async function updateSelectionStyle(textArea, newIndex, oldIndex) {
}
let img = previewDiv.querySelector("img");
let url = await getExtraNetworkPreviewURL(selectedFilename, shorthandType);
if (url) {
img.src = url;
@@ -808,7 +818,7 @@ function updateRuby(textArea, prompt) {
ruby.setAttribute("class", `acRuby${typeClass} notranslate`);
textArea.parentNode.appendChild(ruby);
}
ruby.innerText = prompt;
let bracketEscapedPrompt = prompt.replaceAll("\\(", "$").replaceAll("\\)", "%");
@@ -826,9 +836,9 @@ function updateRuby(textArea, prompt) {
.replaceAll(" ", "_")
.replaceAll("\\(", "(")
.replaceAll("\\)", ")");
const translation = translations?.get(tag) || translations?.get(unsanitizedTag);
let escapedTag = escapeRegExp(tag);
return { tag, escapedTag, translation };
}
@@ -844,14 +854,14 @@ function updateRuby(textArea, prompt) {
// First try to find direct matches
[...rubyTags].forEach(tag => {
let tuple = prepareTag(tag);
if (tuple.translation) {
html = replaceOccurences(html, tuple);
} else {
let subTags = tuple.tag.split(" ").filter(x => x.trim().length > 0);
// Return if there is only one word
if (subTags.length === 1) return;
let subHtml = tag.replaceAll("$", "\\(").replaceAll("%", "\\)");
let translateNgram = (windows) => {
@@ -866,14 +876,14 @@ function updateRuby(textArea, prompt) {
}
});
}
// Perform n-gram sliding window search
translateNgram(toNgrams(subTags, 3));
translateNgram(toNgrams(subTags, 2));
translateNgram(toNgrams(subTags, 1));
let escapedTag = escapeRegExp(tuple.tag);
let searchRegex = new RegExp(`(?<!<ruby>)(?:\\b)${escapedTag}(?:\\b|$|(?=[,|: \\t\\n\\r]))(?!<rt>)`, "g");
html = html.replaceAll(searchRegex, subHtml);
}
@@ -985,6 +995,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
}
results = [];
resultCountBeforeNormalTags = 0;
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
// Process all parsers
@@ -1024,7 +1035,13 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
});
}
}
} else { // Else search the normal tag list
}
// Else search the normal tag list
if (!resultCandidates || resultCandidates.length === 0
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
) {
resultCountBeforeNormalTags = results.length;
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
@@ -1039,7 +1056,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
let aliasFilter = (x) => x[3] && x[3].toLowerCase().search(searchRegex) > -1;
let translationFilter = (x) => (translations.has(x[0]) && translations.get(x[0]).toLowerCase().search(searchRegex) > -1)
|| x[3] && x[3].split(",").some(y => translations.has(y) && translations.get(y).toLowerCase().search(searchRegex) > -1);
let fil;
if (TAC_CFG.alias.searchByAlias && TAC_CFG.translation.searchByTranslation)
fil = (x) => baseFilter(x) || aliasFilter(x) || translationFilter(x);
@@ -1077,10 +1094,10 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
results = results.concat(extraResults);
}
}
// Slice if the user has set a max result count
if (!TAC_CFG.showAllResults) {
results = results.slice(0, TAC_CFG.maxResults);
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
}
@@ -1098,7 +1115,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
function navigateInList(textArea, event) {
// Return if the function is deactivated in the UI or the current model is excluded due to white/blacklist settings
if (!isEnabled()) return;
let keys = TAC_CFG.keymap;
// Close window if Home or End is pressed while not a keybinding, since it would break completion on leaving the original tag
@@ -1148,10 +1165,25 @@ function navigateInList(textArea, event) {
}
break;
case keys["JumpToStart"]:
selectedTag = 0;
if (TAC_CFG.includeEmbeddingsInNormalResults &&
selectedTag > resultCountBeforeNormalTags &&
resultCountBeforeNormalTags > 0
) {
selectedTag = resultCountBeforeNormalTags;
} else {
selectedTag = 0;
}
break;
case keys["JumpToEnd"]:
selectedTag = resultCount - 1;
// Jump to the end of the list, or the end of embeddings if they are included in the normal results
if (TAC_CFG.includeEmbeddingsInNormalResults &&
selectedTag < resultCountBeforeNormalTags &&
resultCountBeforeNormalTags > 0
) {
selectedTag = Math.min(resultCountBeforeNormalTags, resultCount - 1);
} else {
selectedTag = resultCount - 1;
}
break;
case keys["ChooseSelected"]:
if (selectedTag !== null) {
@@ -1320,7 +1352,7 @@ async function setup() {
let mode = (document.querySelector(".dark") || gradioApp().querySelector(".dark")) ? 0 : 1;
// Check if we are on webkit
let browser = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 ? "firefox" : "other";
let css = autocompleteCSS;
// Replace vars with actual values (can't use actual css vars because of the way we inject the css)
Object.keys(styleColors).forEach((key) => {
@@ -1329,7 +1361,7 @@ async function setup() {
Object.keys(browserVars).forEach((key) => {
css = css.replaceAll(`var(${key})`, browserVars[key][browser]);
})
if (acStyle.styleSheet) {
acStyle.styleSheet.cssText = css;
} else {

View File

@@ -56,7 +56,7 @@ def parse_umi_format(umi_tags, count, data):
for item in data:
umi_tags[count] = ','.join(data[item]['Tags'])
count += 1
def parse_dynamic_prompt_format(yaml_wildcards, data, path):
# Recurse subkeys, delete those without string lists as values
@@ -66,7 +66,7 @@ def parse_dynamic_prompt_format(yaml_wildcards, data, path):
recurse_dict(value)
elif not (isinstance(value, list) and all(isinstance(v, str) for v in value)):
del d[key]
recurse_dict(data)
# Add to yaml_wildcards
yaml_wildcards[path.name] = data
@@ -98,13 +98,13 @@ def get_yaml_wildcards():
except yaml.YAMLError:
print('Issue in parsing YAML file ' + path.name)
continue
# Sort by count
umi_sorted = sorted(umi_tags.items(), key=lambda item: item[1], reverse=True)
umi_output = []
for tag, count in umi_sorted:
umi_output.append(f"{tag},{count}")
if (len(umi_output) > 0):
write_to_temp_file('umi_tags.txt', umi_output)
@@ -204,7 +204,7 @@ def get_lyco():
# Get a list of all LyCORIS in the folder
lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)]
# Get hashes
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
@@ -214,7 +214,7 @@ def get_lyco():
hashes[name] = get_lora_simple_hash(ly)
else:
hashes[name] = ""
# Sort
sorted_lycos = dict(sorted(hashes.items()))
# Add hashes and return
@@ -318,7 +318,7 @@ def write_temp_files():
lora = get_lora()
if lora:
write_to_temp_file('lora.txt', lora)
lyco_exists = LYCO_PATH is not None and LYCO_PATH.exists()
if lyco_exists and not (lora_exists and LYCO_PATH.samefile(LORA_PATH)):
lyco = get_lyco()
@@ -369,6 +369,7 @@ def on_ui_settings():
"tac_useWildcards": shared.OptionInfo(True, "Search for wildcards"),
"tac_sortWildcardResults": shared.OptionInfo(True, "Sort wildcard file contents alphabetically").info("If your wildcard files have a specific custom order, disable this to keep it"),
"tac_useEmbeddings": shared.OptionInfo(True, "Search for embeddings"),
"tac_includeEmbeddingsInNormalResults": shared.OptionInfo(False, "Include embeddings in normal tag results").info("The 'JumpTo...' keybinds (End & Home key by default) will select the first non-embedding result of their direction on the first press for quick navigation in longer lists."),
"tac_useHypernetworks": shared.OptionInfo(True, "Search for hypernetworks"),
"tac_useLoras": shared.OptionInfo(True, "Search for Loras"),
"tac_useLycos": shared.OptionInfo(True, "Search for LyCORIS/LoHa"),
@@ -381,6 +382,7 @@ def on_ui_settings():
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
"tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"),
"tac_modelKeywordCompletion": shared.OptionInfo("Never", "Try to add known trigger words for LORA/LyCO models", gr.Dropdown, lambda: {"choices": ["Never","Only user list","Always"]}).info("Will use & prefer the native activation keywords settable in the extra networks UI. Other functionality requires the <a href=\"https://github.com/mix1009/model-keyword\" target=\"_blank\">model-keyword</a> extension to be installed, but will work with it disabled.").needs_restart(),
"tac_modelKeywordLocation": shared.OptionInfo("Start of prompt", "Where to insert the trigger keyword", gr.Dropdown, lambda: {"choices": ["Start of prompt","End of prompt","Before LORA/LyCO"]}).info("Only relevant if the above option is enabled"),
"tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""),
# Alias settings
"tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),
@@ -451,25 +453,25 @@ def on_ui_settings():
shared.opts.add_option("tac_colormap", shared.OptionInfo(colorDefault, colorLabel, gr.Textbox, section=TAC_SECTION))
shared.opts.add_option("tac_refreshTempFiles", shared.OptionInfo("Refresh TAC temp files", "Refresh internal temp files", gr.HTML, {}, refresh=refresh_temp_files, section=TAC_SECTION))
script_callbacks.on_ui_settings(on_ui_settings)
def api_tac(_: gr.Blocks, app: FastAPI):
async def get_json_info(base_path: Path, filename: str = None):
if base_path is None or (not base_path.exists()):
return JSONResponse({}, status_code=404)
try:
json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True)
if json_candidates is not None and len(json_candidates) > 0:
return FileResponse(json_candidates[0])
except Exception as e:
return JSONResponse({"error": e}, status_code=500)
async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False):
if base_path is None or (not base_path.exists()):
return JSONResponse({}, status_code=404)
try:
img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True)
img_candidates = [img for img in img_glob if Path(img).suffix in [".png", ".jpg", ".jpeg", ".webp", ".gif"]]
@@ -484,11 +486,11 @@ def api_tac(_: gr.Blocks, app: FastAPI):
@app.get("/tacapi/v1/lora-info/{lora_name}")
async def get_lora_info(lora_name):
return await get_json_info(LORA_PATH, lora_name)
@app.get("/tacapi/v1/lyco-info/{lyco_name}")
async def get_lyco_info(lyco_name):
return await get_json_info(LYCO_PATH, lyco_name)
def get_path_for_type(type):
if type == "lora":
return LORA_PATH
@@ -504,11 +506,11 @@ def api_tac(_: gr.Blocks, app: FastAPI):
@app.get("/tacapi/v1/thumb-preview/{filename}")
async def get_thumb_preview(filename, type):
return await get_preview_thumbnail(get_path_for_type(type), filename, False)
@app.get("/tacapi/v1/thumb-preview-blob/{filename}")
async def get_thumb_preview_blob(filename, type):
return await get_preview_thumbnail(get_path_for_type(type), filename, True)
@app.get("/tacapi/v1/wildcard-contents")
async def get_wildcard_contents(basepath: str, filename: str):
if basepath is None or basepath == "":
@@ -517,7 +519,7 @@ def api_tac(_: gr.Blocks, app: FastAPI):
base = Path(basepath)
if base is None or (not base.exists()):
return JSONResponse({}, status_code=404)
try:
wildcard_path = base.joinpath(filename)
if wildcard_path.exists():