Add sorting to javascript side

Now uses the sortKey if available. Elements without a sortKey will always use name as fallback.
Removed sort direction API again since it needs to be modeled case-by-case in the javascript anyway.
This commit is contained in:
DominikDoom
2023-09-13 14:02:28 +02:00
parent 475ef59197
commit 44effca702
9 changed files with 117 additions and 78 deletions

View File

@@ -27,6 +27,7 @@ class AutocompleteResult {
aliases = null;
meta = null;
hash = null;
sortKey = null;
// Constructor
constructor(text, type) {

View File

@@ -81,6 +81,17 @@ async function fetchAPI(url, json = true, cache = false) {
return await response.text();
}
async function postAPI(url, body) {
let response = await fetch(url, { method: "POST", body: body });
if (response.status != 200) {
console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText);
return null;
}
return await response.json();
}
// Extra network preview thumbnails
async function getExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
@@ -200,6 +211,33 @@ function observeElement(element, property, callback, delay = 0) {
}
}
// Sort functions
function getSortFunction() {
let criterium = TAC_CFG.modelSortOrder || "Name";
return (a, b) => {
let textHolderA = a.type === ResultType.chant ? a.aliases : a.text;
let textHolderB = b.type === ResultType.chant ? b.aliases : b.text;
switch (criterium) {
case "Date Modified":
let aParsed = parseFloat(a.sortKey || "-1");
let bParsed = parseFloat(b.sortKey || "-1");
if (aParsed === bParsed) {
let aKey = a.sortKey || textHolderA;
let bKey = b.sortKey || textHolderB;
return aKey.localeCompare(bKey);
}
return bParsed - aParsed;
default:
let aKey = a.sortKey || textHolderA;
let bKey = b.sortKey || textHolderB;
return aKey.localeCompare(bKey);
}
}
}
// Queue calling function to process global queues
async function processQueue(queue, context, ...args) {
for (let i = 0; i < queue.length; i++) {

View File

@@ -16,7 +16,7 @@ class EmbeddingParser extends BaseTagParser {
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
if (versionString)
tempResults = embeddings.filter(x => filterCondition(x) && x[1] && x[1] === versionString); // Filter by tagword
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2] === versionString); // Filter by tagword
else
tempResults = embeddings.filter(x => filterCondition(x)); // Filter by tagword
} else {
@@ -27,7 +27,8 @@ class EmbeddingParser extends BaseTagParser {
let finalResults = [];
tempResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
result.meta = t[1] + " Embedding";
result.sortKey = t[1];
result.meta = t[2] + " Embedding";
finalResults.push(result);
});
@@ -38,9 +39,9 @@ class EmbeddingParser extends BaseTagParser {
async function load() {
if (embeddings.length === 0) {
try {
embeddings = (await readFile(`${tagBasePath}/temp/emb.txt`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.trim().split(",")); // Split into name, version type pairs
embeddings = (await loadCSV(`${tagBasePath}/temp/emb.txt`))
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
.map(x => [x[0].trim(), x[1], x[2]]); // Return name, sortKey, hash tuples
} catch (e) {
console.error("Error loading embeddings.txt: " + e);
}

View File

@@ -8,7 +8,7 @@ class HypernetParser extends BaseTagParser {
if (tagword !== "<" && tagword !== "<h:" && tagword !== "<hypernet:") {
let searchTerm = tagword.replace("<hypernet:", "").replace("<h:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
tempResults = hypernetworks.filter(x => filterCondition(x)); // Filter by tagword
tempResults = hypernetworks.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = hypernetworks;
}
@@ -16,8 +16,9 @@ class HypernetParser extends BaseTagParser {
// Add final results
let finalResults = [];
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.hypernetwork)
let result = new AutocompleteResult(t[0].trim(), ResultType.hypernetwork)
result.meta = "Hypernetwork";
result.sortKey = t[1];
finalResults.push(result);
});
@@ -28,9 +29,9 @@ class HypernetParser extends BaseTagParser {
async function load() {
if (hypernetworks.length === 0) {
try {
hypernetworks = (await readFile(`${tagBasePath}/temp/hyp.txt`)).split("\n")
.filter(x => x.trim().length > 0) //Remove empty lines
.map(x => x.trim()); // Remove carriage returns and padding if it exists
hypernetworks = (await loadCSV(`${tagBasePath}/temp/hyp.txt`))
.filter(x => x[0]?.trim().length > 0) //Remove empty lines
.map(x => [x[0]?.trim(), x[1]]); // Remove carriage returns and padding if it exists
} catch (e) {
console.error("Error loading hypernetworks.txt: " + e);
}

View File

@@ -23,7 +23,8 @@ class LoraParser extends BaseTagParser {
let result = new AutocompleteResult(name, ResultType.lora)
result.meta = "Lora";
result.hash = t[1];
result.sortKey = t[1];
result.hash = t[2];
finalResults.push(result);
});
@@ -36,7 +37,7 @@ async function load() {
try {
loras = (await loadCSV(`${tagBasePath}/temp/lora.txt`))
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
.map(x => [x[0]?.trim(), x[1]]); // Trim filenames and return the name, hash pairs
.map(x => [x[0]?.trim(), x[1], x[2]]); // Trim filenames and return the name, sortKey, hash pairs
} catch (e) {
console.error("Error loading lora.txt: " + e);
}

View File

@@ -23,7 +23,8 @@ class LycoParser extends BaseTagParser {
let result = new AutocompleteResult(name, ResultType.lyco)
result.meta = "Lyco";
result.hash = t[1];
result.sortKey = t[1];
result.hash = t[2];
finalResults.push(result);
});
@@ -36,7 +37,7 @@ async function load() {
try {
lycos = (await loadCSV(`${tagBasePath}/temp/lyco.txt`))
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
.map(x => [x[0]?.trim(), x[1]]); // Trim filenames and return the name, hash pairs
.map(x => [x[0]?.trim(), x[1], x[2]]); // Trim filenames and return the name, sortKey, hash pairs
} catch (e) {
console.error("Error loading lyco.txt: " + e);
}

View File

@@ -85,13 +85,14 @@ class WildcardFileParser extends BaseTagParser {
} else {
result = new AutocompleteResult(wcFile[1].trim(), ResultType.wildcardFile);
result.meta = "Wildcard file";
result.sortKey = wcFile[2].trim();
}
finalResults.push(result);
alreadyAdded.set(wcFile[1], true);
});
finalResults.sort((a, b) => a.text.localeCompare(b.text));
finalResults.sort(getSortFunction());
return finalResults;
}
@@ -100,17 +101,17 @@ class WildcardFileParser extends BaseTagParser {
async function load() {
if (wildcardFiles.length === 0 && wildcardExtFiles.length === 0) {
try {
let wcFileArr = (await readFile(`${tagBasePath}/temp/wc.txt`)).split("\n");
let wcBasePath = wcFileArr[0].trim(); // First line should be the base path
let wcFileArr = await loadCSV(`${tagBasePath}/temp/wc.txt`);
let wcBasePath = wcFileArr[0][0].trim(); // First line should be the base path
wildcardFiles = wcFileArr.slice(1)
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => [wcBasePath, x.trim().replace(".txt", "")]); // Remove file extension & newlines
.filter(x => x[0]?.trim().length > 0) //Remove empty lines
.map(x => [wcBasePath, x[0]?.trim().replace(".txt", ""), x[1]]); // Remove file extension & newlines
// To support multiple sources, we need to separate them using the provided "-----" strings
let wcExtFileArr = (await readFile(`${tagBasePath}/temp/wce.txt`)).split("\n");
let wcExtFileArr = await loadCSV(`${tagBasePath}/temp/wce.txt`);
let splitIndices = [];
for (let index = 0; index < wcExtFileArr.length; index++) {
if (wcExtFileArr[index].trim() === "-----") {
if (wcExtFileArr[index][0].trim() === "-----") {
splitIndices.push(index);
}
}
@@ -121,12 +122,10 @@ async function load() {
let end = splitIndices[i];
let wcExtFile = wcExtFileArr.slice(start, end);
let base = wcExtFile[0].trim() + "/";
let base = wcExtFile[0][0].trim() + "/";
wcExtFile = wcExtFile.slice(1)
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.trim().replace(base, "").replace(".txt", "")); // Remove file extension & newlines;
wcExtFile = wcExtFile.map(x => [base, x]);
.filter(x => x[0]?.trim().length > 0) //Remove empty lines
.map(x => [base, x[0]?.trim().replace(base, "").replace(".txt", ""), x[1]]);
wildcardExtFiles.push(...wcExtFile);
}

View File

@@ -217,6 +217,7 @@ async function syncOptions() {
useLycos: opts["tac_useLycos"],
showWikiLinks: opts["tac_showWikiLinks"],
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
modelSortOrder: opts["tac_modelSortOrder"],
// Insertion related settings
replaceUnderscores: opts["tac_replaceUnderscores"],
escapeParentheses: opts["tac_escapeParentheses"],
@@ -269,6 +270,12 @@ async function syncOptions() {
await loadTags(newCFG);
}
// Refresh temp files if model sort order changed
// Contrary to the other loads, this one shouldn't happen on a first time load
if (TAC_CFG && newCFG.modelSortOrder !== TAC_CFG.modelSortOrder) {
await refreshTacTempFiles(true);
}
// Update CSS if maxResults changed
if (TAC_CFG && newCFG.maxResults !== TAC_CFG.maxResults) {
gradioApp().querySelectorAll(".autocompleteResults").forEach(r => {
@@ -1007,36 +1014,28 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
if (resultCandidates && resultCandidates.length > 0) {
// Flatten our candidate(s)
results = resultCandidates.flat();
// If there was more than one candidate, sort the results by text to mix them
// instead of having them added in the order of the parsers
let shouldSort = resultCandidates.length > 1;
if (shouldSort) {
results = results.sort((a, b) => {
let sortByA = a.type === ResultType.chant ? a.aliases : a.text;
let sortByB = b.type === ResultType.chant ? b.aliases : b.text;
return sortByA.localeCompare(sortByB);
});
// Sort results
results = results.sort(getSortFunction());
// Since some tags are kaomoji, we have to add the normal results in some cases
if (tagword.startsWith("<") || tagword.startsWith("*<")) {
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
tagword = tagword.slice(1);
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
} else {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
result.count = g[2];
result.aliases = g[3];
results.push(result);
});
// Since some tags are kaomoji, we have to add the normal results in some cases
if (tagword.startsWith("<") || tagword.startsWith("*<")) {
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
tagword = tagword.slice(1);
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
} else {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
result.count = g[2];
result.aliases = g[3];
results.push(result);
});
}
}
// Else search the normal tag list
@@ -1223,8 +1222,8 @@ function navigateInList(textArea, event) {
event.stopPropagation();
}
async function refreshTacTempFiles() {
setTimeout(async () => {
async function refreshTacTempFiles(api = false) {
const reload = async () => {
wildcardFiles = [];
wildcardExtFiles = [];
umiWildcards = [];
@@ -1236,7 +1235,16 @@ async function refreshTacTempFiles() {
await processQueue(QUEUE_FILE_LOAD, null);
console.log("TAC: Refreshed temp files");
}, 2000);
}
if (api) {
await postAPI("tacapi/v1/refresh-temp-files", null);
await reload();
} else {
setTimeout(async () => {
await reload();
}, 2000);
}
}
function addAutocompleteToArea(area) {

View File

@@ -27,24 +27,16 @@ except Exception as e: # Not supported.
# Sorting functions for extra networks / embeddings stuff
sort_criteria = {
"Name": {
"key": lambda path, name: name.lower() if Path(name).parts > 1 else path.stem.lower(),
"reverse": False
},
"Date Modified": {
"key": lambda path, name: path.stat().st_mtime,
"reverse": True
},
"Name": lambda path, name, subpath: name.lower() if subpath else path.stem.lower(),
"Date Modified": lambda path, name, subpath: path.stat().st_mtime
}
def sort_models(model_list, sort_method = None):
def sort_models(model_list, sort_method = None, name_has_subpath = False):
"""Sorts models according to the setting.
Input: list of (full_path, display_name, {hash}) models.
Returns models in the format of name, sort key, meta.
Meta is optional and can be a hash, version string or other required info.
Whether the currently selected sort method needs to be reversed is provided
by an API endpoint to reduce duplication in temp files.
"""
if len(model_list) == 0:
return model_list
@@ -58,9 +50,9 @@ def sort_models(model_list, sort_method = None):
# During merging on the JS side we need to re-sort anyway, so here only the sort criteria are calculated.
# The list itself doesn't need to get sorted at this point.
if len(model_list[0]) > 2:
results = [f'{name},"{sorter["key"](path, name)}",{meta}' for path, name, meta in model_list]
results = [f'{name},"{sorter(path, name, name_has_subpath)}",{meta}' for path, name, meta in model_list]
else:
results = [f'{name},"{sorter["key"](path, name)}"' for path, name in model_list]
results = [f'{name},"{sorter(path, name, name_has_subpath)}"' for path, name in model_list]
return results
@@ -70,7 +62,7 @@ def get_wildcards():
resolved = [(w, w.relative_to(WILDCARD_PATH).as_posix())
for w in wildcard_files
if w.name != "put wildcards here.txt"]
return sort_models(resolved)
return sort_models(resolved, name_has_subpath=True)
def get_ext_wildcards():
@@ -82,7 +74,7 @@ def get_ext_wildcards():
resolved = [(w, w.relative_to(path).as_posix())
for w in path.rglob("*.txt")
if w.name != "put wildcards here.txt"]
wildcard_files.extend(sort_models(resolved))
wildcard_files.extend(sort_models(resolved, name_has_subpath=True))
wildcard_files.append("-----")
return wildcard_files
@@ -232,7 +224,6 @@ def get_lora():
loras_with_hash = []
for l in valid_loras:
name = l.relative_to(LORA_PATH).as_posix()
name = f'"{name}"'
if model_keyword_installed:
hash = get_lora_simple_hash(l)
else:
@@ -253,7 +244,6 @@ def get_lyco():
lycos_with_hash = []
for ly in valid_lycos:
name = ly.relative_to(LYCO_PATH).as_posix()
name = f'"{name}"'
if model_keyword_installed:
hash = get_lora_simple_hash(ly)
else:
@@ -525,10 +515,9 @@ def api_tac(_: gr.Blocks, app: FastAPI):
except Exception as e:
return JSONResponse({"error": e}, status_code=500)
@app.get("/tacapi/v1/sort-direction")
async def get_sort_direction():
criterium = getattr(shared.opts, "tac_modelSortOrder", "Name")
return sort_criteria[criterium]['reverse'] if sort_criteria[criterium] else sort_criteria['Name']['reverse']
@app.post("/tacapi/v1/refresh-temp-files")
async def api_refresh_temp_files():
refresh_temp_files()
@app.get("/tacapi/v1/lora-info/{lora_name}")
async def get_lora_info(lora_name):