Compare commits

...

47 Commits

Author SHA1 Message Date
DominikDoom
119a3ad51f Merge branch 'feature-sort-by-frequent-use' 2024-04-13 15:08:57 +02:00
DominikDoom
c820a22149 Merge branch 'main' into feature-sort-by-frequent-use 2024-04-13 15:06:53 +02:00
DominikDoom
ef59cff651 Move last used date check guard to SQL side, implement max cap
- Server side date comparison and cap check further improve js sort performance
- The alias check has also been moved out of calculateUsageBias to support the new cap system
2024-03-16 16:44:43 +01:00
DominikDoom
a454383c43 Merge branch 'main' into feature-sort-by-frequent-use 2024-03-03 13:52:32 +01:00
DominikDoom
4c2ef8f770 Merge branch 'main' into feature-sort-by-frequent-use 2024-02-09 19:23:52 +01:00
DominikDoom
7437850600 Merge branch 'main' into feature-sort-by-frequent-use 2024-02-04 14:46:29 +01:00
DominikDoom
f810b2dd8f Merge branch 'main' into feature-sort-by-frequent-use 2024-01-27 12:39:40 +01:00
DominikDoom
95200e82e1 Merge branch 'main' into feature-sort-by-frequent-use 2024-01-26 17:04:53 +01:00
DominikDoom
a966be7546 Merge branch 'main' into feature-sort-by-frequent-use 2024-01-26 16:21:15 +01:00
DominikDoom
342fbc9041 Pre-calculate usage bias for all results instead of in the sort function
Roughly doubles the sort performance
2024-01-19 21:10:09 +01:00
DominikDoom
d496569c9a Cache sort key for small performance increase 2024-01-19 20:17:14 +01:00
DominikDoom
30c9593d3d Merge branch 'main' into feature-sort-by-frequent-use 2023-12-12 14:23:18 +01:00
DominikDoom
57076060df Merge branch 'main' into feature-sort-by-frequent-use 2023-12-11 11:43:26 +01:00
DominikDoom
20b6635a2a WIP usage info table
Might get replaced with gradio depending on how well it works
2023-12-04 15:00:19 +01:00
DominikDoom
1fe8f26670 Add explanatory tooltip and inline reset ability
Also add tooltip for wiki links
2023-12-04 13:56:15 +01:00
DominikDoom
e82e958c3e Fix alias check for non-aliased tag types 2023-11-29 18:15:59 +01:00
DominikDoom
2dd48eab79 Fix error with db return value for no matches 2023-11-29 18:14:14 +01:00
DominikDoom
4df90f5c95 Don't frequency sort alias results by default
with an option to enable it if desired
2023-11-29 18:04:50 +01:00
DominikDoom
a156214a48 Last used & min count settings
Also some performance improvements
2023-11-29 17:45:51 +01:00
DominikDoom
15478e73b5 Count positive / negative prompt usage separately 2023-11-29 15:22:41 +01:00
DominikDoom
434301738a Merge branch 'main' into feature-sort-by-frequent-use 2023-11-05 13:30:51 +01:00
DominikDoom
4fba7baa69 Merge branch 'main' into feature-sort-by-frequent-use 2023-10-06 18:36:24 +02:00
DominikDoom
7128efc4f4 Apply same fix to extra tags
Count now defaults to max safe integer, which simplifies the sort function
Before, it resulted in really bad performance
2023-10-02 00:45:48 +02:00
DominikDoom
bd0ddfbb24 Fix embeddings not at top
(only affecting the "include embeddings in normal results" option)
2023-10-02 00:16:58 +02:00
DominikDoom
3108daf0e8 Remove kaomoji inclusion in < search
because it interfered with use count searching and is not commonly needed
2023-10-01 23:51:35 +02:00
DominikDoom
363895494b Fix hide after insert race condition 2023-10-01 23:17:12 +02:00
DominikDoom
04551a8132 Don't await increase, limit to 2k for performance 2023-10-01 22:59:28 +02:00
DominikDoom
ffc0e378d3 Add different sorting functions 2023-10-01 22:44:35 +02:00
DominikDoom
440f109f1f Use POST + body to get around URL length limit 2023-10-01 22:30:47 +02:00
DominikDoom
80fb247dbe Sort results by usage count 2023-10-01 21:44:24 +02:00
DominikDoom
d7e98200a8 Use count increase logic 2023-09-26 12:20:15 +02:00
DominikDoom
ac790c8ede Return dict instead of array for clarity 2023-09-26 12:12:46 +02:00
DominikDoom
22365ec8d6 Add missing type return to list request 2023-09-26 12:02:36 +02:00
DominikDoom
030a83aa4d Use query parameter instead of path to fix wildcard subfolder issues 2023-09-26 11:55:12 +02:00
DominikDoom
460d32a4ed Ensure proper reload, fix error message 2023-09-26 11:45:42 +02:00
DominikDoom
581bf1e6a4 Use composite key with name & type to prevent collisions 2023-09-26 11:35:24 +02:00
DominikDoom
74ea5493e5 Add rest of utils functions 2023-09-26 10:58:46 +02:00
DominikDoom
6cf9acd6ab Catch sqlite exceptions, add tag list endpoint 2023-09-24 20:06:40 +02:00
DominikDoom
109a8a155e Change endpoint name for consistency 2023-09-24 18:00:41 +02:00
DominikDoom
3caa1b51ed Add db to gitignore 2023-09-24 17:59:39 +02:00
DominikDoom
b44c36425a Fix db load version comparison, add sort options 2023-09-24 17:59:14 +02:00
DominikDoom
1e81403180 Safety catches for DB API access 2023-09-24 16:50:39 +02:00
DominikDoom
0f487a5c5c WIP database setup inspired by ImageBrowser 2023-09-24 16:28:32 +02:00
DominikDoom
2baa12fea3 Merge branch 'main' into feature-sort-by-frequent-use 2023-09-24 15:34:18 +02:00
DominikDoom
67eeb5fbf6 Merge branch 'main' into feature-sort-by-frequent-use 2023-09-19 12:14:12 +02:00
DominikDoom
11ffed8afc Merge branch 'feature-sorting' into feature-sort-by-frequent-use 2023-09-15 16:37:34 +02:00
DominikDoom
0a8e7d7d84 Stub API setup for tag usage stats 2023-09-12 14:10:15 +02:00
6 changed files with 521 additions and 37 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
tags/temp/
__pycache__/
tags/tag_frequency.db

