Compare commits

...

22 Commits

Author SHA1 Message Date
DominikDoom
5db035cc3a Add missing comma for keyword insertion at end 2023-09-09 14:54:11 +02:00
DominikDoom
90cf3147fd Formatting 2023-09-09 14:51:24 +02:00
DominikDoom
4d4f23e551 Formatting 2023-09-09 14:43:55 +02:00
DominikDoom
80b47c61bb Add new setting to choose where keywords get inserted
Closes #232
2023-09-09 14:41:52 +02:00
DominikDoom
57821aae6a Add option to include embeddings in normal search
along with new keybind functionality for quick jumping between sections.
Closes #230
2023-09-07 13:18:04 +02:00
DominikDoom
e23bb6d4ea Add support for --wildcards-dir cmd argument
Refactor PR #229 a bit to share code with this
2023-09-02 17:59:27 +02:00
DominikDoom
d4cca00575 Merge pull request #229 from azmodii/main 2023-09-02 17:27:00 +02:00
DominikDoom
86ea94a565 Merge pull request #228 from re-unknown/main 2023-09-02 17:08:13 +02:00
Joel Clark
53f46c91a2 feat: Allow support for custom wildcard directory in sd-dynamic-prompts 2023-09-02 21:30:01 +10:00
ReUnknown
e5f93188c3 Support for updated style editor 2023-09-02 16:51:41 +09:00
DominikDoom
3e57842ac6 Remove unnecessary autocomplete call in wildcards
which would result in duplicate file requests
2023-08-29 10:23:13 +02:00
DominikDoom
32c4589df3 Rework wildcards to use own API endpoint
Maybe fixes #226
2023-08-29 09:39:32 +02:00
DominikDoom
5bbd97588c Remove duplicate slash from wildcard files
(should be cosmetic only)
2023-08-28 19:15:34 +02:00
DominikDoom
b2a663f7a7 Merge pull request #223 from Symbiomatrix/embload 2023-08-20 19:08:40 +02:00
Symbiomatrix
6f93d19a2b Edit error message. 2023-08-20 20:02:57 +03:00
Symbiomatrix
79bab04fd2 Typo. 2023-08-20 18:59:12 +03:00
Symbiomatrix
5b69d1e622 Embedding forced reload. 2023-08-20 18:51:37 +03:00
DominikDoom
651cf5fb46 Add metaKey and Shift to non-captured modifiers
Fixes #222
2023-08-19 11:59:41 +02:00
DominikDoom
5deb72cddf Add clearer README description for legacy translations
As suggested in #221
2023-08-16 10:50:05 +02:00
DominikDoom
97ebe78205 !After Detailer (adetailer) support 2023-08-15 14:44:38 +02:00
DominikDoom
b937e853c9 Fix booru wiki links with translations 2023-08-08 19:23:13 +02:00
DominikDoom
f63bbf947f Fix API endpoint to work with symlinks / external folders
Fixes #217
2023-08-07 22:15:48 +02:00
10 changed files with 218 additions and 92 deletions

View File

@@ -474,7 +474,9 @@ You can also add this to your quicksettings bar to have the refresh button avail
# Translations
An additional file can be added in the translation section, which will be used to translate both tags and aliases and also enables searching by translation.
This file needs to be a CSV in the format `<English tag/alias>,<Translation>`, but for backwards compatibility with older files that used a three column format, you can turn on `Translation file uses old 3-column translation format instead of the new 2-column one` to support them. In that case, the second column will be unused and skipped during parsing.
This file needs to be a CSV in the format `<English tag/alias>,<Translation>`. Some older files use a three column format, which requires a compatibility setting to be activated.
You can find it under `Settings > Tag autocomplete > Translation filename > Translation file uses old 3-column translation format instead of the new 2-column one`.
With it on, the second column will be unused and skipped during parsing.
Example with Chinese translation:

View File

@@ -36,6 +36,7 @@ let hideBlocked = false;
// Tag selection for keyboard navigation
var selectedTag = null;
var oldSelectedTag = null;
var resultCountBeforeNormalTags = 0;
// Lora keyword undo/redo history
var textBeforeKeywordInsertion = "";

