mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 11:09:54 +00:00
Merge branch 'main' into feature-sort-by-frequent-use
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
// Utility functions for tag autocomplete
|
||||
|
||||
// Parse the CSV file into a 2D array. Doesn't use regex, so it is very lightweight.
|
||||
// We are ignoring newlines in quote fields since we expect one-line entries and parsing would break for unclosed quotes otherwise
|
||||
function parseCSV(str) {
|
||||
var arr = [];
|
||||
var quote = false; // 'true' means we're inside a quoted field
|
||||
const arr = [];
|
||||
let quote = false; // 'true' means we're inside a quoted field
|
||||
|
||||
// Iterate over each character, keep track of current row and column (of the returned array)
|
||||
for (var row = 0, col = 0, c = 0; c < str.length; c++) {
|
||||
var cc = str[c], nc = str[c + 1]; // Current character, next character
|
||||
for (let row = 0, col = 0, c = 0; c < str.length; c++) {
|
||||
let cc = str[c], nc = str[c+1]; // Current character, next character
|
||||
arr[row] = arr[row] || []; // Create a new row if necessary
|
||||
arr[row][col] = arr[row][col] || ''; // Create a new column (start with empty string) if necessary
|
||||
|
||||
@@ -22,14 +23,12 @@ function parseCSV(str) {
|
||||
// If it's a comma and we're not in a quoted field, move on to the next column
|
||||
if (cc == ',' && !quote) { ++col; continue; }
|
||||
|
||||
// If it's a newline (CRLF) and we're not in a quoted field, skip the next character
|
||||
// and move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\r' && nc == '\n' && !quote) { ++row; col = 0; ++c; continue; }
|
||||
// If it's a newline (CRLF), skip the next character and move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\r' && nc == '\n') { ++row; col = 0; ++c; quote = false; continue; }
|
||||
|
||||
// If it's a newline (LF or CR) and we're not in a quoted field,
|
||||
// move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\n' && !quote) { ++row; col = 0; continue; }
|
||||
if (cc == '\r' && !quote) { ++row; col = 0; continue; }
|
||||
// If it's a newline (LF or CR) move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\n') { ++row; col = 0; quote = false; continue; }
|
||||
if (cc == '\r') { ++row; col = 0; quote = false; continue; }
|
||||
|
||||
// Otherwise, append the current character to the current column
|
||||
arr[row][col] += cc;
|
||||
|
||||
@@ -11,12 +11,15 @@ class EmbeddingParser extends BaseTagParser {
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
searchTerm = searchTerm.slice(2);
|
||||
} else if (searchTerm.startsWith("vxl")) {
|
||||
versionString = searchTerm.slice(0, 3);
|
||||
searchTerm = searchTerm.slice(3);
|
||||
}
|
||||
|
||||
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
|
||||
if (versionString)
|
||||
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2] === versionString); // Filter by tagword
|
||||
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2].toLowerCase() === versionString.toLowerCase()); // Filter by tagword
|
||||
else
|
||||
tempResults = embeddings.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
|
||||
@@ -20,7 +20,7 @@ async function load() {
|
||||
// Add to the dict
|
||||
csv_lines.forEach(parts => {
|
||||
const hash = parts[0];
|
||||
const keywords = parts[1].replaceAll("| ", ", ").replaceAll("|", ", ").trim();
|
||||
const keywords = parts[1]?.replaceAll("| ", ", ")?.replaceAll("|", ", ")?.trim();
|
||||
const lastSepIndex = parts[2]?.lastIndexOf("/") + 1 || parts[2]?.lastIndexOf("\\") + 1 || 0;
|
||||
const name = parts[2]?.substring(lastSepIndex).trim() || "none"
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ class UmiParser extends BaseTagParser {
|
||||
return;
|
||||
}
|
||||
|
||||
let umiTagword = diff[0] || '';
|
||||
let umiTagword = tagCountChange < 0 ? '' : diff[0] || '';
|
||||
let tempResults = [];
|
||||
if (umiTagword && umiTagword.length > 0) {
|
||||
umiTagword = umiTagword.toLowerCase().replace(/[\n\r]/g, "");
|
||||
@@ -188,7 +188,7 @@ class UmiParser extends BaseTagParser {
|
||||
}
|
||||
}
|
||||
|
||||
function updateUmiTags( tagType, sanitizedText, newPrompt, textArea) {
|
||||
function updateUmiTags(tagType, sanitizedText, newPrompt, textArea) {
|
||||
// If it was a umi wildcard, also update the umiPreviousTags
|
||||
if (tagType === ResultType.umiWildcard && originalTagword.length > 0) {
|
||||
let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)];
|
||||
|
||||
@@ -546,6 +546,14 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
let nameDict = modelKeywordDict.get(result.hash);
|
||||
let names = [result.text + ".safetensors", result.text + ".pt", result.text + ".ckpt"];
|
||||
|
||||
// No match, try to find a sha256 match from the cache file
|
||||
if (!nameDict) {
|
||||
const sha256 = await fetchAPI(`/tacapi/v1/lora-cached-hash/${result.text}`)
|
||||
if (sha256) {
|
||||
nameDict = modelKeywordDict.get(sha256);
|
||||
}
|
||||
}
|
||||
|
||||
if (nameDict) {
|
||||
let found = false;
|
||||
names.forEach(name => {
|
||||
@@ -715,6 +723,8 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
linkPart = linkPart.split("[")[0]
|
||||
}
|
||||
|
||||
linkPart = encodeURIComponent(linkPart);
|
||||
|
||||
// Set link based on selected file
|
||||
let tagFileNameLower = tagFileName.toLowerCase();
|
||||
if (tagFileNameLower.startsWith("danbooru")) {
|
||||
|
||||
@@ -12,7 +12,7 @@ import gradio as gr
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from modules import script_callbacks, sd_hijack, shared
|
||||
from modules import script_callbacks, sd_hijack, shared, hashes
|
||||
from pydantic import BaseModel
|
||||
|
||||
from scripts.model_keyword_support import (get_lora_simple_hash,
|
||||
@@ -171,44 +171,30 @@ def get_embeddings(sd_model):
|
||||
# Version constants
|
||||
V1_SHAPE = 768
|
||||
V2_SHAPE = 1024
|
||||
VXL_SHAPE = 2048
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
emb_vXL = []
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
# Get the shape of the first item in the dict
|
||||
emb_a_shape = -1
|
||||
emb_b_shape = -1
|
||||
if (len(emb_type_a) > 0):
|
||||
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
|
||||
if (len(emb_type_b) > 0):
|
||||
emb_b_shape = next(iter(emb_type_b.items()))[1].shape
|
||||
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
|
||||
# Add embeddings to the correct list
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_a.items()]
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_a.items()]
|
||||
for key, emb in (loaded | skipped).items():
|
||||
if emb.filename is None or emb.shape is None:
|
||||
continue
|
||||
|
||||
if (emb_b_shape == V1_SHAPE):
|
||||
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_b.items()]
|
||||
elif (emb_b_shape == V2_SHAPE):
|
||||
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_b.items()]
|
||||
if emb.shape == V1_SHAPE:
|
||||
emb_v1.append((Path(emb.filename), key, "v1"))
|
||||
elif emb.shape == V2_SHAPE:
|
||||
emb_v2.append((Path(emb.filename), key, "v2"))
|
||||
elif emb.shape == VXL_SHAPE:
|
||||
emb_vXL.append((Path(emb.filename), key, "vXL"))
|
||||
|
||||
# Get shape of current model
|
||||
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
#model_shape = vec.shape[1]
|
||||
# Show relevant entries at the top
|
||||
#if (model_shape == V1_SHAPE):
|
||||
# results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2]
|
||||
#elif (model_shape == V2_SHAPE):
|
||||
# results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
|
||||
#else:
|
||||
# raise AttributeError # Fallback to old method
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2)
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL)
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
|
||||
# Get a list of all embeddings in the folder
|
||||
@@ -553,6 +539,18 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
async def get_lyco_info(lyco_name):
|
||||
return await get_json_info(LYCO_PATH, lyco_name)
|
||||
|
||||
@app.get("/tacapi/v1/lora-cached-hash/{lora_name}")
|
||||
async def get_lora_cached_hash(lora_name: str):
|
||||
path_glob = glob.glob(LORA_PATH.as_posix() + f"/**/{lora_name}.*", recursive=True)
|
||||
paths = [lora for lora in path_glob if Path(lora).suffix in [".safetensors", ".ckpt", ".pt"]]
|
||||
if paths is not None and len(paths) > 0:
|
||||
path = paths[0]
|
||||
hash = hashes.sha256_from_cache(path, f"lora/{lora_name}", path.endswith(".safetensors"))
|
||||
if hash is not None:
|
||||
return hash
|
||||
|
||||
return None
|
||||
|
||||
def get_path_for_type(type):
|
||||
if type == "lora":
|
||||
return LORA_PATH
|
||||
|
||||
Reference in New Issue
Block a user