Compare commits

...

60 Commits

Author SHA1 Message Date
DominikDoom
119a3ad51f Merge branch 'feature-sort-by-frequent-use' 2024-04-13 15:08:57 +02:00
DominikDoom
c820a22149 Merge branch 'main' into feature-sort-by-frequent-use 2024-04-13 15:06:53 +02:00
DominikDoom
eb1e1820f9 English dictionary and derpibooru tag lists, thanks to @Nenotriple
Closes #280
2024-04-13 14:42:21 +02:00
DominikDoom
ef59cff651 Move last used date check guard to SQL side, implement max cap
- Server side date comparison and cap check further improve js sort performance
- The alias check has also been moved out of calculateUsageBias to support the new cap system
2024-03-16 16:44:43 +01:00
DominikDoom
a454383c43 Merge branch 'main' into feature-sort-by-frequent-use 2024-03-03 13:52:32 +01:00
DominikDoom
bec567fe26 Merge pull request #277 from Symbiomatrix/embpath
Embedding relative path.
2024-03-03 13:50:33 +01:00
DominikDoom
d4041096c9 Merge pull request #275 from Symbiomatrix/wildcard 2024-03-03 13:43:27 +01:00
Symbiomatrix
0903259ddf Update ext_embeddings.js 2024-03-03 13:49:36 +02:00
Symbiomatrix
f3e64b1fa5 Update tag_autocomplete_helper.py 2024-03-03 13:47:39 +02:00
DominikDoom
312cec5d71 Merge pull request #276 from Symbiomatrix/modkey
[Feature] Modifier keys for list navigation.
2024-03-03 11:46:05 +01:00
Symbiomatrix
b71e6339bd Fix tabs. 2024-03-03 12:21:29 +02:00
Symbiomatrix
7ddbc3c0b2 Update tagAutocomplete.js 2024-03-03 04:23:13 +02:00
DominikDoom
4c2ef8f770 Merge branch 'main' into feature-sort-by-frequent-use 2024-02-09 19:23:52 +01:00
DominikDoom
97c5e4f53c Fix embeddings not loading in SD.Next diffusers backend
Fixes #273
2024-02-09 19:06:23 +01:00
DominikDoom
1d8d9f64b5 Update danbooru.csv with 2023 data
Closes #274
2024-02-09 17:59:06 +01:00
DominikDoom
7437850600 Merge branch 'main' into feature-sort-by-frequent-use 2024-02-04 14:46:29 +01:00
DominikDoom
829a4a7b89 Merge pull request #272 from rkfg/lora-visibility
Hide loras according to settings
2024-02-04 14:45:48 +01:00
rkfg
22472ac8ad Hide loras according to settings 2024-02-04 16:44:33 +03:00
DominikDoom
5f77fa26d3 Update README.md
Add feedback wanted notice
2024-02-04 12:04:02 +01:00
DominikDoom
f810b2dd8f Merge branch 'main' into feature-sort-by-frequent-use 2024-01-27 12:39:40 +01:00
DominikDoom
95200e82e1 Merge branch 'main' into feature-sort-by-frequent-use 2024-01-26 17:04:53 +01:00
DominikDoom
a966be7546 Merge branch 'main' into feature-sort-by-frequent-use 2024-01-26 16:21:15 +01:00
DominikDoom
342fbc9041 Pre-calculate usage bias for all results instead of in the sort function
Roughly doubles the sort performance
2024-01-19 21:10:09 +01:00
DominikDoom
d496569c9a Cache sort key for small performance increase 2024-01-19 20:17:14 +01:00
DominikDoom
30c9593d3d Merge branch 'main' into feature-sort-by-frequent-use 2023-12-12 14:23:18 +01:00
DominikDoom
57076060df Merge branch 'main' into feature-sort-by-frequent-use 2023-12-11 11:43:26 +01:00
DominikDoom
20b6635a2a WIP usage info table
Might get replaced with gradio depending on how well it works
2023-12-04 15:00:19 +01:00
DominikDoom
1fe8f26670 Add explanatory tooltip and inline reset ability
Also add tooltip for wiki links
2023-12-04 13:56:15 +01:00
DominikDoom
e82e958c3e Fix alias check for non-aliased tag types 2023-11-29 18:15:59 +01:00
DominikDoom
2dd48eab79 Fix error with db return value for no matches 2023-11-29 18:14:14 +01:00
DominikDoom
4df90f5c95 Don't frequency sort alias results by default
with an option to enable it if desired
2023-11-29 18:04:50 +01:00
DominikDoom
a156214a48 Last used & min count settings
Also some performance improvements
2023-11-29 17:45:51 +01:00
DominikDoom
15478e73b5 Count positive / negative prompt usage separately 2023-11-29 15:22:41 +01:00
DominikDoom
434301738a Merge branch 'main' into feature-sort-by-frequent-use 2023-11-05 13:30:51 +01:00
DominikDoom
4fba7baa69 Merge branch 'main' into feature-sort-by-frequent-use 2023-10-06 18:36:24 +02:00
DominikDoom
7128efc4f4 Apply same fix to extra tags
Count now defaults to max safe integer, which simplifies the sort function
Before, it resulted in really bad performance
2023-10-02 00:45:48 +02:00
DominikDoom
bd0ddfbb24 Fix embeddings not at top
(only affecting the "include embeddings in normal results" option)
2023-10-02 00:16:58 +02:00
DominikDoom
3108daf0e8 Remove kaomoji inclusion in < search
because it interfered with use count searching and is not commonly needed
2023-10-01 23:51:35 +02:00
DominikDoom
363895494b Fix hide after insert race condition 2023-10-01 23:17:12 +02:00
DominikDoom
04551a8132 Don't await increase, limit to 2k for performance 2023-10-01 22:59:28 +02:00
DominikDoom
ffc0e378d3 Add different sorting functions 2023-10-01 22:44:35 +02:00
DominikDoom
440f109f1f Use POST + body to get around URL length limit 2023-10-01 22:30:47 +02:00
DominikDoom
80fb247dbe Sort results by usage count 2023-10-01 21:44:24 +02:00
DominikDoom
d7e98200a8 Use count increase logic 2023-09-26 12:20:15 +02:00
DominikDoom
ac790c8ede Return dict instead of array for clarity 2023-09-26 12:12:46 +02:00
DominikDoom
22365ec8d6 Add missing type return to list request 2023-09-26 12:02:36 +02:00
DominikDoom
030a83aa4d Use query parameter instead of path to fix wildcard subfolder issues 2023-09-26 11:55:12 +02:00
DominikDoom
460d32a4ed Ensure proper reload, fix error message 2023-09-26 11:45:42 +02:00
DominikDoom
581bf1e6a4 Use composite key with name & type to prevent collisions 2023-09-26 11:35:24 +02:00
DominikDoom
74ea5493e5 Add rest of utils functions 2023-09-26 10:58:46 +02:00
DominikDoom
6cf9acd6ab Catch sqlite exceptions, add tag list endpoint 2023-09-24 20:06:40 +02:00
DominikDoom
109a8a155e Change endpoint name for consistency 2023-09-24 18:00:41 +02:00
DominikDoom
3caa1b51ed Add db to gitignore 2023-09-24 17:59:39 +02:00
DominikDoom
b44c36425a Fix db load version comparison, add sort options 2023-09-24 17:59:14 +02:00
DominikDoom
1e81403180 Safety catches for DB API access 2023-09-24 16:50:39 +02:00
DominikDoom
0f487a5c5c WIP database setup inspired by ImageBrowser 2023-09-24 16:28:32 +02:00
DominikDoom
2baa12fea3 Merge branch 'main' into feature-sort-by-frequent-use 2023-09-24 15:34:18 +02:00
DominikDoom
67eeb5fbf6 Merge branch 'main' into feature-sort-by-frequent-use 2023-09-19 12:14:12 +02:00
DominikDoom
11ffed8afc Merge branch 'feature-sorting' into feature-sort-by-frequent-use 2023-09-15 16:37:34 +02:00
DominikDoom
0a8e7d7d84 Stub API setup for tag usage stats 2023-09-12 14:10:15 +02:00
16 changed files with 293344 additions and 84394 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
tags/temp/
__pycache__/
tags/tag_frequency.db

