mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-27 11:39:55 +00:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
119a3ad51f | ||
|
|
c820a22149 | ||
|
|
eb1e1820f9 | ||
|
|
ef59cff651 | ||
|
|
a454383c43 | ||
|
|
bec567fe26 | ||
|
|
d4041096c9 | ||
|
|
0903259ddf | ||
|
|
f3e64b1fa5 | ||
|
|
312cec5d71 | ||
|
|
b71e6339bd | ||
|
|
7ddbc3c0b2 | ||
|
|
4c2ef8f770 | ||
|
|
97c5e4f53c | ||
|
|
1d8d9f64b5 | ||
|
|
7437850600 | ||
|
|
829a4a7b89 | ||
|
|
22472ac8ad | ||
|
|
5f77fa26d3 | ||
|
|
f810b2dd8f | ||
|
|
95200e82e1 | ||
|
|
a966be7546 | ||
|
|
342fbc9041 | ||
|
|
d496569c9a | ||
|
|
30c9593d3d | ||
|
|
57076060df | ||
|
|
20b6635a2a | ||
|
|
1fe8f26670 | ||
|
|
e82e958c3e | ||
|
|
2dd48eab79 | ||
|
|
4df90f5c95 | ||
|
|
a156214a48 | ||
|
|
15478e73b5 | ||
|
|
434301738a | ||
|
|
4fba7baa69 | ||
|
|
7128efc4f4 | ||
|
|
bd0ddfbb24 | ||
|
|
3108daf0e8 | ||
|
|
363895494b | ||
|
|
04551a8132 | ||
|
|
ffc0e378d3 | ||
|
|
440f109f1f | ||
|
|
80fb247dbe | ||
|
|
d7e98200a8 | ||
|
|
ac790c8ede | ||
|
|
22365ec8d6 | ||
|
|
030a83aa4d | ||
|
|
460d32a4ed | ||
|
|
581bf1e6a4 | ||
|
|
74ea5493e5 | ||
|
|
6cf9acd6ab | ||
|
|
109a8a155e | ||
|
|
3caa1b51ed | ||
|
|
b44c36425a | ||
|
|
1e81403180 | ||
|
|
0f487a5c5c | ||
|
|
2baa12fea3 | ||
|
|
67eeb5fbf6 | ||
|
|
11ffed8afc | ||
|
|
0a8e7d7d84 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
tags/temp/
|
||||
__pycache__/
|
||||
tags/tag_frequency.db
|
||||
|
||||
@@ -20,6 +20,10 @@ Booru style tag autocompletion for the AUTOMATIC1111 Stable Diffusion WebUI
|
||||
</div>
|
||||
<br/>
|
||||
|
||||
#### ⚠️ Notice:
|
||||
I am currently looking for feedback on a new feature I'm working on and want to release soon.<br/>
|
||||
Please check [the announcement post](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/270) for more info if you are interested to help.
|
||||
|
||||
# 📄 Description
|
||||
|
||||
Tag Autocomplete is an extension for the popular [AUTOMATIC1111 web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for Stable Diffusion.
|
||||
|
||||
@@ -24,7 +24,8 @@ class AutocompleteResult {
|
||||
|
||||
// Additional info, only used in some cases
|
||||
category = null;
|
||||
count = null;
|
||||
count = Number.MAX_SAFE_INTEGER;
|
||||
usageBias = null;
|
||||
aliases = null;
|
||||
meta = null;
|
||||
hash = null;
|
||||
|
||||
@@ -80,8 +80,12 @@ async function fetchAPI(url, json = true, cache = false) {
|
||||
return await response.text();
|
||||
}
|
||||
|
||||
async function postAPI(url, body) {
|
||||
let response = await fetch(url, { method: "POST", body: body });
|
||||
async function postAPI(url, body = null) {
|
||||
let response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: body
|
||||
});
|
||||
|
||||
if (response.status != 200) {
|
||||
console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText);
|
||||
@@ -91,6 +95,17 @@ async function postAPI(url, body) {
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async function putAPI(url, body = null) {
|
||||
let response = await fetch(url, { method: "PUT", body: body });
|
||||
|
||||
if (response.status != 200) {
|
||||
console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText);
|
||||
return null;
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
// Extra network preview thumbnails
|
||||
async function getExtraNetworkPreviewURL(filename, type) {
|
||||
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
|
||||
@@ -180,6 +195,104 @@ function flatten(obj, roots = [], sep = ".") {
|
||||
);
|
||||
}
|
||||
|
||||
// Calculate biased tag score based on post count and frequent usage
|
||||
function calculateUsageBias(result, count, uses) {
|
||||
// Check setting conditions
|
||||
if (uses < TAC_CFG.frequencyMinCount) {
|
||||
uses = 0;
|
||||
} else if (uses != 0) {
|
||||
result.usageBias = true;
|
||||
}
|
||||
|
||||
switch (TAC_CFG.frequencyFunction) {
|
||||
case "Logarithmic (weak)":
|
||||
return Math.log(1 + count) + Math.log(1 + uses);
|
||||
case "Logarithmic (strong)":
|
||||
return Math.log(1 + count) + 2 * Math.log(1 + uses);
|
||||
case "Usage first":
|
||||
return uses;
|
||||
default:
|
||||
return count;
|
||||
}
|
||||
}
|
||||
// Beautify return type for easier parsing
|
||||
function mapUseCountArray(useCounts, posAndNeg = false) {
|
||||
return useCounts.map(useCount => {
|
||||
if (posAndNeg) {
|
||||
return {
|
||||
"name": useCount[0],
|
||||
"type": useCount[1],
|
||||
"count": useCount[2],
|
||||
"negCount": useCount[3],
|
||||
"lastUseDate": useCount[4]
|
||||
}
|
||||
}
|
||||
return {
|
||||
"name": useCount[0],
|
||||
"type": useCount[1],
|
||||
"count": useCount[2],
|
||||
"lastUseDate": useCount[3]
|
||||
}
|
||||
});
|
||||
}
|
||||
// Call API endpoint to increase bias of tag in the database
|
||||
function increaseUseCount(tagName, type, negative = false) {
|
||||
postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`);
|
||||
}
|
||||
// Get use count of tag from the database
|
||||
async function getUseCount(tagName, type, negative = false) {
|
||||
return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false))["result"];
|
||||
}
|
||||
async function getUseCounts(tagNames, types, negative = false) {
|
||||
// While semantically weird, we have to use POST here for the body, as urls are limited in length
|
||||
const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative});
|
||||
const rawArray = (await postAPI(`tacapi/v1/get-use-count-list`, body))["result"]
|
||||
return mapUseCountArray(rawArray);
|
||||
}
|
||||
async function getAllUseCounts() {
|
||||
const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts`))["result"];
|
||||
return mapUseCountArray(rawArray, true);
|
||||
}
|
||||
async function resetUseCount(tagName, type, resetPosCount, resetNegCount) {
|
||||
await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`);
|
||||
}
|
||||
|
||||
function createTagUsageTable(tagCounts) {
|
||||
// Create table
|
||||
let tagTable = document.createElement("table");
|
||||
tagTable.innerHTML =
|
||||
`<thead>
|
||||
<tr>
|
||||
<td>Name</td>
|
||||
<td>Type</td>
|
||||
<td>Count(+)</td>
|
||||
<td>Count(-)</td>
|
||||
<td>Last used</td>
|
||||
</tr>
|
||||
</thead>`;
|
||||
tagTable.id = "tac_tagUsageTable"
|
||||
|
||||
tagCounts.forEach(t => {
|
||||
let tr = document.createElement("tr");
|
||||
|
||||
// Fill values
|
||||
let values = [t.name, t.type-1, t.count, t.negCount, t.lastUseDate]
|
||||
values.forEach(v => {
|
||||
let td = document.createElement("td");
|
||||
td.innerText = v;
|
||||
tr.append(td);
|
||||
});
|
||||
// Add delete/reset button
|
||||
let delButton = document.createElement("button");
|
||||
delButton.innerText = "🗑️";
|
||||
delButton.title = "Reset count";
|
||||
tr.append(delButton);
|
||||
|
||||
tagTable.append(tr)
|
||||
});
|
||||
|
||||
return tagTable;
|
||||
}
|
||||
|
||||
// Sliding window function to get possible combination groups of an array
|
||||
function toNgrams(inputArray, size) {
|
||||
@@ -189,7 +302,11 @@ function toNgrams(inputArray, size) {
|
||||
);
|
||||
}
|
||||
|
||||
function escapeRegExp(string) {
|
||||
function escapeRegExp(string, wildcardMatching = false) {
|
||||
if (wildcardMatching) {
|
||||
// Escape all characters except asterisks and ?, which should be treated separately as placeholders.
|
||||
return string.replace(/[-[\]{}()+.,\\^$|#\s]/g, '\\$&').replace(/\*/g, '.*').replace(/\?/g, '.');
|
||||
}
|
||||
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
|
||||
}
|
||||
function escapeHTML(unsafeText) {
|
||||
@@ -238,12 +355,19 @@ function getSortFunction() {
|
||||
let criterion = TAC_CFG.modelSortOrder || "Name";
|
||||
|
||||
const textSort = (a, b, reverse = false) => {
|
||||
const textHolderA = a.type === ResultType.chant ? a.aliases : a.text;
|
||||
const textHolderB = b.type === ResultType.chant ? b.aliases : b.text;
|
||||
// Assign keys so next sort is faster
|
||||
if (!a.sortKey) {
|
||||
a.sortKey = a.type === ResultType.chant
|
||||
? a.aliases
|
||||
: a.text;
|
||||
}
|
||||
if (!b.sortKey) {
|
||||
b.sortKey = b.type === ResultType.chant
|
||||
? b.aliases
|
||||
: b.text;
|
||||
}
|
||||
|
||||
const aKey = a.sortKey || textHolderA;
|
||||
const bKey = b.sortKey || textHolderB;
|
||||
return reverse ? bKey.localeCompare(aKey) : aKey.localeCompare(bKey);
|
||||
return reverse ? b.sortKey.localeCompare(a.sortKey) : a.sortKey.localeCompare(b.sortKey);
|
||||
}
|
||||
const numericSort = (a, b, reverse = false) => {
|
||||
const noKey = reverse ? "-1" : Number.MAX_SAFE_INTEGER;
|
||||
|
||||
@@ -7,7 +7,10 @@ class ChantParser extends BaseTagParser {
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<c:") {
|
||||
let searchTerm = tagword.replace("<chant:", "").replace("<c:", "").replace("<", "");
|
||||
let filterCondition = x => x.terms.toLowerCase().includes(searchTerm) || x.name.toLowerCase().includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.terms.toLowerCase()) || regex.test(x.name.toLowerCase());
|
||||
};
|
||||
tempResults = chants.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = chants;
|
||||
@@ -51,4 +54,4 @@ PARSERS.push(new ChantParser(CHANT_TRIGGER));
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_AFTER_CONFIG_CHANGE.push(load);
|
||||
QUEUE_AFTER_CONFIG_CHANGE.push(load);
|
||||
|
||||
@@ -16,7 +16,10 @@ class EmbeddingParser extends BaseTagParser {
|
||||
searchTerm = searchTerm.slice(3);
|
||||
}
|
||||
|
||||
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x[0].toLowerCase()) || regex.test(x[0].toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
|
||||
if (versionString)
|
||||
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2].toLowerCase() === versionString.toLowerCase()); // Filter by tagword
|
||||
@@ -29,7 +32,11 @@ class EmbeddingParser extends BaseTagParser {
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
|
||||
let lastDot = t[0].lastIndexOf(".") > -1 ? t[0].lastIndexOf(".") : t[0].length;
|
||||
let lastSlash = t[0].lastIndexOf("/") > -1 ? t[0].lastIndexOf("/") : -1;
|
||||
let name = t[0].trim().substring(lastSlash + 1, lastDot);
|
||||
|
||||
let result = new AutocompleteResult(name, ResultType.embedding)
|
||||
result.sortKey = t[1];
|
||||
result.meta = t[2] + " Embedding";
|
||||
finalResults.push(result);
|
||||
@@ -62,4 +69,4 @@ PARSERS.push(new EmbeddingParser(EMB_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -7,7 +7,10 @@ class HypernetParser extends BaseTagParser {
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<h:" && tagword !== "<hypernet:") {
|
||||
let searchTerm = tagword.replace("<hypernet:", "").replace("<h:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = hypernetworks.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = hypernetworks;
|
||||
@@ -49,4 +52,4 @@ PARSERS.push(new HypernetParser(HYP_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -7,7 +7,10 @@ class LoraParser extends BaseTagParser {
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lora:") {
|
||||
let searchTerm = tagword.replace("<lora:", "").replace("<l:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = loras.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = loras;
|
||||
@@ -61,4 +64,4 @@ PARSERS.push(new LoraParser(LORA_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -7,7 +7,10 @@ class LycoParser extends BaseTagParser {
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lyco:" && tagword !== "<lora:") {
|
||||
let searchTerm = tagword.replace("<lyco:", "").replace("<lora:", "").replace("<l:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = lycos.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = lycos;
|
||||
@@ -62,4 +65,4 @@ PARSERS.push(new LycoParser(LYCO_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -18,7 +18,10 @@ class StyleParser extends BaseTagParser {
|
||||
if (tagword !== matchGroups[1]) {
|
||||
let searchTerm = tagword.replace(matchGroups[1], "");
|
||||
|
||||
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x[0].toLowerCase()) || regex.test(x[0].toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = styleNames.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = styleNames;
|
||||
@@ -64,4 +67,4 @@ PARSERS.push(new StyleParser(STYLE_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -86,6 +86,10 @@ const autocompleteCSS = `
|
||||
white-space: nowrap;
|
||||
color: var(--meta-text-color);
|
||||
}
|
||||
.acMetaText.biased::before {
|
||||
content: "✨";
|
||||
margin-right: 2px;
|
||||
}
|
||||
.acWikiLink {
|
||||
padding: 0.5rem;
|
||||
margin: -0.5rem 0 -0.5rem -0.5rem;
|
||||
@@ -216,11 +220,17 @@ async function syncOptions() {
|
||||
includeEmbeddingsInNormalResults: opts["tac_includeEmbeddingsInNormalResults"],
|
||||
useHypernetworks: opts["tac_useHypernetworks"],
|
||||
useLoras: opts["tac_useLoras"],
|
||||
useLycos: opts["tac_useLycos"],
|
||||
useLycos: opts["tac_useLycos"],
|
||||
useLoraPrefixForLycos: opts["tac_useLoraPrefixForLycos"],
|
||||
showWikiLinks: opts["tac_showWikiLinks"],
|
||||
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
|
||||
modelSortOrder: opts["tac_modelSortOrder"],
|
||||
frequencySort: opts["tac_frequencySort"],
|
||||
frequencyFunction: opts["tac_frequencyFunction"],
|
||||
frequencyMinCount: opts["tac_frequencyMinCount"],
|
||||
frequencyMaxAge: opts["tac_frequencyMaxAge"],
|
||||
frequencyRecommendCap: opts["tac_frequencyRecommendCap"],
|
||||
frequencyIncludeAlias: opts["tac_frequencyIncludeAlias"],
|
||||
useStyleVars: opts["tac_useStyleVars"],
|
||||
// Insertion related settings
|
||||
replaceUnderscores: opts["tac_replaceUnderscores"],
|
||||
@@ -466,6 +476,37 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
}
|
||||
}
|
||||
|
||||
// Frequency db update
|
||||
if (TAC_CFG.frequencySort) {
|
||||
let name = null;
|
||||
|
||||
switch (tagType) {
|
||||
case ResultType.wildcardFile:
|
||||
case ResultType.yamlWildcard:
|
||||
// We only want to update the frequency for a full wildcard, not partial paths
|
||||
if (sanitizedText.endsWith("__"))
|
||||
name = text
|
||||
break;
|
||||
case ResultType.chant:
|
||||
// Chants use a slightly different format
|
||||
name = result.aliases;
|
||||
break;
|
||||
default:
|
||||
name = text;
|
||||
break;
|
||||
}
|
||||
|
||||
if (name && name.length > 0) {
|
||||
// Check if it's a negative prompt
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let isNegative = textAreaId.includes("n");
|
||||
// Sanitize name for API call
|
||||
name = encodeURIComponent(name)
|
||||
// Call API & update db
|
||||
increaseUseCount(name, tagType, isNegative)
|
||||
}
|
||||
}
|
||||
|
||||
var prompt = textArea.value;
|
||||
|
||||
// Edit prompt text
|
||||
@@ -574,6 +615,8 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
tacSelfTrigger = true;
|
||||
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
|
||||
// Uses a built-in method from the webui's ui.js which also already accounts for event target
|
||||
if (tagType === ResultType.wildcardTag || tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard)
|
||||
tacSelfTrigger = true;
|
||||
updateInput(textArea);
|
||||
|
||||
// Update previous tags with the edited prompt to prevent re-searching the same term
|
||||
@@ -688,6 +731,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
let wikiLink = document.createElement("a");
|
||||
wikiLink.classList.add("acWikiLink");
|
||||
wikiLink.innerText = "?";
|
||||
wikiLink.title = "Open external wiki page for this tag"
|
||||
|
||||
let linkPart = displayText;
|
||||
// Only use alias result if it is one
|
||||
@@ -733,7 +777,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
}
|
||||
|
||||
// Post count
|
||||
if (result.count && !isNaN(result.count)) {
|
||||
if (result.count && !isNaN(result.count) && result.count !== Number.MAX_SAFE_INTEGER) {
|
||||
let postCount = result.count;
|
||||
let formatter;
|
||||
|
||||
@@ -765,8 +809,24 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
flexDiv.appendChild(metaDiv);
|
||||
}
|
||||
|
||||
// Add small ✨ marker to indicate usage sorting
|
||||
if (result.usageBias) {
|
||||
flexDiv.querySelector(".acMetaText").classList.add("biased");
|
||||
flexDiv.title = "✨ Frequent tag. Ctrl/Cmd + click to reset usage count."
|
||||
}
|
||||
|
||||
// Check if it's a negative prompt
|
||||
let isNegative = textAreaId.includes("n");
|
||||
|
||||
// Add listener
|
||||
li.addEventListener("click", function () { insertTextAtCursor(textArea, result, tagword); });
|
||||
li.addEventListener("click", (e) => {
|
||||
if (e.ctrlKey || e.metaKey) {
|
||||
resetUseCount(result.text, result.type, !isNegative, isNegative);
|
||||
flexDiv.querySelector(".acMetaText").classList.remove("biased");
|
||||
} else {
|
||||
insertTextAtCursor(textArea, result, tagword);
|
||||
}
|
||||
});
|
||||
// Add element to list
|
||||
resultsList.appendChild(li);
|
||||
}
|
||||
@@ -1034,6 +1094,9 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
resultCountBeforeNormalTags = 0;
|
||||
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
|
||||
|
||||
// Needed for slicing check later
|
||||
let normalTags = false;
|
||||
|
||||
// Process all parsers
|
||||
let resultCandidates = (await processParsers(textArea, prompt))?.filter(x => x.length > 0);
|
||||
// If one ore more result candidates match, use their results
|
||||
@@ -1043,32 +1106,12 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
// Sort results, but not if it's umi tags since they are sorted by count
|
||||
if (!(resultCandidates.length === 1 && results[0].type === ResultType.umiWildcard))
|
||||
results = results.sort(getSortFunction());
|
||||
|
||||
// Since some tags are kaomoji, we have to add the normal results in some cases
|
||||
if (tagword.startsWith("<") || tagword.startsWith("*<")) {
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
if (tagword.startsWith("*")) {
|
||||
tagword = tagword.slice(1);
|
||||
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
|
||||
} else {
|
||||
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
|
||||
}
|
||||
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
|
||||
|
||||
genericResults.forEach(g => {
|
||||
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
|
||||
result.category = g[1];
|
||||
result.count = g[2];
|
||||
result.aliases = g[3];
|
||||
results.push(result);
|
||||
});
|
||||
}
|
||||
}
|
||||
// Else search the normal tag list
|
||||
if (!resultCandidates || resultCandidates.length === 0
|
||||
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
|
||||
) {
|
||||
normalTags = true;
|
||||
resultCountBeforeNormalTags = results.length;
|
||||
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
@@ -1123,11 +1166,6 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
results = results.concat(extraResults);
|
||||
}
|
||||
}
|
||||
|
||||
// Slice if the user has set a max result count
|
||||
if (!TAC_CFG.showAllResults) {
|
||||
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
|
||||
}
|
||||
}
|
||||
|
||||
// Guard for empty results
|
||||
@@ -1137,6 +1175,57 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort again with frequency / usage count if enabled
|
||||
if (TAC_CFG.frequencySort) {
|
||||
// Split our results into a list of names and types
|
||||
let tagNames = [];
|
||||
let aliasNames = [];
|
||||
let types = [];
|
||||
// Limit to 2k for performance reasons
|
||||
const aliasTypes = [ResultType.tag, ResultType.extra];
|
||||
results.slice(0,2000).forEach(r => {
|
||||
const name = r.type === ResultType.chant ? r.aliases : r.text;
|
||||
// Add to alias list or tag list depending on if the name includes the tagword
|
||||
// (the same criteria is used in the filter in calculateUsageBias)
|
||||
if (aliasTypes.includes(r.type) && !name.includes(tagword)) {
|
||||
aliasNames.push(name);
|
||||
} else {
|
||||
tagNames.push(name);
|
||||
}
|
||||
types.push(r.type);
|
||||
});
|
||||
|
||||
// Check if it's a negative prompt
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let isNegative = textAreaId.includes("n");
|
||||
|
||||
// Request use counts from the DB
|
||||
const names = TAC_CFG.frequencyIncludeAlias ? tagNames.concat(aliasNames) : tagNames;
|
||||
const counts = await getUseCounts(names, types, isNegative);
|
||||
|
||||
// Pre-calculate weights to prevent duplicate work
|
||||
const resultBiasMap = new Map();
|
||||
results.forEach(result => {
|
||||
const name = result.type === ResultType.chant ? result.aliases : result.text;
|
||||
const type = result.type;
|
||||
// Find matching pair from DB results
|
||||
const useStats = counts.find(c => c.name === name && c.type === type);
|
||||
const uses = useStats?.count || 0;
|
||||
// Calculate & set weight
|
||||
const weight = calculateUsageBias(result, result.count, uses)
|
||||
resultBiasMap.set(result, weight);
|
||||
});
|
||||
// Actual sorting with the pre-calculated weights
|
||||
results = results.sort((a, b) => {
|
||||
return resultBiasMap.get(b) - resultBiasMap.get(a);
|
||||
});
|
||||
}
|
||||
|
||||
// Slice if the user has set a max result count and we are not in a extra networks / wildcard list
|
||||
if (!TAC_CFG.showAllResults && normalTags) {
|
||||
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
|
||||
}
|
||||
|
||||
addResultsToList(textArea, results, tagword, true);
|
||||
showResults(textArea);
|
||||
}
|
||||
@@ -1159,12 +1248,17 @@ function navigateInList(textArea, event) {
|
||||
|
||||
if (!validKeys.includes(event.key)) return;
|
||||
if (!isVisible(textArea)) return
|
||||
// Return if ctrl key is pressed to not interfere with weight editing shortcut
|
||||
if (event.ctrlKey || event.altKey || event.shiftKey || event.metaKey) return;
|
||||
// Add modifier keys to base as text+.
|
||||
let modKey = "";
|
||||
if (event.ctrlKey) modKey += "Ctrl+";
|
||||
if (event.altKey) modKey += "Alt+";
|
||||
if (event.shiftKey) modKey += "Shift+";
|
||||
if (event.metaKey) modKey += "Meta+";
|
||||
modKey += event.key;
|
||||
|
||||
oldSelectedTag = selectedTag;
|
||||
|
||||
switch (event.key) {
|
||||
switch (modKey) {
|
||||
case keys["MoveUp"]:
|
||||
if (selectedTag === null) {
|
||||
selectedTag = resultCount - 1;
|
||||
@@ -1235,6 +1329,8 @@ function navigateInList(textArea, event) {
|
||||
case keys["Close"]:
|
||||
hideResults(textArea);
|
||||
break;
|
||||
default:
|
||||
if (event.ctrlKey || event.altKey || event.shiftKey || event.metaKey) return;
|
||||
}
|
||||
let moveKeys = [keys["MoveUp"], keys["MoveDown"], keys["JumpUp"], keys["JumpDown"], keys["JumpToStart"], keys["JumpToEnd"]];
|
||||
if (selectedTag === resultCount - 1 && moveKeys.includes(event.key)) {
|
||||
@@ -1265,7 +1361,7 @@ async function refreshTacTempFiles(api = false) {
|
||||
}
|
||||
|
||||
if (api) {
|
||||
await postAPI("tacapi/v1/refresh-temp-files", null);
|
||||
await postAPI("tacapi/v1/refresh-temp-files");
|
||||
await reload();
|
||||
} else {
|
||||
setTimeout(async () => {
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
# to a temporary file to expose it to the javascript side
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import sqlite3
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
@@ -11,12 +13,26 @@ import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import Response, FileResponse, JSONResponse
|
||||
from modules import script_callbacks, sd_hijack, shared, hashes
|
||||
from pydantic import BaseModel
|
||||
|
||||
from scripts.model_keyword_support import (get_lora_simple_hash,
|
||||
load_hash_cache, update_hash_cache,
|
||||
write_model_keyword_path)
|
||||
from scripts.shared_paths import *
|
||||
|
||||
try:
|
||||
import scripts.tag_frequency_db as tdb
|
||||
|
||||
# Ensure the db dependency is reloaded on script reload
|
||||
importlib.reload(tdb)
|
||||
|
||||
db = tdb.TagFrequencyDb()
|
||||
if int(db.version) != int(tdb.db_ver):
|
||||
raise ValueError("Database version mismatch")
|
||||
except (ImportError, ValueError, sqlite3.Error) as e:
|
||||
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
|
||||
db = None
|
||||
|
||||
# Attempt to get embedding load function, using the same call as api.
|
||||
try:
|
||||
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
|
||||
@@ -162,26 +178,41 @@ def get_embeddings(sd_model):
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
emb_vXL = []
|
||||
emb_unknown = []
|
||||
results = []
|
||||
|
||||
try:
|
||||
# The sd_model embedding_db reference only exists in sd.next with diffusers backend
|
||||
try:
|
||||
loaded_sdnext = sd_model.embedding_db.word_embeddings
|
||||
skipped_sdnext = sd_model.embedding_db.skipped_embeddings
|
||||
except (NameError, AttributeError):
|
||||
loaded_sdnext = {}
|
||||
skipped_sdnext = {}
|
||||
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
loaded = loaded | loaded_sdnext
|
||||
skipped = skipped | skipped_sdnext
|
||||
|
||||
# Add embeddings to the correct list
|
||||
for key, emb in (loaded | skipped).items():
|
||||
if emb.filename is None or emb.shape is None:
|
||||
if emb.filename is None:
|
||||
continue
|
||||
|
||||
if emb.shape == V1_SHAPE:
|
||||
emb_v1.append((Path(emb.filename), key, "v1"))
|
||||
if emb.shape is None:
|
||||
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
|
||||
elif emb.shape == V1_SHAPE:
|
||||
emb_v1.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v1"))
|
||||
elif emb.shape == V2_SHAPE:
|
||||
emb_v2.append((Path(emb.filename), key, "v2"))
|
||||
emb_v2.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v2"))
|
||||
elif emb.shape == VXL_SHAPE:
|
||||
emb_vXL.append((Path(emb.filename), key, "vXL"))
|
||||
emb_vXL.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "vXL"))
|
||||
else:
|
||||
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
|
||||
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL)
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL) + sort_models(emb_unknown)
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
|
||||
# Get a list of all embeddings in the folder
|
||||
@@ -272,13 +303,21 @@ except Exception as e:
|
||||
# print(f'Exception setting-up performant fetchers: {e}')
|
||||
|
||||
|
||||
def is_visible(p: Path) -> bool:
|
||||
if getattr(shared.opts, "extra_networks_hidden_models", "When searched") != "Never":
|
||||
return True
|
||||
for part in p.parts:
|
||||
if part.startswith('.'):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_lora():
|
||||
"""Write a list of all lora"""
|
||||
# Get hashes
|
||||
valid_loras = _get_lora()
|
||||
loras_with_hash = []
|
||||
for l in valid_loras:
|
||||
if not l.exists() or not l.is_file():
|
||||
if not l.exists() or not l.is_file() or not is_visible(l):
|
||||
continue
|
||||
name = l.relative_to(LORA_PATH).as_posix()
|
||||
if model_keyword_installed:
|
||||
@@ -296,7 +335,7 @@ def get_lyco():
|
||||
valid_lycos = _get_lyco()
|
||||
lycos_with_hash = []
|
||||
for ly in valid_lycos:
|
||||
if not ly.exists() or not ly.is_file():
|
||||
if not ly.exists() or not ly.is_file() or not is_visible(ly):
|
||||
continue
|
||||
name = ly.relative_to(LYCO_PATH).as_posix()
|
||||
if model_keyword_installed:
|
||||
@@ -465,6 +504,13 @@ def on_ui_settings():
|
||||
return self
|
||||
shared.OptionInfo.needs_restart = needs_restart
|
||||
|
||||
# Dictionary of function options and their explanations
|
||||
frequency_sort_functions = {
|
||||
"Logarithmic (weak)": "Will respect the base order and slightly prefer often used tags",
|
||||
"Logarithmic (strong)": "Same as Logarithmic (weak), but with a stronger bias",
|
||||
"Usage first": "Will list used tags by frequency before all others",
|
||||
}
|
||||
|
||||
tac_options = {
|
||||
# Main tag file
|
||||
"tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files),
|
||||
@@ -496,6 +542,13 @@ def on_ui_settings():
|
||||
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
|
||||
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
|
||||
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
|
||||
# Frequency sorting settings
|
||||
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
|
||||
"tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
|
||||
"tac_frequencyMinCount": shared.OptionInfo(3, "Minimum number of uses for a tag to be considered frequent").info("Tags with less uses than this will not be sorted higher, even if the sorting function would normally result in a higher position."),
|
||||
"tac_frequencyMaxAge": shared.OptionInfo(30, "Maximum days since last use for a tag to be considered frequent").info("Similar to the above, tags that haven't been used in this many days will not be sorted higher. Set to 0 to disable."),
|
||||
"tac_frequencyRecommendCap": shared.OptionInfo(10, "Maximum number of recommended tags").info("Limits the maximum number of recommended tags to not drown out normal results. Set to 0 to disable."),
|
||||
"tac_frequencyIncludeAlias": shared.OptionInfo(False, "Frequency sorting matches aliases for frequent tags").info("Tag frequency will be increased for the main tag even if an alias is used for completion. This option can be used to override the default behavior of alias results being ignored for frequency sorting."),
|
||||
# Insertion related settings
|
||||
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
|
||||
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
|
||||
@@ -560,6 +613,20 @@ def on_ui_settings():
|
||||
"6": ["red", "maroon"],
|
||||
"7": ["whitesmoke", "black"],
|
||||
"8": ["seagreen", "darkseagreen"]
|
||||
},
|
||||
"derpibooru": {
|
||||
"-1": ["red", "maroon"],
|
||||
"0": ["#60d160", "#3d9d3d"],
|
||||
"1": ["#fff956", "#918e2e"],
|
||||
"3": ["#fd9961", "#a14c2e"],
|
||||
"4": ["#cf5bbe", "#6c1e6c"],
|
||||
"5": ["#3c8ad9", "#1e5e93"],
|
||||
"6": ["#a6a6a6", "#555555"],
|
||||
"7": ["#47abc1", "#1f6c7c"],
|
||||
"8": ["#7871d0", "#392f7d"],
|
||||
"9": ["#df3647", "#8e1c2b"],
|
||||
"10": ["#c98f2b", "#7b470e"],
|
||||
"11": ["#e87ebe", "#a83583"]
|
||||
}
|
||||
}\
|
||||
"""
|
||||
@@ -699,5 +766,59 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
return Response(status_code=200) # Success
|
||||
else:
|
||||
return Response(status_code=304) # Not modified
|
||||
def db_request(func, get = False):
|
||||
if db is not None:
|
||||
try:
|
||||
if get:
|
||||
ret = func()
|
||||
if ret is list:
|
||||
ret = [{"name": t[0], "type": t[1], "count": t[2], "lastUseDate": t[3]} for t in ret]
|
||||
return JSONResponse({"result": ret})
|
||||
else:
|
||||
func()
|
||||
except sqlite3.Error as e:
|
||||
return JSONResponse({"error": e.__cause__}, status_code=500)
|
||||
else:
|
||||
return JSONResponse({"error": "Database not initialized"}, status_code=500)
|
||||
|
||||
@app.post("/tacapi/v1/increase-use-count")
|
||||
async def increase_use_count(tagname: str, ttype: int, neg: bool):
|
||||
db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
|
||||
|
||||
@app.get("/tacapi/v1/get-use-count")
|
||||
async def get_use_count(tagname: str, ttype: int, neg: bool):
|
||||
return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
|
||||
|
||||
# Small dataholder class
|
||||
class UseCountListRequest(BaseModel):
|
||||
tagNames: list[str]
|
||||
tagTypes: list[int]
|
||||
neg: bool = False
|
||||
|
||||
# Semantically weird to use post here, but it's required for the body on js side
|
||||
@app.post("/tacapi/v1/get-use-count-list")
|
||||
async def get_use_count_list(body: UseCountListRequest):
|
||||
# If a date limit is set > 0, pass it to the db
|
||||
date_limit = getattr(shared.opts, "tac_frequencyMaxAge", 30)
|
||||
date_limit = date_limit if date_limit > 0 else None
|
||||
|
||||
count_list = list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg, date_limit))
|
||||
|
||||
# If a limit is set, return at max the top n results by count
|
||||
if count_list and len(count_list):
|
||||
limit = int(min(getattr(shared.opts, "tac_frequencyRecommendCap", 10), len(count_list)))
|
||||
# Sort by count and return the top n
|
||||
if limit > 0:
|
||||
count_list = sorted(count_list, key=lambda x: x[2], reverse=True)[:limit]
|
||||
|
||||
return db_request(lambda: count_list, get=True)
|
||||
|
||||
@app.put("/tacapi/v1/reset-use-count")
|
||||
async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
|
||||
db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
|
||||
|
||||
@app.get("/tacapi/v1/get-all-use-counts")
|
||||
async def get_all_tag_counts():
|
||||
return db_request(lambda: db.get_all_tags(), get=True)
|
||||
|
||||
script_callbacks.on_app_started(api_tac)
|
||||
|
||||
189
scripts/tag_frequency_db.py
Normal file
189
scripts/tag_frequency_db.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
|
||||
from scripts.shared_paths import TAGS_PATH
|
||||
|
||||
db_file = TAGS_PATH.joinpath("tag_frequency.db")
|
||||
timeout = 30
|
||||
db_ver = 1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction(db=db_file):
|
||||
"""Context manager for database transactions.
|
||||
Ensures that the connection is properly closed after the transaction.
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(db, timeout=timeout)
|
||||
|
||||
conn.isolation_level = None
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN")
|
||||
yield cursor
|
||||
cursor.execute("COMMIT")
|
||||
except sqlite3.Error as e:
|
||||
print("Tag Autocomplete: Frequency database error:", e)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
class TagFrequencyDb:
|
||||
"""Class containing creation and interaction methods for the tag frequency database"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.version = self.__check()
|
||||
|
||||
def __check(self):
|
||||
if not db_file.exists():
|
||||
print("Tag Autocomplete: Creating frequency database")
|
||||
with transaction() as cursor:
|
||||
self.__create_db(cursor)
|
||||
self.__update_db_data(cursor, "version", db_ver)
|
||||
print("Tag Autocomplete: Database successfully created")
|
||||
|
||||
return self.__get_version()
|
||||
|
||||
def __create_db(self, cursor: sqlite3.Cursor):
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS db_data (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tag_frequency (
|
||||
name TEXT NOT NULL,
|
||||
type INT NOT NULL,
|
||||
count_pos INT,
|
||||
count_neg INT,
|
||||
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (name, type)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def __update_db_data(self, cursor: sqlite3.Cursor, key, value):
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE
|
||||
INTO db_data (key, value)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(key, value),
|
||||
)
|
||||
|
||||
def __get_version(self):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM db_data
|
||||
WHERE key = 'version'
|
||||
"""
|
||||
)
|
||||
db_version = cursor.fetchone()
|
||||
|
||||
return db_version[0] if db_version else 0
|
||||
|
||||
def get_all_tags(self):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT name, type, count_pos, count_neg, last_used
|
||||
FROM tag_frequency
|
||||
WHERE count_pos > 0 OR count_neg > 0
|
||||
ORDER BY count_pos + count_neg DESC
|
||||
"""
|
||||
)
|
||||
tags = cursor.fetchall()
|
||||
|
||||
return tags
|
||||
|
||||
def get_tag_count(self, tag, ttype, negative=False):
|
||||
count_str = "count_neg" if negative else "count_pos"
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT {count_str}, last_used
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag, ttype),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
|
||||
if tag_count:
|
||||
return tag_count[0], tag_count[1]
|
||||
else:
|
||||
return 0, None
|
||||
|
||||
def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False, date_limit=None):
|
||||
count_str = "count_neg" if negative else "count_pos"
|
||||
with transaction() as cursor:
|
||||
for tag, ttype in zip(tags, ttypes):
|
||||
if date_limit is not None:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT {count_str}, last_used
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
AND last_used > datetime('now', '-' || ? || ' days')
|
||||
""",
|
||||
(tag, ttype, date_limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT {count_str}, last_used
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag, ttype),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
if tag_count:
|
||||
yield (tag, ttype, tag_count[0], tag_count[1])
|
||||
else:
|
||||
yield (tag, ttype, 0, None)
|
||||
|
||||
def increase_tag_count(self, tag, ttype, negative=False):
|
||||
pos_count = self.get_tag_count(tag, ttype, False)[0]
|
||||
neg_count = self.get_tag_count(tag, ttype, True)[0]
|
||||
|
||||
if negative:
|
||||
neg_count += 1
|
||||
else:
|
||||
pos_count += 1
|
||||
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT OR REPLACE
|
||||
INTO tag_frequency (name, type, count_pos, count_neg)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(tag, ttype, pos_count, neg_count),
|
||||
)
|
||||
|
||||
def reset_tag_count(self, tag, ttype, positive=True, negative=False):
|
||||
if positive and negative:
|
||||
set_str = "count_pos = 0, count_neg = 0"
|
||||
elif positive:
|
||||
set_str = "count_pos = 0"
|
||||
elif negative:
|
||||
set_str = "count_neg = 0"
|
||||
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
UPDATE tag_frequency
|
||||
SET {set_str}
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag, ttype),
|
||||
)
|
||||
113301
tags/EnglishDictionary.csv
Normal file
113301
tags/EnglishDictionary.csv
Normal file
File diff suppressed because it is too large
Load Diff
168662
tags/danbooru.csv
168662
tags/danbooru.csv
File diff suppressed because it is too large
Load Diff
95091
tags/derpibooru.csv
Normal file
95091
tags/derpibooru.csv
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user