View File

@@ -24,7 +24,8 @@ class AutocompleteResult {
// Additional info, only used in some cases
category = null;
count = null;
count = Number.MAX_SAFE_INTEGER;
usageBias = null;
aliases = null;
meta = null;
hash = null;

View File

@@ -80,8 +80,12 @@ async function fetchAPI(url, json = true, cache = false) {
return await response.text();
}
async function postAPI(url, body) {
let response = await fetch(url, { method: "POST", body: body });
async function postAPI(url, body = null) {
let response = await fetch(url, {
method: "POST",
headers: {'Content-Type': 'application/json'},
body: body
});
if (response.status != 200) {
console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText);
@@ -91,6 +95,17 @@ async function postAPI(url, body) {
return await response.json();
}
async function putAPI(url, body = null) {
let response = await fetch(url, { method: "PUT", body: body });
if (response.status != 200) {
console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText);
return null;
}
return await response.json();
}
// Extra network preview thumbnails
async function getExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
@@ -180,6 +195,104 @@ function flatten(obj, roots = [], sep = ".") {
);
}
// Calculate biased tag score based on post count and frequent usage
function calculateUsageBias(result, count, uses) {
// Check setting conditions
if (uses < TAC_CFG.frequencyMinCount) {
uses = 0;
} else if (uses != 0) {
result.usageBias = true;
}
switch (TAC_CFG.frequencyFunction) {
case "Logarithmic (weak)":
return Math.log(1 + count) + Math.log(1 + uses);
case "Logarithmic (strong)":
return Math.log(1 + count) + 2 * Math.log(1 + uses);
case "Usage first":
return uses;
default:
return count;
}
}
// Beautify return type for easier parsing
function mapUseCountArray(useCounts, posAndNeg = false) {
return useCounts.map(useCount => {
if (posAndNeg) {
return {
"name": useCount[0],
"type": useCount[1],
"count": useCount[2],
"negCount": useCount[3],
"lastUseDate": useCount[4]
}
}
return {
"name": useCount[0],
"type": useCount[1],
"count": useCount[2],
"lastUseDate": useCount[3]
}
});
}
// Call API endpoint to increase bias of tag in the database
function increaseUseCount(tagName, type, negative = false) {
postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`);
}
// Get use count of tag from the database
async function getUseCount(tagName, type, negative = false) {
return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false))["result"];
}
async function getUseCounts(tagNames, types, negative = false) {
// While semantically weird, we have to use POST here for the body, as urls are limited in length
const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative});
const rawArray = (await postAPI(`tacapi/v1/get-use-count-list`, body))["result"]
return mapUseCountArray(rawArray);
}
async function getAllUseCounts() {
const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts`))["result"];
return mapUseCountArray(rawArray, true);
}
async function resetUseCount(tagName, type, resetPosCount, resetNegCount) {
await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`);
}
function createTagUsageTable(tagCounts) {
// Create table
let tagTable = document.createElement("table");
tagTable.innerHTML =
`<thead>
<tr>
<td>Name</td>
<td>Type</td>
<td>Count(+)</td>
<td>Count(-)</td>
<td>Last used</td>
</tr>
</thead>`;
tagTable.id = "tac_tagUsageTable"
tagCounts.forEach(t => {
let tr = document.createElement("tr");
// Fill values
let values = [t.name, t.type-1, t.count, t.negCount, t.lastUseDate]
values.forEach(v => {
let td = document.createElement("td");
td.innerText = v;
tr.append(td);
});
// Add delete/reset button
let delButton = document.createElement("button");
delButton.innerText = "🗑️";
delButton.title = "Reset count";
tr.append(delButton);
tagTable.append(tr)
});
return tagTable;
}
// Sliding window function to get possible combination groups of an array
function toNgrams(inputArray, size) {
@@ -242,12 +355,19 @@ function getSortFunction() {
let criterion = TAC_CFG.modelSortOrder || "Name";
const textSort = (a, b, reverse = false) => {
const textHolderA = a.type === ResultType.chant ? a.aliases : a.text;
const textHolderB = b.type === ResultType.chant ? b.aliases : b.text;
// Assign keys so next sort is faster
if (!a.sortKey) {
a.sortKey = a.type === ResultType.chant
? a.aliases
: a.text;
}
if (!b.sortKey) {
b.sortKey = b.type === ResultType.chant
? b.aliases
: b.text;
}
const aKey = a.sortKey || textHolderA;
const bKey = b.sortKey || textHolderB;
return reverse ? bKey.localeCompare(aKey) : aKey.localeCompare(bKey);
return reverse ? b.sortKey.localeCompare(a.sortKey) : a.sortKey.localeCompare(b.sortKey);
}
const numericSort = (a, b, reverse = false) => {
const noKey = reverse ? "-1" : Number.MAX_SAFE_INTEGER;

View File

@@ -86,6 +86,10 @@ const autocompleteCSS = `
white-space: nowrap;
color: var(--meta-text-color);
}
.acMetaText.biased::before {
content: "✨";
margin-right: 2px;
}
.acWikiLink {
padding: 0.5rem;
margin: -0.5rem 0 -0.5rem -0.5rem;
@@ -221,6 +225,12 @@ async function syncOptions() {
showWikiLinks: opts["tac_showWikiLinks"],
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
modelSortOrder: opts["tac_modelSortOrder"],
frequencySort: opts["tac_frequencySort"],
frequencyFunction: opts["tac_frequencyFunction"],
frequencyMinCount: opts["tac_frequencyMinCount"],
frequencyMaxAge: opts["tac_frequencyMaxAge"],
frequencyRecommendCap: opts["tac_frequencyRecommendCap"],
frequencyIncludeAlias: opts["tac_frequencyIncludeAlias"],
useStyleVars: opts["tac_useStyleVars"],
// Insertion related settings
replaceUnderscores: opts["tac_replaceUnderscores"],
@@ -466,6 +476,37 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
}
}
// Frequency db update
if (TAC_CFG.frequencySort) {
let name = null;
switch (tagType) {
case ResultType.wildcardFile:
case ResultType.yamlWildcard:
// We only want to update the frequency for a full wildcard, not partial paths
if (sanitizedText.endsWith("__"))
name = text
break;
case ResultType.chant:
// Chants use a slightly different format
name = result.aliases;
break;
default:
name = text;
break;
}
if (name && name.length > 0) {
// Check if it's a negative prompt
let textAreaId = getTextAreaIdentifier(textArea);
let isNegative = textAreaId.includes("n");
// Sanitize name for API call
name = encodeURIComponent(name)
// Call API & update db
increaseUseCount(name, tagType, isNegative)
}
}
var prompt = textArea.value;
// Edit prompt text
@@ -574,6 +615,8 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
tacSelfTrigger = true;
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
// Uses a built-in method from the webui's ui.js which also already accounts for event target
if (tagType === ResultType.wildcardTag || tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard)
tacSelfTrigger = true;
updateInput(textArea);
// Update previous tags with the edited prompt to prevent re-searching the same term
@@ -688,6 +731,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
let wikiLink = document.createElement("a");
wikiLink.classList.add("acWikiLink");
wikiLink.innerText = "?";
wikiLink.title = "Open external wiki page for this tag"
let linkPart = displayText;
// Only use alias result if it is one
@@ -733,7 +777,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
}
// Post count
if (result.count && !isNaN(result.count)) {
if (result.count && !isNaN(result.count) && result.count !== Number.MAX_SAFE_INTEGER) {
let postCount = result.count;
let formatter;
@@ -765,8 +809,24 @@ function addResultsToList(textArea, results, tagword, resetList) {
flexDiv.appendChild(metaDiv);
}
// Add small ✨ marker to indicate usage sorting
if (result.usageBias) {
flexDiv.querySelector(".acMetaText").classList.add("biased");
flexDiv.title = "✨ Frequent tag. Ctrl/Cmd + click to reset usage count."
}
// Check if it's a negative prompt
let isNegative = textAreaId.includes("n");
// Add listener
li.addEventListener("click", function () { insertTextAtCursor(textArea, result, tagword); });
li.addEventListener("click", (e) => {
if (e.ctrlKey || e.metaKey) {
resetUseCount(result.text, result.type, !isNegative, isNegative);
flexDiv.querySelector(".acMetaText").classList.remove("biased");
} else {
insertTextAtCursor(textArea, result, tagword);
}
});
// Add element to list
resultsList.appendChild(li);
}
@@ -1034,6 +1094,9 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
resultCountBeforeNormalTags = 0;
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
// Needed for slicing check later
let normalTags = false;
// Process all parsers
let resultCandidates = (await processParsers(textArea, prompt))?.filter(x => x.length > 0);
// If one ore more result candidates match, use their results
@@ -1043,32 +1106,12 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
// Sort results, but not if it's umi tags since they are sorted by count
if (!(resultCandidates.length === 1 && results[0].type === ResultType.umiWildcard))
results = results.sort(getSortFunction());
// Since some tags are kaomoji, we have to add the normal results in some cases
if (tagword.startsWith("<") || tagword.startsWith("*<")) {
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
tagword = tagword.slice(1);
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
} else {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
result.count = g[2];
result.aliases = g[3];
results.push(result);
});
}
}
// Else search the normal tag list
if (!resultCandidates || resultCandidates.length === 0
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
) {
normalTags = true;
resultCountBeforeNormalTags = results.length;
// Create escaped search regex with support for * as a start placeholder
@@ -1123,11 +1166,6 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
results = results.concat(extraResults);
}
}
// Slice if the user has set a max result count
if (!TAC_CFG.showAllResults) {
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
}
// Guard for empty results
@@ -1137,6 +1175,57 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
return;
}
// Sort again with frequency / usage count if enabled
if (TAC_CFG.frequencySort) {
// Split our results into a list of names and types
let tagNames = [];
let aliasNames = [];
let types = [];
// Limit to 2k for performance reasons
const aliasTypes = [ResultType.tag, ResultType.extra];
results.slice(0,2000).forEach(r => {
const name = r.type === ResultType.chant ? r.aliases : r.text;
// Add to alias list or tag list depending on if the name includes the tagword
// (the same criteria is used in the filter in calculateUsageBias)
if (aliasTypes.includes(r.type) && !name.includes(tagword)) {
aliasNames.push(name);
} else {
tagNames.push(name);
}
types.push(r.type);
});
// Check if it's a negative prompt
let textAreaId = getTextAreaIdentifier(textArea);
let isNegative = textAreaId.includes("n");
// Request use counts from the DB
const names = TAC_CFG.frequencyIncludeAlias ? tagNames.concat(aliasNames) : tagNames;
const counts = await getUseCounts(names, types, isNegative);
// Pre-calculate weights to prevent duplicate work
const resultBiasMap = new Map();
results.forEach(result => {
const name = result.type === ResultType.chant ? result.aliases : result.text;
const type = result.type;
// Find matching pair from DB results
const useStats = counts.find(c => c.name === name && c.type === type);
const uses = useStats?.count || 0;
// Calculate & set weight
const weight = calculateUsageBias(result, result.count, uses)
resultBiasMap.set(result, weight);
});
// Actual sorting with the pre-calculated weights
results = results.sort((a, b) => {
return resultBiasMap.get(b) - resultBiasMap.get(a);
});
}
// Slice if the user has set a max result count and we are not in a extra networks / wildcard list
if (!TAC_CFG.showAllResults && normalTags) {
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
addResultsToList(textArea, results, tagword, true);
showResults(textArea);
}
@@ -1272,7 +1361,7 @@ async function refreshTacTempFiles(api = false) {
}
if (api) {
await postAPI("tacapi/v1/refresh-temp-files", null);
await postAPI("tacapi/v1/refresh-temp-files");
await reload();
} else {
setTimeout(async () => {

View File

@@ -2,7 +2,9 @@
# to a temporary file to expose it to the javascript side
import glob
import importlib
import json
import sqlite3
import urllib.parse
from pathlib import Path
@@ -11,12 +13,26 @@ import yaml
from fastapi import FastAPI
from fastapi.responses import Response, FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared, hashes
from pydantic import BaseModel
from scripts.model_keyword_support import (get_lora_simple_hash,
load_hash_cache, update_hash_cache,
write_model_keyword_path)
from scripts.shared_paths import *
try:
import scripts.tag_frequency_db as tdb
# Ensure the db dependency is reloaded on script reload
importlib.reload(tdb)
db = tdb.TagFrequencyDb()
if int(db.version) != int(tdb.db_ver):
raise ValueError("Database version mismatch")
except (ImportError, ValueError, sqlite3.Error) as e:
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
db = None
# Attempt to get embedding load function, using the same call as api.
try:
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
@@ -488,6 +504,13 @@ def on_ui_settings():
return self
shared.OptionInfo.needs_restart = needs_restart
# Dictionary of function options and their explanations
frequency_sort_functions = {
"Logarithmic (weak)": "Will respect the base order and slightly prefer often used tags",
"Logarithmic (strong)": "Same as Logarithmic (weak), but with a stronger bias",
"Usage first": "Will list used tags by frequency before all others",
}
tac_options = {
# Main tag file
"tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files),
@@ -519,6 +542,13 @@ def on_ui_settings():
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
# Frequency sorting settings
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
"tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
"tac_frequencyMinCount": shared.OptionInfo(3, "Minimum number of uses for a tag to be considered frequent").info("Tags with less uses than this will not be sorted higher, even if the sorting function would normally result in a higher position."),
"tac_frequencyMaxAge": shared.OptionInfo(30, "Maximum days since last use for a tag to be considered frequent").info("Similar to the above, tags that haven't been used in this many days will not be sorted higher. Set to 0 to disable."),
"tac_frequencyRecommendCap": shared.OptionInfo(10, "Maximum number of recommended tags").info("Limits the maximum number of recommended tags to not drown out normal results. Set to 0 to disable."),
"tac_frequencyIncludeAlias": shared.OptionInfo(False, "Frequency sorting matches aliases for frequent tags").info("Tag frequency will be increased for the main tag even if an alias is used for completion. This option can be used to override the default behavior of alias results being ignored for frequency sorting."),
# Insertion related settings
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
@@ -736,5 +766,59 @@ def api_tac(_: gr.Blocks, app: FastAPI):
return Response(status_code=200) # Success
else:
return Response(status_code=304) # Not modified
def db_request(func, get = False):
if db is not None:
try:
if get:
ret = func()
if ret is list:
ret = [{"name": t[0], "type": t[1], "count": t[2], "lastUseDate": t[3]} for t in ret]
return JSONResponse({"result": ret})
else:
func()
except sqlite3.Error as e:
return JSONResponse({"error": e.__cause__}, status_code=500)
else:
return JSONResponse({"error": "Database not initialized"}, status_code=500)
@app.post("/tacapi/v1/increase-use-count")
async def increase_use_count(tagname: str, ttype: int, neg: bool):
db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
@app.get("/tacapi/v1/get-use-count")
async def get_use_count(tagname: str, ttype: int, neg: bool):
return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
# Small dataholder class
class UseCountListRequest(BaseModel):
tagNames: list[str]
tagTypes: list[int]
neg: bool = False
# Semantically weird to use post here, but it's required for the body on js side
@app.post("/tacapi/v1/get-use-count-list")
async def get_use_count_list(body: UseCountListRequest):
# If a date limit is set > 0, pass it to the db
date_limit = getattr(shared.opts, "tac_frequencyMaxAge", 30)
date_limit = date_limit if date_limit > 0 else None
count_list = list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg, date_limit))
# If a limit is set, return at max the top n results by count
if count_list and len(count_list):
limit = int(min(getattr(shared.opts, "tac_frequencyRecommendCap", 10), len(count_list)))
# Sort by count and return the top n
if limit > 0:
count_list = sorted(count_list, key=lambda x: x[2], reverse=True)[:limit]
return db_request(lambda: count_list, get=True)
@app.put("/tacapi/v1/reset-use-count")
async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
@app.get("/tacapi/v1/get-all-use-counts")
async def get_all_tag_counts():
return db_request(lambda: db.get_all_tags(), get=True)
script_callbacks.on_app_started(api_tac)

189
scripts/tag_frequency_db.py Normal file
View File

@@ -0,0 +1,189 @@
import sqlite3
from contextlib import contextmanager
from scripts.shared_paths import TAGS_PATH
db_file = TAGS_PATH.joinpath("tag_frequency.db")
timeout = 30
db_ver = 1
@contextmanager
def transaction(db=db_file):
"""Context manager for database transactions.
Ensures that the connection is properly closed after the transaction.
"""
try:
conn = sqlite3.connect(db, timeout=timeout)
conn.isolation_level = None
cursor = conn.cursor()
cursor.execute("BEGIN")
yield cursor
cursor.execute("COMMIT")
except sqlite3.Error as e:
print("Tag Autocomplete: Frequency database error:", e)
finally:
if conn:
conn.close()
class TagFrequencyDb:
"""Class containing creation and interaction methods for the tag frequency database"""
def __init__(self) -> None:
self.version = self.__check()
def __check(self):
if not db_file.exists():
print("Tag Autocomplete: Creating frequency database")
with transaction() as cursor:
self.__create_db(cursor)
self.__update_db_data(cursor, "version", db_ver)
print("Tag Autocomplete: Database successfully created")
return self.__get_version()
def __create_db(self, cursor: sqlite3.Cursor):
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS db_data (
key TEXT PRIMARY KEY,
value TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tag_frequency (
name TEXT NOT NULL,
type INT NOT NULL,
count_pos INT,
count_neg INT,
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (name, type)
)
"""
)
def __update_db_data(self, cursor: sqlite3.Cursor, key, value):
cursor.execute(
"""
INSERT OR REPLACE
INTO db_data (key, value)
VALUES (?, ?)
""",
(key, value),
)
def __get_version(self):
with transaction() as cursor:
cursor.execute(
"""
SELECT value
FROM db_data
WHERE key = 'version'
"""
)
db_version = cursor.fetchone()
return db_version[0] if db_version else 0
def get_all_tags(self):
with transaction() as cursor:
cursor.execute(
f"""
SELECT name, type, count_pos, count_neg, last_used
FROM tag_frequency
WHERE count_pos > 0 OR count_neg > 0
ORDER BY count_pos + count_neg DESC
"""
)
tags = cursor.fetchall()
return tags
def get_tag_count(self, tag, ttype, negative=False):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
""",
(tag, ttype),
)
tag_count = cursor.fetchone()
if tag_count:
return tag_count[0], tag_count[1]
else:
return 0, None
def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False, date_limit=None):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
for tag, ttype in zip(tags, ttypes):
if date_limit is not None:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
AND last_used > datetime('now', '-' || ? || ' days')
""",
(tag, ttype, date_limit),
)
else:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
""",
(tag, ttype),
)
tag_count = cursor.fetchone()
if tag_count:
yield (tag, ttype, tag_count[0], tag_count[1])
else:
yield (tag, ttype, 0, None)
def increase_tag_count(self, tag, ttype, negative=False):
pos_count = self.get_tag_count(tag, ttype, False)[0]
neg_count = self.get_tag_count(tag, ttype, True)[0]
if negative:
neg_count += 1
else:
pos_count += 1
with transaction() as cursor:
cursor.execute(
f"""
INSERT OR REPLACE
INTO tag_frequency (name, type, count_pos, count_neg)
VALUES (?, ?, ?, ?)
""",
(tag, ttype, pos_count, neg_count),
)
def reset_tag_count(self, tag, ttype, positive=True, negative=False):
if positive and negative:
set_str = "count_pos = 0, count_neg = 0"
elif positive:
set_str = "count_pos = 0"
elif negative:
set_str = "count_neg = 0"
with transaction() as cursor:
cursor.execute(
f"""
UPDATE tag_frequency
SET {set_str}
WHERE name = ? AND type = ?
""",
(tag, ttype),
)