View File

@@ -20,6 +20,10 @@ Booru style tag autocompletion for the AUTOMATIC1111 Stable Diffusion WebUI
</div>
<br/>
#### ⚠️ Notice:
I am currently looking for feedback on a new feature I'm working on and want to release soon.<br/>
Please check [the announcement post](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/270) for more info if you are interested to help.
# 📄 Description
Tag Autocomplete is an extension for the popular [AUTOMATIC1111 web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for Stable Diffusion.

View File

@@ -24,7 +24,8 @@ class AutocompleteResult {
// Additional info, only used in some cases
category = null;
count = null;
count = Number.MAX_SAFE_INTEGER;
usageBias = null;
aliases = null;
meta = null;
hash = null;

View File

@@ -80,8 +80,12 @@ async function fetchAPI(url, json = true, cache = false) {
return await response.text();
}
async function postAPI(url, body) {
let response = await fetch(url, { method: "POST", body: body });
async function postAPI(url, body = null) {
let response = await fetch(url, {
method: "POST",
headers: {'Content-Type': 'application/json'},
body: body
});
if (response.status != 200) {
console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText);
@@ -91,6 +95,17 @@ async function postAPI(url, body) {
return await response.json();
}
async function putAPI(url, body = null) {
let response = await fetch(url, { method: "PUT", body: body });
if (response.status != 200) {
console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText);
return null;
}
return await response.json();
}
// Extra network preview thumbnails
async function getExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
@@ -180,6 +195,104 @@ function flatten(obj, roots = [], sep = ".") {
);
}
// Calculate biased tag score based on post count and frequent usage
function calculateUsageBias(result, count, uses) {
// Check setting conditions
if (uses < TAC_CFG.frequencyMinCount) {
uses = 0;
} else if (uses != 0) {
result.usageBias = true;
}
switch (TAC_CFG.frequencyFunction) {
case "Logarithmic (weak)":
return Math.log(1 + count) + Math.log(1 + uses);
case "Logarithmic (strong)":
return Math.log(1 + count) + 2 * Math.log(1 + uses);
case "Usage first":
return uses;
default:
return count;
}
}
// Beautify return type for easier parsing
function mapUseCountArray(useCounts, posAndNeg = false) {
return useCounts.map(useCount => {
if (posAndNeg) {
return {
"name": useCount[0],
"type": useCount[1],
"count": useCount[2],
"negCount": useCount[3],
"lastUseDate": useCount[4]
}
}
return {
"name": useCount[0],
"type": useCount[1],
"count": useCount[2],
"lastUseDate": useCount[3]
}
});
}
// Call API endpoint to increase bias of tag in the database
function increaseUseCount(tagName, type, negative = false) {
postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`);
}
// Get use count of tag from the database
async function getUseCount(tagName, type, negative = false) {
return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false))["result"];
}
async function getUseCounts(tagNames, types, negative = false) {
// While semantically weird, we have to use POST here for the body, as urls are limited in length
const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative});
const rawArray = (await postAPI(`tacapi/v1/get-use-count-list`, body))["result"]
return mapUseCountArray(rawArray);
}
async function getAllUseCounts() {
const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts`))["result"];
return mapUseCountArray(rawArray, true);
}
async function resetUseCount(tagName, type, resetPosCount, resetNegCount) {
await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`);
}
function createTagUsageTable(tagCounts) {
// Create table
let tagTable = document.createElement("table");
tagTable.innerHTML =
`<thead>
<tr>
<td>Name</td>
<td>Type</td>
<td>Count(+)</td>
<td>Count(-)</td>
<td>Last used</td>
</tr>
</thead>`;
tagTable.id = "tac_tagUsageTable"
tagCounts.forEach(t => {
let tr = document.createElement("tr");
// Fill values
let values = [t.name, t.type-1, t.count, t.negCount, t.lastUseDate]
values.forEach(v => {
let td = document.createElement("td");
td.innerText = v;
tr.append(td);
});
// Add delete/reset button
let delButton = document.createElement("button");
delButton.innerText = "🗑️";
delButton.title = "Reset count";
tr.append(delButton);
tagTable.append(tr)
});
return tagTable;
}
// Sliding window function to get possible combination groups of an array
function toNgrams(inputArray, size) {
@@ -189,7 +302,11 @@ function toNgrams(inputArray, size) {
);
}
function escapeRegExp(string) {
function escapeRegExp(string, wildcardMatching = false) {
if (wildcardMatching) {
// Escape all characters except asterisks and ?, which should be treated separately as placeholders.
return string.replace(/[-[\]{}()+.,\\^$|#\s]/g, '\\$&').replace(/\*/g, '.*').replace(/\?/g, '.');
}
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
}
function escapeHTML(unsafeText) {
@@ -238,12 +355,19 @@ function getSortFunction() {
let criterion = TAC_CFG.modelSortOrder || "Name";
const textSort = (a, b, reverse = false) => {
const textHolderA = a.type === ResultType.chant ? a.aliases : a.text;
const textHolderB = b.type === ResultType.chant ? b.aliases : b.text;
// Assign keys so next sort is faster
if (!a.sortKey) {
a.sortKey = a.type === ResultType.chant
? a.aliases
: a.text;
}
if (!b.sortKey) {
b.sortKey = b.type === ResultType.chant
? b.aliases
: b.text;
}
const aKey = a.sortKey || textHolderA;
const bKey = b.sortKey || textHolderB;
return reverse ? bKey.localeCompare(aKey) : aKey.localeCompare(bKey);
return reverse ? b.sortKey.localeCompare(a.sortKey) : a.sortKey.localeCompare(b.sortKey);
}
const numericSort = (a, b, reverse = false) => {
const noKey = reverse ? "-1" : Number.MAX_SAFE_INTEGER;

View File

@@ -7,7 +7,10 @@ class ChantParser extends BaseTagParser {
let tempResults = [];
if (tagword !== "<" && tagword !== "<c:") {
let searchTerm = tagword.replace("<chant:", "").replace("<c:", "").replace("<", "");
let filterCondition = x => x.terms.toLowerCase().includes(searchTerm) || x.name.toLowerCase().includes(searchTerm);
let filterCondition = x => {
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
return regex.test(x.terms.toLowerCase()) || regex.test(x.name.toLowerCase());
};
tempResults = chants.filter(x => filterCondition(x)); // Filter by tagword
} else {
tempResults = chants;
@@ -51,4 +54,4 @@ PARSERS.push(new ChantParser(CHANT_TRIGGER));
// Add our utility functions to their respective queues
QUEUE_FILE_LOAD.push(load);
QUEUE_SANITIZE.push(sanitize);
QUEUE_AFTER_CONFIG_CHANGE.push(load);
QUEUE_AFTER_CONFIG_CHANGE.push(load);

View File

@@ -16,7 +16,10 @@ class EmbeddingParser extends BaseTagParser {
searchTerm = searchTerm.slice(3);
}
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
let filterCondition = x => {
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
return regex.test(x[0].toLowerCase()) || regex.test(x[0].toLowerCase().replaceAll(" ", "_"));
};
if (versionString)
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2].toLowerCase() === versionString.toLowerCase()); // Filter by tagword
@@ -29,7 +32,11 @@ class EmbeddingParser extends BaseTagParser {
// Add final results
let finalResults = [];
tempResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
let lastDot = t[0].lastIndexOf(".") > -1 ? t[0].lastIndexOf(".") : t[0].length;
let lastSlash = t[0].lastIndexOf("/") > -1 ? t[0].lastIndexOf("/") : -1;
let name = t[0].trim().substring(lastSlash + 1, lastDot);
let result = new AutocompleteResult(name, ResultType.embedding)
result.sortKey = t[1];
result.meta = t[2] + " Embedding";
finalResults.push(result);
@@ -62,4 +69,4 @@ PARSERS.push(new EmbeddingParser(EMB_TRIGGER));
// Add our utility functions to their respective queues
QUEUE_FILE_LOAD.push(load);
QUEUE_SANITIZE.push(sanitize);
QUEUE_SANITIZE.push(sanitize);

View File

@@ -7,7 +7,10 @@ class HypernetParser extends BaseTagParser {
let tempResults = [];
if (tagword !== "<" && tagword !== "<h:" && tagword !== "<hypernet:") {
let searchTerm = tagword.replace("<hypernet:", "").replace("<h:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
let filterCondition = x => {
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
};
tempResults = hypernetworks.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = hypernetworks;
@@ -49,4 +52,4 @@ PARSERS.push(new HypernetParser(HYP_TRIGGER));
// Add our utility functions to their respective queues
QUEUE_FILE_LOAD.push(load);
QUEUE_SANITIZE.push(sanitize);
QUEUE_SANITIZE.push(sanitize);

View File

@@ -7,7 +7,10 @@ class LoraParser extends BaseTagParser {
let tempResults = [];
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lora:") {
let searchTerm = tagword.replace("<lora:", "").replace("<l:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
let filterCondition = x => {
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
};
tempResults = loras.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = loras;
@@ -61,4 +64,4 @@ PARSERS.push(new LoraParser(LORA_TRIGGER));
// Add our utility functions to their respective queues
QUEUE_FILE_LOAD.push(load);
QUEUE_SANITIZE.push(sanitize);
QUEUE_SANITIZE.push(sanitize);

View File

@@ -7,7 +7,10 @@ class LycoParser extends BaseTagParser {
let tempResults = [];
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lyco:" && tagword !== "<lora:") {
let searchTerm = tagword.replace("<lyco:", "").replace("<lora:", "").replace("<l:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
let filterCondition = x => {
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
};
tempResults = lycos.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = lycos;
@@ -62,4 +65,4 @@ PARSERS.push(new LycoParser(LYCO_TRIGGER));
// Add our utility functions to their respective queues
QUEUE_FILE_LOAD.push(load);
QUEUE_SANITIZE.push(sanitize);
QUEUE_SANITIZE.push(sanitize);

View File

@@ -18,7 +18,10 @@ class StyleParser extends BaseTagParser {
if (tagword !== matchGroups[1]) {
let searchTerm = tagword.replace(matchGroups[1], "");
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
let filterCondition = x => {
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
return regex.test(x[0].toLowerCase()) || regex.test(x[0].toLowerCase().replaceAll(" ", "_"));
};
tempResults = styleNames.filter(x => filterCondition(x)); // Filter by tagword
} else {
tempResults = styleNames;
@@ -64,4 +67,4 @@ PARSERS.push(new StyleParser(STYLE_TRIGGER));
// Add our utility functions to their respective queues
QUEUE_FILE_LOAD.push(load);
QUEUE_SANITIZE.push(sanitize);
QUEUE_SANITIZE.push(sanitize);

View File

@@ -86,6 +86,10 @@ const autocompleteCSS = `
white-space: nowrap;
color: var(--meta-text-color);
}
.acMetaText.biased::before {
content: "✨";
margin-right: 2px;
}
.acWikiLink {
padding: 0.5rem;
margin: -0.5rem 0 -0.5rem -0.5rem;
@@ -216,11 +220,17 @@ async function syncOptions() {
includeEmbeddingsInNormalResults: opts["tac_includeEmbeddingsInNormalResults"],
useHypernetworks: opts["tac_useHypernetworks"],
useLoras: opts["tac_useLoras"],
useLycos: opts["tac_useLycos"],
useLycos: opts["tac_useLycos"],
useLoraPrefixForLycos: opts["tac_useLoraPrefixForLycos"],
showWikiLinks: opts["tac_showWikiLinks"],
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
modelSortOrder: opts["tac_modelSortOrder"],
frequencySort: opts["tac_frequencySort"],
frequencyFunction: opts["tac_frequencyFunction"],
frequencyMinCount: opts["tac_frequencyMinCount"],
frequencyMaxAge: opts["tac_frequencyMaxAge"],
frequencyRecommendCap: opts["tac_frequencyRecommendCap"],
frequencyIncludeAlias: opts["tac_frequencyIncludeAlias"],
useStyleVars: opts["tac_useStyleVars"],
// Insertion related settings
replaceUnderscores: opts["tac_replaceUnderscores"],
@@ -466,6 +476,37 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
}
}
// Frequency db update
if (TAC_CFG.frequencySort) {
let name = null;
switch (tagType) {
case ResultType.wildcardFile:
case ResultType.yamlWildcard:
// We only want to update the frequency for a full wildcard, not partial paths
if (sanitizedText.endsWith("__"))
name = text
break;
case ResultType.chant:
// Chants use a slightly different format
name = result.aliases;
break;
default:
name = text;
break;
}
if (name && name.length > 0) {
// Check if it's a negative prompt
let textAreaId = getTextAreaIdentifier(textArea);
let isNegative = textAreaId.includes("n");
// Sanitize name for API call
name = encodeURIComponent(name)
// Call API & update db
increaseUseCount(name, tagType, isNegative)
}
}
var prompt = textArea.value;
// Edit prompt text
@@ -574,6 +615,8 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
tacSelfTrigger = true;
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
// Uses a built-in method from the webui's ui.js which also already accounts for event target
if (tagType === ResultType.wildcardTag || tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard)
tacSelfTrigger = true;
updateInput(textArea);
// Update previous tags with the edited prompt to prevent re-searching the same term
@@ -688,6 +731,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
let wikiLink = document.createElement("a");
wikiLink.classList.add("acWikiLink");
wikiLink.innerText = "?";
wikiLink.title = "Open external wiki page for this tag"
let linkPart = displayText;
// Only use alias result if it is one
@@ -733,7 +777,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
}
// Post count
if (result.count && !isNaN(result.count)) {
if (result.count && !isNaN(result.count) && result.count !== Number.MAX_SAFE_INTEGER) {
let postCount = result.count;
let formatter;
@@ -765,8 +809,24 @@ function addResultsToList(textArea, results, tagword, resetList) {
flexDiv.appendChild(metaDiv);
}
// Add small ✨ marker to indicate usage sorting
if (result.usageBias) {
flexDiv.querySelector(".acMetaText").classList.add("biased");
flexDiv.title = "✨ Frequent tag. Ctrl/Cmd + click to reset usage count."
}
// Check if it's a negative prompt
let isNegative = textAreaId.includes("n");
// Add listener
li.addEventListener("click", function () { insertTextAtCursor(textArea, result, tagword); });
li.addEventListener("click", (e) => {
if (e.ctrlKey || e.metaKey) {
resetUseCount(result.text, result.type, !isNegative, isNegative);
flexDiv.querySelector(".acMetaText").classList.remove("biased");
} else {
insertTextAtCursor(textArea, result, tagword);
}
});
// Add element to list
resultsList.appendChild(li);
}
@@ -1034,6 +1094,9 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
resultCountBeforeNormalTags = 0;
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
// Needed for slicing check later
let normalTags = false;
// Process all parsers
let resultCandidates = (await processParsers(textArea, prompt))?.filter(x => x.length > 0);
// If one ore more result candidates match, use their results
@@ -1043,32 +1106,12 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
// Sort results, but not if it's umi tags since they are sorted by count
if (!(resultCandidates.length === 1 && results[0].type === ResultType.umiWildcard))
results = results.sort(getSortFunction());
// Since some tags are kaomoji, we have to add the normal results in some cases
if (tagword.startsWith("<") || tagword.startsWith("*<")) {
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
tagword = tagword.slice(1);
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
} else {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
result.count = g[2];
result.aliases = g[3];
results.push(result);
});
}
}
// Else search the normal tag list
if (!resultCandidates || resultCandidates.length === 0
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
) {
normalTags = true;
resultCountBeforeNormalTags = results.length;
// Create escaped search regex with support for * as a start placeholder
@@ -1123,11 +1166,6 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
results = results.concat(extraResults);
}
}
// Slice if the user has set a max result count
if (!TAC_CFG.showAllResults) {
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
}
// Guard for empty results
@@ -1137,6 +1175,57 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
return;
}
// Sort again with frequency / usage count if enabled
if (TAC_CFG.frequencySort) {
// Split our results into a list of names and types
let tagNames = [];
let aliasNames = [];
let types = [];
// Limit to 2k for performance reasons
const aliasTypes = [ResultType.tag, ResultType.extra];
results.slice(0,2000).forEach(r => {
const name = r.type === ResultType.chant ? r.aliases : r.text;
// Add to alias list or tag list depending on if the name includes the tagword
// (the same criteria is used in the filter in calculateUsageBias)
if (aliasTypes.includes(r.type) && !name.includes(tagword)) {
aliasNames.push(name);
} else {
tagNames.push(name);
}
types.push(r.type);
});
// Check if it's a negative prompt
let textAreaId = getTextAreaIdentifier(textArea);
let isNegative = textAreaId.includes("n");
// Request use counts from the DB
const names = TAC_CFG.frequencyIncludeAlias ? tagNames.concat(aliasNames) : tagNames;
const counts = await getUseCounts(names, types, isNegative);
// Pre-calculate weights to prevent duplicate work
const resultBiasMap = new Map();
results.forEach(result => {
const name = result.type === ResultType.chant ? result.aliases : result.text;
const type = result.type;
// Find matching pair from DB results
const useStats = counts.find(c => c.name === name && c.type === type);
const uses = useStats?.count || 0;
// Calculate & set weight
const weight = calculateUsageBias(result, result.count, uses)
resultBiasMap.set(result, weight);
});
// Actual sorting with the pre-calculated weights
results = results.sort((a, b) => {
return resultBiasMap.get(b) - resultBiasMap.get(a);
});
}
// Slice if the user has set a max result count and we are not in a extra networks / wildcard list
if (!TAC_CFG.showAllResults && normalTags) {
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
addResultsToList(textArea, results, tagword, true);
showResults(textArea);
}
@@ -1159,12 +1248,17 @@ function navigateInList(textArea, event) {
if (!validKeys.includes(event.key)) return;
if (!isVisible(textArea)) return
// Return if ctrl key is pressed to not interfere with weight editing shortcut
if (event.ctrlKey || event.altKey || event.shiftKey || event.metaKey) return;
// Add modifier keys to base as text+.
let modKey = "";
if (event.ctrlKey) modKey += "Ctrl+";
if (event.altKey) modKey += "Alt+";
if (event.shiftKey) modKey += "Shift+";
if (event.metaKey) modKey += "Meta+";
modKey += event.key;
oldSelectedTag = selectedTag;
switch (event.key) {
switch (modKey) {
case keys["MoveUp"]:
if (selectedTag === null) {
selectedTag = resultCount - 1;
@@ -1235,6 +1329,8 @@ function navigateInList(textArea, event) {
case keys["Close"]:
hideResults(textArea);
break;
default:
if (event.ctrlKey || event.altKey || event.shiftKey || event.metaKey) return;
}
let moveKeys = [keys["MoveUp"], keys["MoveDown"], keys["JumpUp"], keys["JumpDown"], keys["JumpToStart"], keys["JumpToEnd"]];
if (selectedTag === resultCount - 1 && moveKeys.includes(event.key)) {
@@ -1265,7 +1361,7 @@ async function refreshTacTempFiles(api = false) {
}
if (api) {
await postAPI("tacapi/v1/refresh-temp-files", null);
await postAPI("tacapi/v1/refresh-temp-files");
await reload();
} else {
setTimeout(async () => {

View File

@@ -2,7 +2,9 @@
# to a temporary file to expose it to the javascript side
import glob
import importlib
import json
import sqlite3
import urllib.parse
from pathlib import Path
@@ -11,12 +13,26 @@ import yaml
from fastapi import FastAPI
from fastapi.responses import Response, FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared, hashes
from pydantic import BaseModel
from scripts.model_keyword_support import (get_lora_simple_hash,
load_hash_cache, update_hash_cache,
write_model_keyword_path)
from scripts.shared_paths import *
try:
import scripts.tag_frequency_db as tdb
# Ensure the db dependency is reloaded on script reload
importlib.reload(tdb)
db = tdb.TagFrequencyDb()
if int(db.version) != int(tdb.db_ver):
raise ValueError("Database version mismatch")
except (ImportError, ValueError, sqlite3.Error) as e:
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
db = None
# Attempt to get embedding load function, using the same call as api.
try:
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
@@ -162,26 +178,41 @@ def get_embeddings(sd_model):
emb_v1 = []
emb_v2 = []
emb_vXL = []
emb_unknown = []
results = []
try:
# The sd_model embedding_db reference only exists in sd.next with diffusers backend
try:
loaded_sdnext = sd_model.embedding_db.word_embeddings
skipped_sdnext = sd_model.embedding_db.skipped_embeddings
except (NameError, AttributeError):
loaded_sdnext = {}
skipped_sdnext = {}
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
loaded = loaded | loaded_sdnext
skipped = skipped | skipped_sdnext
# Add embeddings to the correct list
for key, emb in (loaded | skipped).items():
if emb.filename is None or emb.shape is None:
if emb.filename is None:
continue
if emb.shape == V1_SHAPE:
emb_v1.append((Path(emb.filename), key, "v1"))
if emb.shape is None:
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
elif emb.shape == V1_SHAPE:
emb_v1.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v1"))
elif emb.shape == V2_SHAPE:
emb_v2.append((Path(emb.filename), key, "v2"))
emb_v2.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v2"))
elif emb.shape == VXL_SHAPE:
emb_vXL.append((Path(emb.filename), key, "vXL"))
emb_vXL.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "vXL"))
else:
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL)
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL) + sort_models(emb_unknown)
except AttributeError:
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
# Get a list of all embeddings in the folder
@@ -272,13 +303,21 @@ except Exception as e:
# print(f'Exception setting-up performant fetchers: {e}')
def is_visible(p: Path) -> bool:
if getattr(shared.opts, "extra_networks_hidden_models", "When searched") != "Never":
return True
for part in p.parts:
if part.startswith('.'):
return False
return True
def get_lora():
"""Write a list of all lora"""
# Get hashes
valid_loras = _get_lora()
loras_with_hash = []
for l in valid_loras:
if not l.exists() or not l.is_file():
if not l.exists() or not l.is_file() or not is_visible(l):
continue
name = l.relative_to(LORA_PATH).as_posix()
if model_keyword_installed:
@@ -296,7 +335,7 @@ def get_lyco():
valid_lycos = _get_lyco()
lycos_with_hash = []
for ly in valid_lycos:
if not ly.exists() or not ly.is_file():
if not ly.exists() or not ly.is_file() or not is_visible(ly):
continue
name = ly.relative_to(LYCO_PATH).as_posix()
if model_keyword_installed:
@@ -465,6 +504,13 @@ def on_ui_settings():
return self
shared.OptionInfo.needs_restart = needs_restart
# Dictionary of function options and their explanations
frequency_sort_functions = {
"Logarithmic (weak)": "Will respect the base order and slightly prefer often used tags",
"Logarithmic (strong)": "Same as Logarithmic (weak), but with a stronger bias",
"Usage first": "Will list used tags by frequency before all others",
}
tac_options = {
# Main tag file
"tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files),
@@ -496,6 +542,13 @@ def on_ui_settings():
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
# Frequency sorting settings
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
"tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
"tac_frequencyMinCount": shared.OptionInfo(3, "Minimum number of uses for a tag to be considered frequent").info("Tags with less uses than this will not be sorted higher, even if the sorting function would normally result in a higher position."),
"tac_frequencyMaxAge": shared.OptionInfo(30, "Maximum days since last use for a tag to be considered frequent").info("Similar to the above, tags that haven't been used in this many days will not be sorted higher. Set to 0 to disable."),
"tac_frequencyRecommendCap": shared.OptionInfo(10, "Maximum number of recommended tags").info("Limits the maximum number of recommended tags to not drown out normal results. Set to 0 to disable."),
"tac_frequencyIncludeAlias": shared.OptionInfo(False, "Frequency sorting matches aliases for frequent tags").info("Tag frequency will be increased for the main tag even if an alias is used for completion. This option can be used to override the default behavior of alias results being ignored for frequency sorting."),
# Insertion related settings
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
@@ -560,6 +613,20 @@ def on_ui_settings():
"6": ["red", "maroon"],
"7": ["whitesmoke", "black"],
"8": ["seagreen", "darkseagreen"]
},
"derpibooru": {
"-1": ["red", "maroon"],
"0": ["#60d160", "#3d9d3d"],
"1": ["#fff956", "#918e2e"],
"3": ["#fd9961", "#a14c2e"],
"4": ["#cf5bbe", "#6c1e6c"],
"5": ["#3c8ad9", "#1e5e93"],
"6": ["#a6a6a6", "#555555"],
"7": ["#47abc1", "#1f6c7c"],
"8": ["#7871d0", "#392f7d"],
"9": ["#df3647", "#8e1c2b"],
"10": ["#c98f2b", "#7b470e"],
"11": ["#e87ebe", "#a83583"]
}
}\
"""
@@ -699,5 +766,59 @@ def api_tac(_: gr.Blocks, app: FastAPI):
return Response(status_code=200) # Success
else:
return Response(status_code=304) # Not modified
def db_request(func, get = False):
if db is not None:
try:
if get:
ret = func()
if ret is list:
ret = [{"name": t[0], "type": t[1], "count": t[2], "lastUseDate": t[3]} for t in ret]
return JSONResponse({"result": ret})
else:
func()
except sqlite3.Error as e:
return JSONResponse({"error": e.__cause__}, status_code=500)
else:
return JSONResponse({"error": "Database not initialized"}, status_code=500)
@app.post("/tacapi/v1/increase-use-count")
async def increase_use_count(tagname: str, ttype: int, neg: bool):
db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
@app.get("/tacapi/v1/get-use-count")
async def get_use_count(tagname: str, ttype: int, neg: bool):
return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
# Small dataholder class
class UseCountListRequest(BaseModel):
tagNames: list[str]
tagTypes: list[int]
neg: bool = False
# Semantically weird to use post here, but it's required for the body on js side
@app.post("/tacapi/v1/get-use-count-list")
async def get_use_count_list(body: UseCountListRequest):
# If a date limit is set > 0, pass it to the db
date_limit = getattr(shared.opts, "tac_frequencyMaxAge", 30)
date_limit = date_limit if date_limit > 0 else None
count_list = list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg, date_limit))
# If a limit is set, return at max the top n results by count
if count_list and len(count_list):
limit = int(min(getattr(shared.opts, "tac_frequencyRecommendCap", 10), len(count_list)))
# Sort by count and return the top n
if limit > 0:
count_list = sorted(count_list, key=lambda x: x[2], reverse=True)[:limit]
return db_request(lambda: count_list, get=True)
@app.put("/tacapi/v1/reset-use-count")
async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
@app.get("/tacapi/v1/get-all-use-counts")
async def get_all_tag_counts():
return db_request(lambda: db.get_all_tags(), get=True)
script_callbacks.on_app_started(api_tac)

189
scripts/tag_frequency_db.py Normal file
View File

@@ -0,0 +1,189 @@
import sqlite3
from contextlib import contextmanager
from scripts.shared_paths import TAGS_PATH
db_file = TAGS_PATH.joinpath("tag_frequency.db")
timeout = 30
db_ver = 1
@contextmanager
def transaction(db=db_file):
"""Context manager for database transactions.
Ensures that the connection is properly closed after the transaction.
"""
try:
conn = sqlite3.connect(db, timeout=timeout)
conn.isolation_level = None
cursor = conn.cursor()
cursor.execute("BEGIN")
yield cursor
cursor.execute("COMMIT")
except sqlite3.Error as e:
print("Tag Autocomplete: Frequency database error:", e)
finally:
if conn:
conn.close()
class TagFrequencyDb:
"""Class containing creation and interaction methods for the tag frequency database"""
def __init__(self) -> None:
self.version = self.__check()
def __check(self):
if not db_file.exists():
print("Tag Autocomplete: Creating frequency database")
with transaction() as cursor:
self.__create_db(cursor)
self.__update_db_data(cursor, "version", db_ver)
print("Tag Autocomplete: Database successfully created")
return self.__get_version()
def __create_db(self, cursor: sqlite3.Cursor):
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS db_data (
key TEXT PRIMARY KEY,
value TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tag_frequency (
name TEXT NOT NULL,
type INT NOT NULL,
count_pos INT,
count_neg INT,
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (name, type)
)
"""
)
def __update_db_data(self, cursor: sqlite3.Cursor, key, value):
cursor.execute(
"""
INSERT OR REPLACE
INTO db_data (key, value)
VALUES (?, ?)
""",
(key, value),
)
def __get_version(self):
with transaction() as cursor:
cursor.execute(
"""
SELECT value
FROM db_data
WHERE key = 'version'
"""
)
db_version = cursor.fetchone()
return db_version[0] if db_version else 0
def get_all_tags(self):
with transaction() as cursor:
cursor.execute(
f"""
SELECT name, type, count_pos, count_neg, last_used
FROM tag_frequency
WHERE count_pos > 0 OR count_neg > 0
ORDER BY count_pos + count_neg DESC
"""
)
tags = cursor.fetchall()
return tags
def get_tag_count(self, tag, ttype, negative=False):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
""",
(tag, ttype),
)
tag_count = cursor.fetchone()
if tag_count:
return tag_count[0], tag_count[1]
else:
return 0, None
def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False, date_limit=None):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
for tag, ttype in zip(tags, ttypes):
if date_limit is not None:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
AND last_used > datetime('now', '-' || ? || ' days')
""",
(tag, ttype, date_limit),
)
else:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
""",
(tag, ttype),
)
tag_count = cursor.fetchone()
if tag_count:
yield (tag, ttype, tag_count[0], tag_count[1])
else:
yield (tag, ttype, 0, None)
def increase_tag_count(self, tag, ttype, negative=False):
pos_count = self.get_tag_count(tag, ttype, False)[0]
neg_count = self.get_tag_count(tag, ttype, True)[0]
if negative:
neg_count += 1
else:
pos_count += 1
with transaction() as cursor:
cursor.execute(
f"""
INSERT OR REPLACE
INTO tag_frequency (name, type, count_pos, count_neg)
VALUES (?, ?, ?, ?)
""",
(tag, ttype, pos_count, neg_count),
)
def reset_tag_count(self, tag, ttype, positive=True, negative=False):
if positive and negative:
set_str = "count_pos = 0, count_neg = 0"
elif positive:
set_str = "count_pos = 0"
elif negative:
set_str = "count_neg = 0"
with transaction() as cursor:
cursor.execute(
f"""
UPDATE tag_frequency
SET {set_str}
WHERE name = ? AND type = ?
""",
(tag, ttype),
)

113301
tags/EnglishDictionary.csv Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

95091
tags/derpibooru.csv Normal file

File diff suppressed because it is too large Load Diff