View File

@@ -9,7 +9,11 @@ const core = [
"#img2img_prompt > label > textarea",
"#txt2img_neg_prompt > label > textarea",
"#img2img_neg_prompt > label > textarea",
".prompt > label > textarea"
".prompt > label > textarea",
"#txt2img_edit_style_prompt > label > textarea",
"#txt2img_edit_style_neg_prompt > label > textarea",
"#img2img_edit_style_prompt > label > textarea",
"#img2img_edit_style_neg_prompt > label > textarea"
];
// Third party text area selectors
@@ -57,6 +61,24 @@ const thirdParty = {
"[id^=MD-i2i][id$=prompt] textarea",
"[id^=MD-i2i][id$=prompt] input[type='text']"
]
},
"adetailer-t2i": {
"base": "#txt2img_script_container",
"hasIds": true,
"onDemand": true,
"selectors": [
"[id^=script_txt2img_adetailer_ad_prompt] textarea",
"[id^=script_txt2img_adetailer_ad_negative_prompt] textarea"
]
},
"adetailer-i2i": {
"base": "#img2img_script_container",
"hasIds": true,
"onDemand": true,
"selectors": [
"[id^=script_img2img_adetailer_ad_prompt] textarea",
"[id^=script_img2img_adetailer_ad_negative_prompt] textarea"
]
}
}
@@ -90,7 +112,7 @@ function addOnDemandObservers(setupFunction) {
let base = gradioApp().querySelector(entry.base);
if (!base) continue;
let accordions = [...base?.querySelectorAll(".gradio-accordion")];
if (!accordions) continue;
@@ -115,12 +137,12 @@ function addOnDemandObservers(setupFunction) {
[...gradioApp().querySelectorAll(entry.selectors.join(", "))].forEach(x => setupFunction(x));
} else { // Otherwise, we have to find the text areas by their adjacent labels
let base = gradioApp().querySelector(entry.base);
// Safety check
if (!base) continue;
let allTextAreas = [...base.querySelectorAll("textarea, input[type='text']")];
// Filter the text areas where the adjacent label matches one of the selectors
let matchingTextAreas = allTextAreas.filter(ta => [...ta.parentElement.childNodes].some(x => entry.selectors.includes(x.innerText)));
matchingTextAreas.forEach(x => setupFunction(x));

View File

@@ -41,7 +41,7 @@ function parseCSV(str) {
async function readFile(filePath, json = false, cache = false) {
if (!cache)
filePath += `?${new Date().getTime()}`;
let response = await fetch(`file=${filePath}`);
if (response.status != 200) {
@@ -84,10 +84,18 @@ async function fetchAPI(url, json = true, cache = false) {
// Extra network preview thumbnails
async function getExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
if (previewJSON?.url)
return `file=${previewJSON.url}`;
else
if (previewJSON?.url) {
const properURL = `sd_extra_networks/thumb?filename=${previewJSON.url}`;
if ((await fetch(properURL)).status == 200) {
return properURL;
} else {
// create blob url
const blob = await (await fetch(`tacapi/v1/thumb-preview-blob/${filename}?type=${type}`)).blob();
return URL.createObjectURL(blob);
}
} else {
return null;
}
}
// Debounce function to prevent spamming the autocomplete function
@@ -138,7 +146,7 @@ function flatten(obj, roots = [], sep = ".") {
{}
);
}
// Sliding window function to get possible combination groups of an array
function toNgrams(inputArray, size) {

View File

@@ -1,5 +1,5 @@
const EMB_REGEX = /<(?!l:|h:|c:)[^,> ]*>?/g;
const EMB_TRIGGER = () => TAC_CFG.useEmbeddings && tagword.match(EMB_REGEX);
const EMB_TRIGGER = () => TAC_CFG.useEmbeddings && (tagword.match(EMB_REGEX) || TAC_CFG.includeEmbeddingsInNormalResults);
class EmbeddingParser extends BaseTagParser {
parse() {

View File

@@ -7,7 +7,7 @@ class UmiParser extends BaseTagParser {
parse(textArea, prompt) {
// We are in a UMI yaml tag definition, parse further
let umiSubPrompts = [...prompt.matchAll(UMI_PROMPT_REGEX)];
let umiTags = [];
let umiTagsWithOperators = []
@@ -15,7 +15,7 @@ class UmiParser extends BaseTagParser {
umiSubPrompts.forEach(umiSubPrompt => {
umiTags = umiTags.concat([...umiSubPrompt[0].matchAll(UMI_TAG_REGEX)].map(x => x[1].toLowerCase()));
const start = umiSubPrompt.index;
const end = umiSubPrompt.index + umiSubPrompt[0].length;
if (textArea.selectionStart >= start && textArea.selectionStart <= end) {
@@ -113,7 +113,7 @@ class UmiParser extends BaseTagParser {
|| !matches.all.includes(x[0])
);
}
if (umiTags.length > 0) {
// Get difference for subprompt
let tagCountChange = umiTags.length - umiPreviousTags.length;
@@ -152,7 +152,7 @@ class UmiParser extends BaseTagParser {
return finalResults;
} else if (showAll) {
let filteredWildcardsSorted = filteredWildcards("");
// Add final results
let finalResults = [];
filteredWildcardsSorted.forEach(t => {
@@ -160,14 +160,14 @@ class UmiParser extends BaseTagParser {
result.count = t[1];
finalResults.push(result);
});
originalTagword = tagword;
tagword = "";
return finalResults;
}
} else {
let filteredWildcardsSorted = filteredWildcards("");
// Add final results
let finalResults = [];
filteredWildcardsSorted.forEach(t => {

View File

@@ -19,13 +19,16 @@ class WildcardParser extends BaseTagParser {
let wcPairs = wcFound || wildcardExtFiles.filter(x => x[1].toLowerCase() === wcFile);
if (!wcPairs) return [];
let wildcards = [];
for (let i = 0; i < wcPairs.length; i++) {
const wcPair = wcPairs[i];
if (!wcPair[0] || !wcPair[1]) continue;
const basePath = wcPairs[i][0];
const fileName = wcPairs[i][1];
if (!basePath || !fileName) return;
if (wcPair[0].endsWith(".yaml")) {
// YAML wildcards are already loaded as json, so we can get the values directly.
// basePath is the name of the file in this case, and fileName the key
if (basePath.endsWith(".yaml")) {
const getDescendantProp = (obj, desc) => {
const arr = desc.split("/");
while (arr.length) {
@@ -33,10 +36,11 @@ class WildcardParser extends BaseTagParser {
}
return obj;
}
wildcards = wildcards.concat(getDescendantProp(yamlWildcards[wcPair[0]], wcPair[1]));
wildcards = wildcards.concat(getDescendantProp(yamlWildcards[basePath], fileName));
} else {
const fileContent = (await readFile(`${wcPair[0]}/${wcPair[1]}.txt`)).split("\n")
.filter(x => x.trim().length > 0 && !x.startsWith('#')); // Remove empty lines and comments
const fileContent = (await fetchAPI(`tacapi/v1/wildcard-contents?basepath=${basePath}&filename=${fileName}.txt`, false))
.split("\n")
.filter(x => x.trim().length > 0 && !x.startsWith('#')); // Remove empty lines and comments
wildcards = wildcards.concat(fileContent);
}
}
@@ -82,7 +86,7 @@ class WildcardFileParser extends BaseTagParser {
result = new AutocompleteResult(wcFile[1].trim(), ResultType.wildcardFile);
result.meta = "Wildcard file";
}
finalResults.push(result);
alreadyAdded.set(wcFile[1], true);
});
@@ -128,7 +132,7 @@ async function load() {
// Load the yaml wildcard json file and append it as a wildcard file, appending each key as a path component until we reach the end
yamlWildcards = await readFile(`${tagBasePath}/temp/wc_yaml.json`, true);
// Append each key as a path component until we reach a leaf
Object.keys(yamlWildcards).forEach(file => {
const flattened = flatten(yamlWildcards[file], [], "/");
@@ -155,7 +159,6 @@ function keepOpenIfWildcard(tagType, sanitizedText, newPrompt, textArea) {
// If it's a wildcard, we want to keep the results open so the user can select another wildcard
if (tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard) {
hideBlocked = true;
autocomplete(textArea, newPrompt, sanitizedText);
setTimeout(() => { hideBlocked = false; }, 450);
return true;
}

View File

@@ -211,6 +211,7 @@ async function syncOptions() {
useWildcards: opts["tac_useWildcards"],
sortWildcardResults: opts["tac_sortWildcardResults"],
useEmbeddings: opts["tac_useEmbeddings"],
includeEmbeddingsInNormalResults: opts["tac_includeEmbeddingsInNormalResults"],
useHypernetworks: opts["tac_useHypernetworks"],
useLoras: opts["tac_useLoras"],
useLycos: opts["tac_useLycos"],
@@ -224,6 +225,7 @@ async function syncOptions() {
alwaysSpaceAtEnd: opts["tac_alwaysSpaceAtEnd"],
wildcardCompletionMode: opts["tac_wildcardCompletionMode"],
modelKeywordCompletion: opts["tac_modelKeywordCompletion"],
modelKeywordLocation: opts["tac_modelKeywordLocation"],
// Alias settings
alias: {
searchByAlias: opts["tac_alias.searchByAlias"],
@@ -330,7 +332,7 @@ function showResults(textArea) {
if (TAC_CFG.slidingPopup) {
let caretPosition = getCaretCoordinates(textArea, textArea.selectionEnd).left;
let offset = Math.min(textArea.offsetLeft - textArea.scrollLeft + caretPosition, textArea.offsetWidth - parentDiv.offsetWidth);
parentDiv.style.left = `${offset}px`;
} else {
if (parentDiv.style.left)
@@ -346,9 +348,9 @@ function showResults(textArea) {
function hideResults(textArea) {
let textAreaId = getTextAreaIdentifier(textArea);
let resultsDiv = gradioApp().querySelector('.autocompleteParent' + textAreaId);
if (!resultsDiv) return;
resultsDiv.style.display = "none";
selectedTag = null;
}
@@ -358,12 +360,12 @@ function isEnabled() {
if (TAC_CFG.activeIn.global) {
// Skip check if the current model was not correctly detected, since it could wrongly disable the script otherwise
if (!currentModelName || !currentModelHash) return true;
let modelList = TAC_CFG.activeIn.modelList
.split(",")
.map(x => x.trim())
.filter(x => x.length > 0);
let shortHash = currentModelHash.substring(0, 10);
let modelNameWithoutHash = currentModelName.replace(/\[.*\]$/g, "").trim();
if (TAC_CFG.activeIn.modelListMode.toLowerCase() === "blacklist") {
@@ -493,7 +495,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
keywords = info["activation text"];
}
}
if (!keywords && modelKeywordPath.length > 0 && result.hash && result.hash !== "NOFILE" && result.hash.length > 0) {
let nameDict = modelKeywordDict.get(result.hash);
let names = [result.text + ".safetensors", result.text + ".pt", result.text + ".ckpt"];
@@ -506,7 +508,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
keywords = nameDict.get(name);
}
});
if (!found)
keywords = nameDict.get("none");
}
@@ -514,17 +516,25 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
if (keywords && keywords.length > 0) {
textBeforeKeywordInsertion = newPrompt;
newPrompt = `${keywords}, ${newPrompt}`; // Insert keywords
if (TAC_CFG.modelKeywordLocation === "Start of prompt")
newPrompt = `${keywords}, ${newPrompt}`; // Insert keywords
else if (TAC_CFG.modelKeywordLocation === "End of prompt")
newPrompt = `${newPrompt}, ${keywords}`; // Insert keywords
else {
let keywordStart = prompt[editStart - 1] === " " ? editStart - 1 : editStart;
newPrompt = prompt.substring(0, keywordStart) + `, ${keywords} ${insert}` + prompt.substring(editEnd);
}
textAfterKeywordInsertion = newPrompt;
keywordInsertionUndone = false;
setTimeout(() => lastEditWasKeywordInsertion = true, 200)
keywordsLength = keywords.length + 2; // +2 for the comma and space
}
}
// Insert into prompt textbox and reposition cursor
textArea.value = newPrompt;
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength;
@@ -651,7 +661,12 @@ function addResultsToList(textArea, results, tagword, resetList) {
// Only use alias result if it is one
if (displayText.includes("➝"))
linkPart = displayText.split(" ➝ ")[1];
// Remove any trailing translations
if (linkPart.includes("[")) {
linkPart = linkPart.split("[")[0]
}
// Set link based on selected file
let tagFileNameLower = tagFileName.toLowerCase();
if (tagFileNameLower.startsWith("danbooru")) {
@@ -659,7 +674,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
} else if (tagFileNameLower.startsWith("e621")) {
wikiLink.href = `https://e621.net/wiki_pages/${linkPart}`;
}
wikiLink.target = "_blank";
flexDiv.appendChild(wikiLink);
}
@@ -712,7 +727,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
else if (result.meta.startsWith("v2"))
itemText.classList.add("acEmbeddingV2");
}
flexDiv.appendChild(metaDiv);
}
@@ -777,7 +792,7 @@ async function updateSelectionStyle(textArea, newIndex, oldIndex) {
}
let img = previewDiv.querySelector("img");
let url = await getExtraNetworkPreviewURL(selectedFilename, shorthandType);
if (url) {
img.src = url;
@@ -803,7 +818,7 @@ function updateRuby(textArea, prompt) {
ruby.setAttribute("class", `acRuby${typeClass} notranslate`);
textArea.parentNode.appendChild(ruby);
}
ruby.innerText = prompt;
let bracketEscapedPrompt = prompt.replaceAll("\\(", "$").replaceAll("\\)", "%");
@@ -821,9 +836,9 @@ function updateRuby(textArea, prompt) {
.replaceAll(" ", "_")
.replaceAll("\\(", "(")
.replaceAll("\\)", ")");
const translation = translations?.get(tag) || translations?.get(unsanitizedTag);
let escapedTag = escapeRegExp(tag);
return { tag, escapedTag, translation };
}
@@ -839,14 +854,14 @@ function updateRuby(textArea, prompt) {
// First try to find direct matches
[...rubyTags].forEach(tag => {
let tuple = prepareTag(tag);
if (tuple.translation) {
html = replaceOccurences(html, tuple);
} else {
let subTags = tuple.tag.split(" ").filter(x => x.trim().length > 0);
// Return if there is only one word
if (subTags.length === 1) return;
let subHtml = tag.replaceAll("$", "\\(").replaceAll("%", "\\)");
let translateNgram = (windows) => {
@@ -861,14 +876,14 @@ function updateRuby(textArea, prompt) {
}
});
}
// Perform n-gram sliding window search
translateNgram(toNgrams(subTags, 3));
translateNgram(toNgrams(subTags, 2));
translateNgram(toNgrams(subTags, 1));
let escapedTag = escapeRegExp(tuple.tag);
let searchRegex = new RegExp(`(?<!<ruby>)(?:\\b)${escapedTag}(?:\\b|$|(?=[,|: \\t\\n\\r]))(?!<rt>)`, "g");
html = html.replaceAll(searchRegex, subHtml);
}
@@ -980,6 +995,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
}
results = [];
resultCountBeforeNormalTags = 0;
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
// Process all parsers
@@ -1019,7 +1035,13 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
});
}
}
} else { // Else search the normal tag list
}
// Else search the normal tag list
if (!resultCandidates || resultCandidates.length === 0
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
) {
resultCountBeforeNormalTags = results.length;
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
@@ -1034,7 +1056,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
let aliasFilter = (x) => x[3] && x[3].toLowerCase().search(searchRegex) > -1;
let translationFilter = (x) => (translations.has(x[0]) && translations.get(x[0]).toLowerCase().search(searchRegex) > -1)
|| x[3] && x[3].split(",").some(y => translations.has(y) && translations.get(y).toLowerCase().search(searchRegex) > -1);
let fil;
if (TAC_CFG.alias.searchByAlias && TAC_CFG.translation.searchByTranslation)
fil = (x) => baseFilter(x) || aliasFilter(x) || translationFilter(x);
@@ -1072,10 +1094,10 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
results = results.concat(extraResults);
}
}
// Slice if the user has set a max result count
if (!TAC_CFG.showAllResults) {
results = results.slice(0, TAC_CFG.maxResults);
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
}
@@ -1093,7 +1115,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
function navigateInList(textArea, event) {
// Return if the function is deactivated in the UI or the current model is excluded due to white/blacklist settings
if (!isEnabled()) return;
let keys = TAC_CFG.keymap;
// Close window if Home or End is pressed while not a keybinding, since it would break completion on leaving the original tag
@@ -1109,7 +1131,7 @@ function navigateInList(textArea, event) {
if (!validKeys.includes(event.key)) return;
if (!isVisible(textArea)) return
// Return if ctrl key is pressed to not interfere with weight editing shortcut
if (event.ctrlKey || event.altKey) return;
if (event.ctrlKey || event.altKey || event.shiftKey || event.metaKey) return;
oldSelectedTag = selectedTag;
@@ -1143,10 +1165,25 @@ function navigateInList(textArea, event) {
}
break;
case keys["JumpToStart"]:
selectedTag = 0;
if (TAC_CFG.includeEmbeddingsInNormalResults &&
selectedTag > resultCountBeforeNormalTags &&
resultCountBeforeNormalTags > 0
) {
selectedTag = resultCountBeforeNormalTags;
} else {
selectedTag = 0;
}
break;
case keys["JumpToEnd"]:
selectedTag = resultCount - 1;
// Jump to the end of the list, or the end of embeddings if they are included in the normal results
if (TAC_CFG.includeEmbeddingsInNormalResults &&
selectedTag < resultCountBeforeNormalTags &&
resultCountBeforeNormalTags > 0
) {
selectedTag = Math.min(resultCountBeforeNormalTags, resultCount - 1);
} else {
selectedTag = resultCount - 1;
}
break;
case keys["ChooseSelected"]:
if (selectedTag !== null) {
@@ -1315,7 +1352,7 @@ async function setup() {
let mode = (document.querySelector(".dark") || gradioApp().querySelector(".dark")) ? 0 : 1;
// Check if we are on webkit
let browser = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 ? "firefox" : "other";
let css = autocompleteCSS;
// Replace vars with actual values (can't use actual css vars because of the way we inject the css)
Object.keys(styleColors).forEach((key) => {
@@ -1324,7 +1361,7 @@ async function setup() {
Object.keys(browserVars).forEach((key) => {
css = css.replaceAll(`var(${key})`, browserVars[key][browser]);
})
if (acStyle.styleSheet) {
acStyle.styleSheet.cssText = css;
} else {

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from modules import scripts, shared
try:
@@ -37,6 +38,21 @@ except AttributeError:
def find_ext_wildcard_paths():
"""Returns the path to the extension wildcards folder"""
found = list(EXT_PATH.glob("*/wildcards/"))
# Try to find the wildcard path from the shared opts
try:
from modules.shared import opts
except ImportError: # likely not in an a1111 context
opts = None
# Append custom wildcard paths
custom_paths = [
getattr(shared.cmd_opts, "wildcards_dir", None), # Cmd arg from the wildcard extension
getattr(opts, "wildcard_dir", None), # Custom path from sd-dynamic-prompts
]
for path in [Path(p) for p in custom_paths if p is not None]:
if path.exists():
found.append(path)
return found

View File

@@ -17,6 +17,12 @@ from scripts.model_keyword_support import (get_lora_simple_hash,
write_model_keyword_path)
from scripts.shared_paths import *
# Attempt to get embedding load function, using the same call as api.
try:
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
except Exception as e: # Not supported.
load_textual_inversion_embeddings = lambda *args, **kwargs: None
print("Tag Autocomplete: Cannot reload embeddings instantly:", e)
def get_wildcards():
"""Returns a list of all wildcards. Works on nested folders."""
@@ -50,7 +56,7 @@ def parse_umi_format(umi_tags, count, data):
for item in data:
umi_tags[count] = ','.join(data[item]['Tags'])
count += 1
def parse_dynamic_prompt_format(yaml_wildcards, data, path):
# Recurse subkeys, delete those without string lists as values
@@ -60,7 +66,7 @@ def parse_dynamic_prompt_format(yaml_wildcards, data, path):
recurse_dict(value)
elif not (isinstance(value, list) and all(isinstance(v, str) for v in value)):
del d[key]
recurse_dict(data)
# Add to yaml_wildcards
yaml_wildcards[path.name] = data
@@ -92,13 +98,13 @@ def get_yaml_wildcards():
except yaml.YAMLError:
print('Issue in parsing YAML file ' + path.name)
continue
# Sort by count
umi_sorted = sorted(umi_tags.items(), key=lambda item: item[1], reverse=True)
umi_output = []
for tag, count in umi_sorted:
umi_output.append(f"{tag},{count}")
if (len(umi_output) > 0):
write_to_temp_file('umi_tags.txt', umi_output)
@@ -198,7 +204,7 @@ def get_lyco():
# Get a list of all LyCORIS in the folder
lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)]
# Get hashes
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
@@ -208,7 +214,7 @@ def get_lyco():
hashes[name] = get_lora_simple_hash(ly)
else:
hashes[name] = ""
# Sort
sorted_lycos = dict(sorted(hashes.items()))
# Add hashes and return
@@ -280,6 +286,7 @@ if EMB_PATH.exists():
def refresh_temp_files():
global WILDCARD_EXT_PATHS
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
load_textual_inversion_embeddings(force_reload = True) # Instant embedding reload.
write_temp_files()
get_embeddings(shared.sd_model)
@@ -311,7 +318,7 @@ def write_temp_files():
lora = get_lora()
if lora:
write_to_temp_file('lora.txt', lora)
lyco_exists = LYCO_PATH is not None and LYCO_PATH.exists()
if lyco_exists and not (lora_exists and LYCO_PATH.samefile(LORA_PATH)):
lyco = get_lyco()
@@ -362,6 +369,7 @@ def on_ui_settings():
"tac_useWildcards": shared.OptionInfo(True, "Search for wildcards"),
"tac_sortWildcardResults": shared.OptionInfo(True, "Sort wildcard file contents alphabetically").info("If your wildcard files have a specific custom order, disable this to keep it"),
"tac_useEmbeddings": shared.OptionInfo(True, "Search for embeddings"),
"tac_includeEmbeddingsInNormalResults": shared.OptionInfo(False, "Include embeddings in normal tag results").info("The 'JumpTo...' keybinds (End & Home key by default) will select the first non-embedding result of their direction on the first press for quick navigation in longer lists."),
"tac_useHypernetworks": shared.OptionInfo(True, "Search for hypernetworks"),
"tac_useLoras": shared.OptionInfo(True, "Search for Loras"),
"tac_useLycos": shared.OptionInfo(True, "Search for LyCORIS/LoHa"),
@@ -374,6 +382,7 @@ def on_ui_settings():
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
"tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"),
"tac_modelKeywordCompletion": shared.OptionInfo("Never", "Try to add known trigger words for LORA/LyCO models", gr.Dropdown, lambda: {"choices": ["Never","Only user list","Always"]}).info("Will use & prefer the native activation keywords settable in the extra networks UI. Other functionality requires the <a href=\"https://github.com/mix1009/model-keyword\" target=\"_blank\">model-keyword</a> extension to be installed, but will work with it disabled.").needs_restart(),
"tac_modelKeywordLocation": shared.OptionInfo("Start of prompt", "Where to insert the trigger keyword", gr.Dropdown, lambda: {"choices": ["Start of prompt","End of prompt","Before LORA/LyCO"]}).info("Only relevant if the above option is enabled"),
"tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""),
# Alias settings
"tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),
@@ -444,53 +453,81 @@ def on_ui_settings():
shared.opts.add_option("tac_colormap", shared.OptionInfo(colorDefault, colorLabel, gr.Textbox, section=TAC_SECTION))
shared.opts.add_option("tac_refreshTempFiles", shared.OptionInfo("Refresh TAC temp files", "Refresh internal temp files", gr.HTML, {}, refresh=refresh_temp_files, section=TAC_SECTION))
script_callbacks.on_ui_settings(on_ui_settings)
def api_tac(_: gr.Blocks, app: FastAPI):
async def get_json_info(base_path: Path, filename: str = None):
if base_path is None or (not base_path.exists()):
return json.dumps({})
return JSONResponse({}, status_code=404)
try:
json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True)
if json_candidates is not None and len(json_candidates) > 0:
return FileResponse(json_candidates[0])
except Exception as e:
return json.dumps({"error": e})
async def get_preview_thumbnail(base_path: Path, filename: str = None):
return JSONResponse({"error": e}, status_code=500)
async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False):
if base_path is None or (not base_path.exists()):
return json.dumps({})
return JSONResponse({}, status_code=404)
try:
img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True)
img_candidates = [img for img in img_glob if Path(img).suffix in [".png", ".jpg", ".jpeg", ".webp"]]
img_candidates = [img for img in img_glob if Path(img).suffix in [".png", ".jpg", ".jpeg", ".webp", ".gif"]]
if img_candidates is not None and len(img_candidates) > 0:
return JSONResponse({"url": urllib.parse.quote(img_candidates[0])})
if blob:
return FileResponse(img_candidates[0])
else:
return JSONResponse({"url": urllib.parse.quote(img_candidates[0])})
except Exception as e:
return json.dumps({"error": e})
return JSONResponse({"error": e}, status_code=500)
@app.get("/tacapi/v1/lora-info/{lora_name}")
async def get_lora_info(lora_name):
return await get_json_info(LORA_PATH, lora_name)
@app.get("/tacapi/v1/lyco-info/{lyco_name}")
async def get_lyco_info(lyco_name):
return await get_json_info(LYCO_PATH, lyco_name)
def get_path_for_type(type):
if type == "lora":
return LORA_PATH
elif type == "lyco":
return LYCO_PATH
elif type == "hyper":
return HYP_PATH
elif type == "embed":
return EMB_PATH
else:
return None
@app.get("/tacapi/v1/thumb-preview/{filename}")
async def get_thumb_preview(filename, type):
if type == "lora":
return await get_preview_thumbnail(LORA_PATH, filename)
elif type == "lyco":
return await get_preview_thumbnail(LYCO_PATH, filename)
elif type == "hyper":
return await get_preview_thumbnail(HYP_PATH, filename)
elif type == "embed":
return await get_preview_thumbnail(EMB_PATH, filename)
else:
return "Invalid type"
return await get_preview_thumbnail(get_path_for_type(type), filename, False)
@app.get("/tacapi/v1/thumb-preview-blob/{filename}")
async def get_thumb_preview_blob(filename, type):
return await get_preview_thumbnail(get_path_for_type(type), filename, True)
@app.get("/tacapi/v1/wildcard-contents")
async def get_wildcard_contents(basepath: str, filename: str):
if basepath is None or basepath == "":
return JSONResponse({}, status_code=404)
base = Path(basepath)
if base is None or (not base.exists()):
return JSONResponse({}, status_code=404)
try:
wildcard_path = base.joinpath(filename)
if wildcard_path.exists():
return FileResponse(wildcard_path)
else:
return JSONResponse({}, status_code=404)
except Exception as e:
return JSONResponse({"error": e}, status_code=500)
script_callbacks.on_app_started(api_tac)
script_callbacks.on_app_started(api_tac)