diff --git a/javascript/_utils.js b/javascript/_utils.js index 04321bb..1c6d978 100644 --- a/javascript/_utils.js +++ b/javascript/_utils.js @@ -1,13 +1,14 @@ // Utility functions for tag autocomplete // Parse the CSV file into a 2D array. Doesn't use regex, so it is very lightweight. +// We are ignoring newlines in quote fields since we expect one-line entries and parsing would break for unclosed quotes otherwise function parseCSV(str) { - var arr = []; - var quote = false; // 'true' means we're inside a quoted field + const arr = []; + let quote = false; // 'true' means we're inside a quoted field // Iterate over each character, keep track of current row and column (of the returned array) - for (var row = 0, col = 0, c = 0; c < str.length; c++) { - var cc = str[c], nc = str[c + 1]; // Current character, next character + for (let row = 0, col = 0, c = 0; c < str.length; c++) { + let cc = str[c], nc = str[c+1]; // Current character, next character arr[row] = arr[row] || []; // Create a new row if necessary arr[row][col] = arr[row][col] || ''; // Create a new column (start with empty string) if necessary @@ -22,14 +23,12 @@ function parseCSV(str) { // If it's a comma and we're not in a quoted field, move on to the next column if (cc == ',' && !quote) { ++col; continue; } - // If it's a newline (CRLF) and we're not in a quoted field, skip the next character - // and move on to the next row and move to column 0 of that new row - if (cc == '\r' && nc == '\n' && !quote) { ++row; col = 0; ++c; continue; } + // If it's a newline (CRLF), skip the next character and move on to the next row and move to column 0 of that new row + if (cc == '\r' && nc == '\n') { ++row; col = 0; ++c; quote = false; continue; } - // If it's a newline (LF or CR) and we're not in a quoted field, - // move on to the next row and move to column 0 of that new row - if (cc == '\n' && !quote) { ++row; col = 0; continue; } - if (cc == '\r' && !quote) { ++row; col = 0; continue; } + // If it's a newline (LF or CR) move on to the next row and move to column 0 of that new row + if (cc == '\n') { ++row; col = 0; quote = false; continue; } + if (cc == '\r') { ++row; col = 0; quote = false; continue; } // Otherwise, append the current character to the current column arr[row][col] += cc; diff --git a/javascript/ext_embeddings.js b/javascript/ext_embeddings.js index 9c7bd44..820aae4 100644 --- a/javascript/ext_embeddings.js +++ b/javascript/ext_embeddings.js @@ -11,12 +11,15 @@ class EmbeddingParser extends BaseTagParser { if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) { versionString = searchTerm.slice(0, 2); searchTerm = searchTerm.slice(2); + } else if (searchTerm.startsWith("vxl")) { + versionString = searchTerm.slice(0, 3); + searchTerm = searchTerm.slice(3); } let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm); if (versionString) - tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2] === versionString); // Filter by tagword + tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2].toLowerCase() === versionString.toLowerCase()); // Filter by tagword else tempResults = embeddings.filter(x => filterCondition(x)); // Filter by tagword } else { diff --git a/javascript/ext_modelKeyword.js b/javascript/ext_modelKeyword.js index ff07910..ac88747 100644 --- a/javascript/ext_modelKeyword.js +++ b/javascript/ext_modelKeyword.js @@ -20,7 +20,7 @@ async function load() { // Add to the dict csv_lines.forEach(parts => { const hash = parts[0]; - const keywords = parts[1].replaceAll("| ", ", ").replaceAll("|", ", ").trim(); + const keywords = parts[1]?.replaceAll("| ", ", ")?.replaceAll("|", ", ")?.trim(); const lastSepIndex = parts[2]?.lastIndexOf("/") + 1 || parts[2]?.lastIndexOf("\\") + 1 || 0; const name = parts[2]?.substring(lastSepIndex).trim() || "none" diff --git a/javascript/ext_umi.js b/javascript/ext_umi.js index a55f80c..ea5067a 100644 --- a/javascript/ext_umi.js +++ b/javascript/ext_umi.js @@ -129,7 +129,7 @@ class UmiParser extends BaseTagParser { return; } - let umiTagword = diff[0] || ''; + let umiTagword = tagCountChange < 0 ? '' : diff[0] || ''; let tempResults = []; if (umiTagword && umiTagword.length > 0) { umiTagword = umiTagword.toLowerCase().replace(/[\n\r]/g, ""); @@ -188,7 +188,7 @@ class UmiParser extends BaseTagParser { } } -function updateUmiTags( tagType, sanitizedText, newPrompt, textArea) { +function updateUmiTags(tagType, sanitizedText, newPrompt, textArea) { // If it was a umi wildcard, also update the umiPreviousTags if (tagType === ResultType.umiWildcard && originalTagword.length > 0) { let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)]; diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index d197fb7..59aa708 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -546,6 +546,14 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout let nameDict = modelKeywordDict.get(result.hash); let names = [result.text + ".safetensors", result.text + ".pt", result.text + ".ckpt"]; + // No match, try to find a sha256 match from the cache file + if (!nameDict) { + const sha256 = await fetchAPI(`/tacapi/v1/lora-cached-hash/${result.text}`) + if (sha256) { + nameDict = modelKeywordDict.get(sha256); + } + } + if (nameDict) { let found = false; names.forEach(name => { @@ -715,6 +723,8 @@ function addResultsToList(textArea, results, tagword, resetList) { linkPart = linkPart.split("[")[0] } + linkPart = encodeURIComponent(linkPart); + // Set link based on selected file let tagFileNameLower = tagFileName.toLowerCase(); if (tagFileNameLower.startsWith("danbooru")) { diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 25baa9c..7820899 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -12,7 +12,7 @@ import gradio as gr import yaml from fastapi import FastAPI from fastapi.responses import FileResponse, JSONResponse -from modules import script_callbacks, sd_hijack, shared +from modules import script_callbacks, sd_hijack, shared, hashes from pydantic import BaseModel from scripts.model_keyword_support import (get_lora_simple_hash, @@ -171,44 +171,30 @@ def get_embeddings(sd_model): # Version constants V1_SHAPE = 768 V2_SHAPE = 1024 + VXL_SHAPE = 2048 emb_v1 = [] emb_v2 = [] + emb_vXL = [] results = [] try: # Get embedding dict from sd_hijack to separate v1/v2 embeddings - emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings - emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings - # Get the shape of the first item in the dict - emb_a_shape = -1 - emb_b_shape = -1 - if (len(emb_type_a) > 0): - emb_a_shape = next(iter(emb_type_a.items()))[1].shape - if (len(emb_type_b) > 0): - emb_b_shape = next(iter(emb_type_b.items()))[1].shape + loaded = sd_hijack.model_hijack.embedding_db.word_embeddings + skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings # Add embeddings to the correct list - if (emb_a_shape == V1_SHAPE): - emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_a.items()] - elif (emb_a_shape == V2_SHAPE): - emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_a.items()] + for key, emb in (loaded | skipped).items(): + if emb.filename is None or emb.shape is None: + continue - if (emb_b_shape == V1_SHAPE): - emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_b.items()] - elif (emb_b_shape == V2_SHAPE): - emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_b.items()] + if emb.shape == V1_SHAPE: + emb_v1.append((Path(emb.filename), key, "v1")) + elif emb.shape == V2_SHAPE: + emb_v2.append((Path(emb.filename), key, "v2")) + elif emb.shape == VXL_SHAPE: + emb_vXL.append((Path(emb.filename), key, "vXL")) - # Get shape of current model - #vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1) - #model_shape = vec.shape[1] - # Show relevant entries at the top - #if (model_shape == V1_SHAPE): - # results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2] - #elif (model_shape == V2_SHAPE): - # results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1] - #else: - # raise AttributeError # Fallback to old method - results = sort_models(emb_v1) + sort_models(emb_v2) + results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL) except AttributeError: print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.") # Get a list of all embeddings in the folder @@ -553,6 +539,18 @@ def api_tac(_: gr.Blocks, app: FastAPI): async def get_lyco_info(lyco_name): return await get_json_info(LYCO_PATH, lyco_name) + @app.get("/tacapi/v1/lora-cached-hash/{lora_name}") + async def get_lora_cached_hash(lora_name: str): + path_glob = glob.glob(LORA_PATH.as_posix() + f"/**/{lora_name}.*", recursive=True) + paths = [lora for lora in path_glob if Path(lora).suffix in [".safetensors", ".ckpt", ".pt"]] + if paths is not None and len(paths) > 0: + path = paths[0] + hash = hashes.sha256_from_cache(path, f"lora/{lora_name}", path.endswith(".safetensors")) + if hash is not None: + return hash + + return None + def get_path_for_type(type): if type == "lora": return LORA_PATH