From 6deefda279ae91ad050bdc55a56b2b4cd453abb4 Mon Sep 17 00:00:00 2001 From: Dominik Reh Date: Mon, 2 Jan 2023 00:38:48 +0100 Subject: [PATCH] Show version info for embeddings Also allows searching by version to quickly find v1 or v2 model embeddings Closes #97 --- javascript/tagAutocomplete.js | 37 +++++++++++++++--- scripts/tag_autocomplete_helper.py | 60 +++++++++++++++++++++++++++--- 2 files changed, 85 insertions(+), 12 deletions(-) diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index b2eadf1..1b1209a 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -7,7 +7,9 @@ const styleColors = { "--results-bg-odd": ["#111827", "#f9fafb"], "--results-hover": ["#1f2937", "#f5f6f8"], "--results-selected": ["#374151", "#e5e7eb"], - "--post-count-color": ["#6b6f7b", "#a2a9b4"] + "--post-count-color": ["#6b6f7b", "#a2a9b4"], + "--embedding-v1-color": ["lightsteelblue", "#2b5797"], + "--embedding-v2-color": ["skyblue", "#2d89ef"], } const browserVars = { "--results-overflow-y": { @@ -66,6 +68,12 @@ const autocompleteCSS = ` flex-grow: 1; color: var(--post-count-color); } + .acListItem.acEmbeddingV1 { + color: var(--embedding-v1-color); + } + .acListItem.acEmbeddingV2 { + color: var(--embedding-v2-color); + } `; // Parse the CSV file into a 2D array. Doesn't use regex, so it is very lightweight. @@ -364,7 +372,7 @@ function insertTextAtCursor(textArea, result, tagword) { } else if (tagType === "yamlWildcard" && !yamlWildcards.includes(text)) { sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them } else if (tagType === "embedding") { - sanitizedText = `<${text.replace(/^.*?: /g, "")}>`; + sanitizedText = `${text.replace(/^.*?: /g, "")}`; } else { sanitizedText = CFG.replaceUnderscores ? text.replaceAll("_", " ") : text; } @@ -422,7 +430,6 @@ function insertTextAtCursor(textArea, result, tagword) { let match = surrounding.match(new RegExp(escapeRegExp(`${tagword}`), "i")); let insert = surrounding.replace(match, sanitizedText); - let modifiedTagword = prompt.substring(0, editStart) + insert + prompt.substring(editEnd); let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)]; let umiTags = []; @@ -549,6 +556,17 @@ function addResultsToList(textArea, results, tagword, resetList) { countDiv.classList.add("acPostCount"); flexDiv.appendChild(countDiv); } + } else if (result[1] === "embedding" && result[2]) { // Check if it is an embedding we have version info for + let versionDiv = document.createElement("div"); + versionDiv.textContent = result[2]; + versionDiv.classList.add("acPostCount"); + + if (result[2].startsWith("v1")) + itemText.classList.add("acEmbeddingV1"); + else if (result[2].startsWith("v2")) + itemText.classList.add("acEmbeddingV2"); + + flexDiv.appendChild(versionDiv); } // Add listener @@ -811,7 +829,14 @@ async function autocomplete(textArea, prompt, fixedTag = null) { // Show embeddings let tempResults = []; if (tagword !== "<") { - tempResults = embeddings.filter(x => x.toLowerCase().includes(tagword.replace("<", ""))) // Filter by tagword + let searchTerm = tagword.replace("<", "") + let versionString; + if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) { + versionString = searchTerm.slice(0, 2); + searchTerm = searchTerm.slice(2); + } + let versionCondition = x => x[1] && x[1] === versionString; + tempResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm) && versionCondition(x)); // Filter by tagword } else { tempResults = embeddings; } @@ -825,7 +850,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) { searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i'); } genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults); - results = genericResults.concat(tempResults.map(x => ["Embeddings: " + x.trim(), "embedding"])); // Mark as embedding + results = tempResults.map(x => [x[0].trim(), "embedding", x[1] + " Embedding"]).concat(genericResults); // Mark as embedding } else { // Create escaped search regex with support for * as a start placeholder let searchRegex; @@ -1022,7 +1047,7 @@ async function setup() { try { embeddings = (await readFile(`${tagBasePath}/temp/emb.txt?${new Date().getTime()}`)).split("\n") .filter(x => x.trim().length > 0) // Remove empty lines - .map(x => x.replace(".bin", "").replace(".pt", "").replace(".png", "")); // Remove file extensions + .map(x => x.trim().split(",")); // Split into name, version type pairs } catch (e) { console.error("Error loading embeddings.txt: " + e); } diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index d1012bb..24a107b 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -3,8 +3,10 @@ import gradio as gr from pathlib import Path -from modules import scripts, script_callbacks, shared +from modules import scripts, script_callbacks, shared, sd_hijack import yaml +import time +import threading # Webui root path FILE_DIR = Path().absolute() @@ -78,9 +80,54 @@ def get_ext_wildcard_tags(): output.append(f"{tag},{count}") return output + def get_embeddings(): - """Returns a list of all embeddings""" - return [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png"}] + """Write a list of all embeddings with their version""" + # Get a list of all embeddings in the folder + embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}] + # Remove file extensions + embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir] + + # Wait for all embeddings to be loaded + while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir): + time.sleep(2) # Sleep for 2 seconds + + # Get embedding dict from sd_hijack to separate v1/v2 embeddings + emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings + emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings + # Get the shape of the first item in the dict + emb_a_shape = -1 + if (len(emb_type_a) > 0): + emb_a_shape = next(iter(emb_type_a.items()))[1].shape + + # Add embeddings to the correct list + V1_SHAPE = 768 + V2_SHAPE = 1024 + emb_v1 = [] + emb_v2 = [] + + if (emb_a_shape == V1_SHAPE): + emb_v1 = list(emb_type_a.keys()) + emb_v2 = list(emb_type_b) + elif (emb_a_shape == V2_SHAPE): + emb_v1 = list(emb_type_b) + emb_v2 = list(emb_type_a.keys()) + + # Create a new list to store the modified strings + results = [] + + # Iterate through each string in the big list + for string in embs_in_dir: + if string in emb_v1: + results.append(string + ",v1") + elif string in emb_v2: + results.append(string + ",v2") + # If the string is not in either, default to v1 + # (we can't know what it is since the startup model loaded none of them, but it's probably v1 since v2 is newer) + else: + results.append(string + ",v1") + + write_to_temp_file('emb.txt', results) def write_tag_base_path(): @@ -143,9 +190,10 @@ if WILDCARD_EXT_PATHS is not None: # Write embeddings to emb.txt if found if EMB_PATH.exists(): - embeddings = get_embeddings() - if embeddings: - write_to_temp_file('emb.txt', embeddings) + # We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads) + thread = threading.Thread(target=get_embeddings) + thread.start() + # Register autocomplete options def on_ui_settings():