mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 19:19:57 +00:00
Show version info for embeddings
Also allows searching by version to quickly find v1 or v2 model embeddings Closes #97
This commit is contained in:
@@ -7,7 +7,9 @@ const styleColors = {
|
||||
"--results-bg-odd": ["#111827", "#f9fafb"],
|
||||
"--results-hover": ["#1f2937", "#f5f6f8"],
|
||||
"--results-selected": ["#374151", "#e5e7eb"],
|
||||
"--post-count-color": ["#6b6f7b", "#a2a9b4"]
|
||||
"--post-count-color": ["#6b6f7b", "#a2a9b4"],
|
||||
"--embedding-v1-color": ["lightsteelblue", "#2b5797"],
|
||||
"--embedding-v2-color": ["skyblue", "#2d89ef"],
|
||||
}
|
||||
const browserVars = {
|
||||
"--results-overflow-y": {
|
||||
@@ -66,6 +68,12 @@ const autocompleteCSS = `
|
||||
flex-grow: 1;
|
||||
color: var(--post-count-color);
|
||||
}
|
||||
.acListItem.acEmbeddingV1 {
|
||||
color: var(--embedding-v1-color);
|
||||
}
|
||||
.acListItem.acEmbeddingV2 {
|
||||
color: var(--embedding-v2-color);
|
||||
}
|
||||
`;
|
||||
|
||||
// Parse the CSV file into a 2D array. Doesn't use regex, so it is very lightweight.
|
||||
@@ -364,7 +372,7 @@ function insertTextAtCursor(textArea, result, tagword) {
|
||||
} else if (tagType === "yamlWildcard" && !yamlWildcards.includes(text)) {
|
||||
sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them
|
||||
} else if (tagType === "embedding") {
|
||||
sanitizedText = `<${text.replace(/^.*?: /g, "")}>`;
|
||||
sanitizedText = `${text.replace(/^.*?: /g, "")}`;
|
||||
} else {
|
||||
sanitizedText = CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
|
||||
}
|
||||
@@ -422,7 +430,6 @@ function insertTextAtCursor(textArea, result, tagword) {
|
||||
let match = surrounding.match(new RegExp(escapeRegExp(`${tagword}`), "i"));
|
||||
let insert = surrounding.replace(match, sanitizedText);
|
||||
|
||||
let modifiedTagword = prompt.substring(0, editStart) + insert + prompt.substring(editEnd);
|
||||
let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)];
|
||||
|
||||
let umiTags = [];
|
||||
@@ -549,6 +556,17 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
countDiv.classList.add("acPostCount");
|
||||
flexDiv.appendChild(countDiv);
|
||||
}
|
||||
} else if (result[1] === "embedding" && result[2]) { // Check if it is an embedding we have version info for
|
||||
let versionDiv = document.createElement("div");
|
||||
versionDiv.textContent = result[2];
|
||||
versionDiv.classList.add("acPostCount");
|
||||
|
||||
if (result[2].startsWith("v1"))
|
||||
itemText.classList.add("acEmbeddingV1");
|
||||
else if (result[2].startsWith("v2"))
|
||||
itemText.classList.add("acEmbeddingV2");
|
||||
|
||||
flexDiv.appendChild(versionDiv);
|
||||
}
|
||||
|
||||
// Add listener
|
||||
@@ -811,7 +829,14 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
// Show embeddings
|
||||
let tempResults = [];
|
||||
if (tagword !== "<") {
|
||||
tempResults = embeddings.filter(x => x.toLowerCase().includes(tagword.replace("<", ""))) // Filter by tagword
|
||||
let searchTerm = tagword.replace("<", "")
|
||||
let versionString;
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
searchTerm = searchTerm.slice(2);
|
||||
}
|
||||
let versionCondition = x => x[1] && x[1] === versionString;
|
||||
tempResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm) && versionCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = embeddings;
|
||||
}
|
||||
@@ -825,7 +850,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
|
||||
}
|
||||
genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
|
||||
results = genericResults.concat(tempResults.map(x => ["Embeddings: " + x.trim(), "embedding"])); // Mark as embedding
|
||||
results = tempResults.map(x => [x[0].trim(), "embedding", x[1] + " Embedding"]).concat(genericResults); // Mark as embedding
|
||||
} else {
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
@@ -1022,7 +1047,7 @@ async function setup() {
|
||||
try {
|
||||
embeddings = (await readFile(`${tagBasePath}/temp/emb.txt?${new Date().getTime()}`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => x.replace(".bin", "").replace(".pt", "").replace(".png", "")); // Remove file extensions
|
||||
.map(x => x.trim().split(",")); // Split into name, version type pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading embeddings.txt: " + e);
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
import gradio as gr
|
||||
from pathlib import Path
|
||||
from modules import scripts, script_callbacks, shared
|
||||
from modules import scripts, script_callbacks, shared, sd_hijack
|
||||
import yaml
|
||||
import time
|
||||
import threading
|
||||
|
||||
# Webui root path
|
||||
FILE_DIR = Path().absolute()
|
||||
@@ -78,9 +80,54 @@ def get_ext_wildcard_tags():
|
||||
output.append(f"{tag},{count}")
|
||||
return output
|
||||
|
||||
|
||||
def get_embeddings():
|
||||
"""Returns a list of all embeddings"""
|
||||
return [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png"}]
|
||||
"""Write a list of all embeddings with their version"""
|
||||
# Get a list of all embeddings in the folder
|
||||
embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
|
||||
# Remove file extensions
|
||||
embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir]
|
||||
|
||||
# Wait for all embeddings to be loaded
|
||||
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
|
||||
time.sleep(2) # Sleep for 2 seconds
|
||||
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
# Get the shape of the first item in the dict
|
||||
emb_a_shape = -1
|
||||
if (len(emb_type_a) > 0):
|
||||
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
|
||||
|
||||
# Add embeddings to the correct list
|
||||
V1_SHAPE = 768
|
||||
V2_SHAPE = 1024
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_a.keys())
|
||||
emb_v2 = list(emb_type_b)
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v1 = list(emb_type_b)
|
||||
emb_v2 = list(emb_type_a.keys())
|
||||
|
||||
# Create a new list to store the modified strings
|
||||
results = []
|
||||
|
||||
# Iterate through each string in the big list
|
||||
for string in embs_in_dir:
|
||||
if string in emb_v1:
|
||||
results.append(string + ",v1")
|
||||
elif string in emb_v2:
|
||||
results.append(string + ",v2")
|
||||
# If the string is not in either, default to v1
|
||||
# (we can't know what it is since the startup model loaded none of them, but it's probably v1 since v2 is newer)
|
||||
else:
|
||||
results.append(string + ",v1")
|
||||
|
||||
write_to_temp_file('emb.txt', results)
|
||||
|
||||
|
||||
def write_tag_base_path():
|
||||
@@ -143,9 +190,10 @@ if WILDCARD_EXT_PATHS is not None:
|
||||
|
||||
# Write embeddings to emb.txt if found
|
||||
if EMB_PATH.exists():
|
||||
embeddings = get_embeddings()
|
||||
if embeddings:
|
||||
write_to_temp_file('emb.txt', embeddings)
|
||||
# We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads)
|
||||
thread = threading.Thread(target=get_embeddings)
|
||||
thread.start()
|
||||
|
||||
|
||||
# Register autocomplete options
|
||||
def on_ui_settings():
|
||||
|
||||
Reference in New Issue
Block a user