Automatic wildcard & embed discovery

This commit is contained in:
Dominik Reh
2022-10-15 14:32:02 +02:00
parent a628d96a41
commit 08c10928f8
3 changed files with 80 additions and 21 deletions

View File

@@ -0,0 +1,40 @@
# This helper script scans folders for wildcards and embeddings and writes them
# to a temporary file to expose it to the javascript side
import os
# The path to the folder containing the wildcards and embeddings
FILE_DIR = os.path.dirname(os.path.realpath("__file__"))
WILDCARD_PATH = os.path.join(FILE_DIR, 'scripts/wildcards')
EMB_PATH = os.path.join(FILE_DIR, 'embeddings')
# The path to the temporary file
TEMP_PATH = os.path.join(FILE_DIR, 'tags/temp')
def get_wildcards():
"""Returns a list of all wildcards"""
return filter(lambda f: f.endswith(".txt"), os.listdir(WILDCARD_PATH))
def get_embeddings():
"""Returns a list of all embeddings"""
return filter(lambda f: f.endswith(".bin") or f.endswith(".pt"), os.listdir(EMB_PATH))
def write_to_temp_file(name, data):
"""Writes the given data to a temporary file"""
with open(os.path.join(TEMP_PATH, name), 'w') as f:
f.write(('\n'.join(data)))
# Check if the temp path exists and create it if not
if not os.path.exists(TEMP_PATH):
os.makedirs(TEMP_PATH)
# Write wildcards to wc.txt if found
if os.path.exists(WILDCARD_PATH):
wildcards = get_wildcards()
if wildcards:
write_to_temp_file('wc.txt', wildcards)
# Write embeddings to emb.txt if found
if os.path.exists(EMB_PATH):
embeddings = get_embeddings()
if embeddings:
write_to_temp_file('emb.txt', embeddings)

View File

