diff --git a/javascript/_result.js b/javascript/_result.js index 96129f1..823f26d 100644 --- a/javascript/_result.js +++ b/javascript/_result.js @@ -27,6 +27,7 @@ class AutocompleteResult { aliases = null; meta = null; hash = null; + sortKey = null; // Constructor constructor(text, type) { diff --git a/javascript/_utils.js b/javascript/_utils.js index 6ef46c0..069ee2e 100644 --- a/javascript/_utils.js +++ b/javascript/_utils.js @@ -81,6 +81,17 @@ async function fetchAPI(url, json = true, cache = false) { return await response.text(); } +async function postAPI(url, body) { + let response = await fetch(url, { method: "POST", body: body }); + + if (response.status != 200) { + console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText); + return null; + } + + return await response.json(); +} + // Extra network preview thumbnails async function getExtraNetworkPreviewURL(filename, type) { const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true); @@ -200,6 +211,33 @@ function observeElement(element, property, callback, delay = 0) { } } +// Sort functions +function getSortFunction() { + let criterium = TAC_CFG.modelSortOrder || "Name"; + return (a, b) => { + let textHolderA = a.type === ResultType.chant ? a.aliases : a.text; + let textHolderB = b.type === ResultType.chant ? b.aliases : b.text; + + switch (criterium) { + case "Date Modified": + let aParsed = parseFloat(a.sortKey || "-1"); + let bParsed = parseFloat(b.sortKey || "-1"); + + if (aParsed === bParsed) { + let aKey = a.sortKey || textHolderA; + let bKey = b.sortKey || textHolderB; + return aKey.localeCompare(bKey); + } + + return bParsed - aParsed; + default: + let aKey = a.sortKey || textHolderA; + let bKey = b.sortKey || textHolderB; + return aKey.localeCompare(bKey); + } + } +} + // Queue calling function to process global queues async function processQueue(queue, context, ...args) { for (let i = 0; i < queue.length; i++) { diff --git a/javascript/ext_embeddings.js b/javascript/ext_embeddings.js index e51aa4b..9c7bd44 100644 --- a/javascript/ext_embeddings.js +++ b/javascript/ext_embeddings.js @@ -16,7 +16,7 @@ class EmbeddingParser extends BaseTagParser { let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm); if (versionString) - tempResults = embeddings.filter(x => filterCondition(x) && x[1] && x[1] === versionString); // Filter by tagword + tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2] === versionString); // Filter by tagword else tempResults = embeddings.filter(x => filterCondition(x)); // Filter by tagword } else { @@ -27,7 +27,8 @@ class EmbeddingParser extends BaseTagParser { let finalResults = []; tempResults.forEach(t => { let result = new AutocompleteResult(t[0].trim(), ResultType.embedding) - result.meta = t[1] + " Embedding"; + result.sortKey = t[1]; + result.meta = t[2] + " Embedding"; finalResults.push(result); }); @@ -38,9 +39,9 @@ class EmbeddingParser extends BaseTagParser { async function load() { if (embeddings.length === 0) { try { - embeddings = (await readFile(`${tagBasePath}/temp/emb.txt`)).split("\n") - .filter(x => x.trim().length > 0) // Remove empty lines - .map(x => x.trim().split(",")); // Split into name, version type pairs + embeddings = (await loadCSV(`${tagBasePath}/temp/emb.txt`)) + .filter(x => x[0]?.trim().length > 0) // Remove empty lines + .map(x => [x[0].trim(), x[1], x[2]]); // Return name, sortKey, hash tuples } catch (e) { console.error("Error loading embeddings.txt: " + e); } diff --git a/javascript/ext_hypernets.js b/javascript/ext_hypernets.js index 7f564fd..3613b2a 100644 --- a/javascript/ext_hypernets.js +++ b/javascript/ext_hypernets.js @@ -8,7 +8,7 @@ class HypernetParser extends BaseTagParser { if (tagword !== "<" && tagword !== " x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm); - tempResults = hypernetworks.filter(x => filterCondition(x)); // Filter by tagword + tempResults = hypernetworks.filter(x => filterCondition(x[0])); // Filter by tagword } else { tempResults = hypernetworks; } @@ -16,8 +16,9 @@ class HypernetParser extends BaseTagParser { // Add final results let finalResults = []; tempResults.forEach(t => { - let result = new AutocompleteResult(t.trim(), ResultType.hypernetwork) + let result = new AutocompleteResult(t[0].trim(), ResultType.hypernetwork) result.meta = "Hypernetwork"; + result.sortKey = t[1]; finalResults.push(result); }); @@ -28,9 +29,9 @@ class HypernetParser extends BaseTagParser { async function load() { if (hypernetworks.length === 0) { try { - hypernetworks = (await readFile(`${tagBasePath}/temp/hyp.txt`)).split("\n") - .filter(x => x.trim().length > 0) //Remove empty lines - .map(x => x.trim()); // Remove carriage returns and padding if it exists + hypernetworks = (await loadCSV(`${tagBasePath}/temp/hyp.txt`)) + .filter(x => x[0]?.trim().length > 0) //Remove empty lines + .map(x => [x[0]?.trim(), x[1]]); // Remove carriage returns and padding if it exists } catch (e) { console.error("Error loading hypernetworks.txt: " + e); } diff --git a/javascript/ext_loras.js b/javascript/ext_loras.js index 9a94b75..22c879c 100644 --- a/javascript/ext_loras.js +++ b/javascript/ext_loras.js @@ -23,7 +23,8 @@ class LoraParser extends BaseTagParser { let result = new AutocompleteResult(name, ResultType.lora) result.meta = "Lora"; - result.hash = t[1]; + result.sortKey = t[1]; + result.hash = t[2]; finalResults.push(result); }); @@ -36,7 +37,7 @@ async function load() { try { loras = (await loadCSV(`${tagBasePath}/temp/lora.txt`)) .filter(x => x[0]?.trim().length > 0) // Remove empty lines - .map(x => [x[0]?.trim(), x[1]]); // Trim filenames and return the name, hash pairs + .map(x => [x[0]?.trim(), x[1], x[2]]); // Trim filenames and return the name, sortKey, hash pairs } catch (e) { console.error("Error loading lora.txt: " + e); } diff --git a/javascript/ext_lycos.js b/javascript/ext_lycos.js index dd1b439..ad6271e 100644 --- a/javascript/ext_lycos.js +++ b/javascript/ext_lycos.js @@ -23,7 +23,8 @@ class LycoParser extends BaseTagParser { let result = new AutocompleteResult(name, ResultType.lyco) result.meta = "Lyco"; - result.hash = t[1]; + result.sortKey = t[1]; + result.hash = t[2]; finalResults.push(result); }); @@ -36,7 +37,7 @@ async function load() { try { lycos = (await loadCSV(`${tagBasePath}/temp/lyco.txt`)) .filter(x => x[0]?.trim().length > 0) // Remove empty lines - .map(x => [x[0]?.trim(), x[1]]); // Trim filenames and return the name, hash pairs + .map(x => [x[0]?.trim(), x[1], x[2]]); // Trim filenames and return the name, sortKey, hash pairs } catch (e) { console.error("Error loading lyco.txt: " + e); } diff --git a/javascript/ext_wildcards.js b/javascript/ext_wildcards.js index cde0421..34361b8 100644 --- a/javascript/ext_wildcards.js +++ b/javascript/ext_wildcards.js @@ -85,13 +85,14 @@ class WildcardFileParser extends BaseTagParser { } else { result = new AutocompleteResult(wcFile[1].trim(), ResultType.wildcardFile); result.meta = "Wildcard file"; + result.sortKey = wcFile[2].trim(); } finalResults.push(result); alreadyAdded.set(wcFile[1], true); }); - finalResults.sort((a, b) => a.text.localeCompare(b.text)); + finalResults.sort(getSortFunction()); return finalResults; } @@ -100,17 +101,17 @@ class WildcardFileParser extends BaseTagParser { async function load() { if (wildcardFiles.length === 0 && wildcardExtFiles.length === 0) { try { - let wcFileArr = (await readFile(`${tagBasePath}/temp/wc.txt`)).split("\n"); - let wcBasePath = wcFileArr[0].trim(); // First line should be the base path + let wcFileArr = await loadCSV(`${tagBasePath}/temp/wc.txt`); + let wcBasePath = wcFileArr[0][0].trim(); // First line should be the base path wildcardFiles = wcFileArr.slice(1) - .filter(x => x.trim().length > 0) // Remove empty lines - .map(x => [wcBasePath, x.trim().replace(".txt", "")]); // Remove file extension & newlines + .filter(x => x[0]?.trim().length > 0) //Remove empty lines + .map(x => [wcBasePath, x[0]?.trim().replace(".txt", ""), x[1]]); // Remove file extension & newlines // To support multiple sources, we need to separate them using the provided "-----" strings - let wcExtFileArr = (await readFile(`${tagBasePath}/temp/wce.txt`)).split("\n"); + let wcExtFileArr = await loadCSV(`${tagBasePath}/temp/wce.txt`); let splitIndices = []; for (let index = 0; index < wcExtFileArr.length; index++) { - if (wcExtFileArr[index].trim() === "-----") { + if (wcExtFileArr[index][0].trim() === "-----") { splitIndices.push(index); } } @@ -121,12 +122,10 @@ async function load() { let end = splitIndices[i]; let wcExtFile = wcExtFileArr.slice(start, end); - let base = wcExtFile[0].trim() + "/"; + let base = wcExtFile[0][0].trim() + "/"; wcExtFile = wcExtFile.slice(1) - .filter(x => x.trim().length > 0) // Remove empty lines - .map(x => x.trim().replace(base, "").replace(".txt", "")); // Remove file extension & newlines; - - wcExtFile = wcExtFile.map(x => [base, x]); + .filter(x => x[0]?.trim().length > 0) //Remove empty lines + .map(x => [base, x[0]?.trim().replace(base, "").replace(".txt", ""), x[1]]); wildcardExtFiles.push(...wcExtFile); } diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index 775ba39..9331296 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -217,6 +217,7 @@ async function syncOptions() { useLycos: opts["tac_useLycos"], showWikiLinks: opts["tac_showWikiLinks"], showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"], + modelSortOrder: opts["tac_modelSortOrder"], // Insertion related settings replaceUnderscores: opts["tac_replaceUnderscores"], escapeParentheses: opts["tac_escapeParentheses"], @@ -269,6 +270,12 @@ async function syncOptions() { await loadTags(newCFG); } + // Refresh temp files if model sort order changed + // Contrary to the other loads, this one shouldn't happen on a first time load + if (TAC_CFG && newCFG.modelSortOrder !== TAC_CFG.modelSortOrder) { + await refreshTacTempFiles(true); + } + // Update CSS if maxResults changed if (TAC_CFG && newCFG.maxResults !== TAC_CFG.maxResults) { gradioApp().querySelectorAll(".autocompleteResults").forEach(r => { @@ -1007,36 +1014,28 @@ async function autocomplete(textArea, prompt, fixedTag = null) { if (resultCandidates && resultCandidates.length > 0) { // Flatten our candidate(s) results = resultCandidates.flat(); - // If there was more than one candidate, sort the results by text to mix them - // instead of having them added in the order of the parsers - let shouldSort = resultCandidates.length > 1; - if (shouldSort) { - results = results.sort((a, b) => { - let sortByA = a.type === ResultType.chant ? a.aliases : a.text; - let sortByB = b.type === ResultType.chant ? b.aliases : b.text; - return sortByA.localeCompare(sortByB); - }); + // Sort results + results = results.sort(getSortFunction()); - // Since some tags are kaomoji, we have to add the normal results in some cases - if (tagword.startsWith("<") || tagword.startsWith("*<")) { - // Create escaped search regex with support for * as a start placeholder - let searchRegex; - if (tagword.startsWith("*")) { - tagword = tagword.slice(1); - searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i'); - } else { - searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i'); - } - let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults); - - genericResults.forEach(g => { - let result = new AutocompleteResult(g[0].trim(), ResultType.tag) - result.category = g[1]; - result.count = g[2]; - result.aliases = g[3]; - results.push(result); - }); + // Since some tags are kaomoji, we have to add the normal results in some cases + if (tagword.startsWith("<") || tagword.startsWith("*<")) { + // Create escaped search regex with support for * as a start placeholder + let searchRegex; + if (tagword.startsWith("*")) { + tagword = tagword.slice(1); + searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i'); + } else { + searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i'); } + let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults); + + genericResults.forEach(g => { + let result = new AutocompleteResult(g[0].trim(), ResultType.tag) + result.category = g[1]; + result.count = g[2]; + result.aliases = g[3]; + results.push(result); + }); } } // Else search the normal tag list @@ -1223,8 +1222,8 @@ function navigateInList(textArea, event) { event.stopPropagation(); } -async function refreshTacTempFiles() { - setTimeout(async () => { +async function refreshTacTempFiles(api = false) { + const reload = async () => { wildcardFiles = []; wildcardExtFiles = []; umiWildcards = []; @@ -1236,7 +1235,16 @@ async function refreshTacTempFiles() { await processQueue(QUEUE_FILE_LOAD, null); console.log("TAC: Refreshed temp files"); - }, 2000); + } + + if (api) { + await postAPI("tacapi/v1/refresh-temp-files", null); + await reload(); + } else { + setTimeout(async () => { + await reload(); + }, 2000); + } } function addAutocompleteToArea(area) { diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 9094f77..0af571a 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -27,24 +27,16 @@ except Exception as e: # Not supported. # Sorting functions for extra networks / embeddings stuff sort_criteria = { - "Name": { - "key": lambda path, name: name.lower() if Path(name).parts > 1 else path.stem.lower(), - "reverse": False - }, - "Date Modified": { - "key": lambda path, name: path.stat().st_mtime, - "reverse": True - }, + "Name": lambda path, name, subpath: name.lower() if subpath else path.stem.lower(), + "Date Modified": lambda path, name, subpath: path.stat().st_mtime } -def sort_models(model_list, sort_method = None): +def sort_models(model_list, sort_method = None, name_has_subpath = False): """Sorts models according to the setting. Input: list of (full_path, display_name, {hash}) models. Returns models in the format of name, sort key, meta. Meta is optional and can be a hash, version string or other required info. - Whether the currently selected sort method needs to be reversed is provided - by an API endpoint to reduce duplication in temp files. """ if len(model_list) == 0: return model_list @@ -58,9 +50,9 @@ def sort_models(model_list, sort_method = None): # During merging on the JS side we need to re-sort anyway, so here only the sort criteria are calculated. # The list itself doesn't need to get sorted at this point. if len(model_list[0]) > 2: - results = [f'{name},"{sorter["key"](path, name)}",{meta}' for path, name, meta in model_list] + results = [f'{name},"{sorter(path, name, name_has_subpath)}",{meta}' for path, name, meta in model_list] else: - results = [f'{name},"{sorter["key"](path, name)}"' for path, name in model_list] + results = [f'{name},"{sorter(path, name, name_has_subpath)}"' for path, name in model_list] return results @@ -70,7 +62,7 @@ def get_wildcards(): resolved = [(w, w.relative_to(WILDCARD_PATH).as_posix()) for w in wildcard_files if w.name != "put wildcards here.txt"] - return sort_models(resolved) + return sort_models(resolved, name_has_subpath=True) def get_ext_wildcards(): @@ -82,7 +74,7 @@ def get_ext_wildcards(): resolved = [(w, w.relative_to(path).as_posix()) for w in path.rglob("*.txt") if w.name != "put wildcards here.txt"] - wildcard_files.extend(sort_models(resolved)) + wildcard_files.extend(sort_models(resolved, name_has_subpath=True)) wildcard_files.append("-----") return wildcard_files @@ -232,7 +224,6 @@ def get_lora(): loras_with_hash = [] for l in valid_loras: name = l.relative_to(LORA_PATH).as_posix() - name = f'"{name}"' if model_keyword_installed: hash = get_lora_simple_hash(l) else: @@ -253,7 +244,6 @@ def get_lyco(): lycos_with_hash = [] for ly in valid_lycos: name = ly.relative_to(LYCO_PATH).as_posix() - name = f'"{name}"' if model_keyword_installed: hash = get_lora_simple_hash(ly) else: @@ -525,10 +515,9 @@ def api_tac(_: gr.Blocks, app: FastAPI): except Exception as e: return JSONResponse({"error": e}, status_code=500) - @app.get("/tacapi/v1/sort-direction") - async def get_sort_direction(): - criterium = getattr(shared.opts, "tac_modelSortOrder", "Name") - return sort_criteria[criterium]['reverse'] if sort_criteria[criterium] else sort_criteria['Name']['reverse'] + @app.post("/tacapi/v1/refresh-temp-files") + async def api_refresh_temp_files(): + refresh_temp_files() @app.get("/tacapi/v1/lora-info/{lora_name}") async def get_lora_info(lora_name):