@@ -226,6 +226,8 @@ function insertTextAtCursor(textArea, result, tagword) {
sanitizedText = "__" + text.replace("Wildcards: ", "") + "__";
} else if (tagType === "wildcardTag") {
sanitizedText = text.replace(/^.*?: /g, "");
} else if (tagType === "embedding") {
sanitizedText = `<${text.replace(/^.*?: /g, "")}>`;
} else {
sanitizedText = acConfig.replaceUnderscores ? text.replaceAll("_", " ") : text;
}
@@ -283,8 +285,13 @@ function insertTextAtCursor(textArea, result, tagword) {
function addResultsToList(textArea, results, tagword) {
let textAreaId = getTextAreaIdentifier(textArea);
let resultsList = gradioApp().querySelector('.autocompleteResults' + textAreaId + ' > ul');
let resultDiv = gradioApp().querySelector('.autocompleteResults' + textAreaId);
let resultsList = resultDiv.querySelector('ul');
// Reset list, selection and scrollTop since the list changed
resultsList.innerHTML = "";
selectedTag = null;
resultDiv.scrollTop = 0;
// Find right colors from config
let tagFileName = acConfig.tagFile.split(".")[0];
@@ -297,14 +304,15 @@ function addResultsToList(textArea, results, tagword) {
let li = document.createElement("li");
li.textContent = result[0];
// Wildcards have no tag type
if (!result[1].startsWith("wildcard")) {
// Wildcards & Embeds have no tag type
if (!result[1].startsWith("wildcard") && result[1] !== "embedding") {
// Set the color of the tag
let tagType = result[1];
let colorGroup = tagColors[tagFileName];
// Default to danbooru scheme if no matching one is found
if (colorGroup === undefined) colorGroup = tagColors["danbooru"];
console.log(colorGroup[tagType][mode]);
li.style = `color: ${colorGroup[tagType][mode]};`;
}
@@ -336,6 +344,7 @@ function updateSelectionStyle(textArea, num) {
wildcardFiles = [];
wildcards = {};
embeddings = [];
allTags = [];
results = [];
tagword = "";
@@ -375,14 +384,14 @@ function autocomplete(textArea, prompt, fixedTag = null) {
tagword = tagword.toLowerCase();
if ([...tagword.matchAll(/\b__([^,_ ]+)__([^, ]*)\b/g)].length > 0 && acConfig.useWildcards) {
if (acConfig.useWildcards && [...tagword.matchAll(/\b__([^,_ ]+)__([^, ]*)\b/g)].length > 0) {
// Show wildcards from a file with that name
wcMatch = [...tagword.matchAll(/\b__([^,_ ]+)__([^, ]*)\b/g)]
let wcFile = wcMatch[0][1];
let wcWord = wcMatch[0][2];
results = wildcards[wcFile].filter(x => (wcWord !== null) ? x.toLowerCase().includes(wcWord) : x) // Filter by tagword
.map(x => [wcFile + ": " + x.trim(), "wildcardTag"]); // Mark as wildcard
} else if ((tagword.startsWith("__") && !tagword.endsWith("__") || tagword === "__") && acConfig.useWildcards) {
} else if (acConfig.useWildcards && (tagword.startsWith("__") && !tagword.endsWith("__") || tagword === "__")) {
// Show available wildcard files
let tempResults = [];
if (tagword !== "__") {
@@ -391,6 +400,17 @@ function autocomplete(textArea, prompt, fixedTag = null) {
tempResults = wildcardFiles;
}
results = tempResults.map(x => ["Wildcards: " + x.trim(), "wildcardFile"]); // Mark as wildcard
} else if (acConfig.useEmbeddings && tagword.match(/<[^,> ]*>?/g)) {
// Show embeddings
let tempResults = [];
if (tagword !== "<") {
tempResults = embeddings.filter(x => x.toLowerCase().includes(tagword.replace("<", ""))) // Filter by tagword
} else {
tempResults = embeddings;
}
// Since some tags are kaomoji, we have to still get the normal results first.
genericResults = allTags.filter(x => x[0].toLowerCase().includes(tagword)).slice(0, acConfig.maxResults);
results = genericResults.concat(tempResults.map(x => ["Embeddings: " + x.trim(), "embedding"])); // Mark as embedding
} else {
results = allTags.filter(x => x[0].toLowerCase().includes(tagword)).slice(0, acConfig.maxResults);
}
@@ -402,7 +422,6 @@ function autocomplete(textArea, prompt, fixedTag = null) {
return;
}
selectedTag = null; // Reset since the list changed
showResults(textArea);
addResultsToList(textArea, results, tagword);
}
@@ -457,7 +476,7 @@ function navigateInList(textArea, event) {
styleAdded = false;
onUiUpdate(function () {
// One-time config, tags & wildcards loading
// Load config
if (acConfig === null) {
try {
acConfig = JSON.parse(readFile("file/tags/config.json"));
@@ -466,6 +485,7 @@ onUiUpdate(function () {
return;
}
}
// Load main tags
if (allTags.length === 0) {
try {
allTags = loadCSV();
@@ -474,17 +494,14 @@ onUiUpdate(function () {
return;
}
}
// Load wildcards
if (wildcardFiles.length === 0 && acConfig.useWildcards) {
try {
wildcardFiles = readFile("file/tags/wildcardNames.txt").split("\n")
.filter(x => !x.startsWith("//")) // Remove comments
.filter(x => x.toLowerCase().includes(tagword.substring(2))) // Filter by tagword
.filter(x => x.trim().length > 0) // Remove empty lines
wildcardFiles = readFile("file/tags/temp/wc.txt").split("\n");
wildcardFiles.forEach(fName => {
try {
wildcards[fName.trim()] = readFile(`file/scripts/wildcards/${fName}.txt`).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
wildcards[fName.trim()] = readFile(`file/scripts/wildcards/${fName}`).split("\n")
.filter(x => x.trim().length > 0); // Remove empty lines
} catch (e) {
console.log(`Could not load wildcards for ${fName}`);
}
@@ -493,6 +510,15 @@ onUiUpdate(function () {
console.error("Error loading wildcardNames.txt: " + e);
}
}
// Load embeddings
if (embeddings.length === 0 && acConfig.useEmbeddings) {
try {
embeddings = readFile("file/tags/temp/emb.txt").split("\n")
.map(x => x.split(".")[0]); // Remove file extensions
} catch (e) {
console.error("Error loading embeddings.txt: " + e);
}
}
// Find all textareas
let txt2imgTextArea = gradioApp().querySelector('#txt2img_prompt > label > textarea');

View File

@@ -1,7 +0,0 @@
// Put the file names of wildcard files you want to use here. Needed so that the script can access them.
// The default ones are the following, you can uncomment them if you have them
//adjective
//artist
//genre
//site
//style