diff --git a/.gitignore b/.gitignore
index e9e3707..d324b1b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
tags/temp/
__pycache__/
+tags/tag_frequency.db
diff --git a/javascript/_result.js b/javascript/_result.js
index 73dab12..8b81a1b 100644
--- a/javascript/_result.js
+++ b/javascript/_result.js
@@ -24,7 +24,8 @@ class AutocompleteResult {
// Additional info, only used in some cases
category = null;
- count = null;
+ count = Number.MAX_SAFE_INTEGER;
+ usageBias = null;
aliases = null;
meta = null;
hash = null;
diff --git a/javascript/_utils.js b/javascript/_utils.js
index 2eaed4a..f44aff2 100644
--- a/javascript/_utils.js
+++ b/javascript/_utils.js
@@ -80,8 +80,12 @@ async function fetchAPI(url, json = true, cache = false) {
return await response.text();
}
-async function postAPI(url, body) {
- let response = await fetch(url, { method: "POST", body: body });
+async function postAPI(url, body = null) {
+ let response = await fetch(url, {
+ method: "POST",
+ headers: {'Content-Type': 'application/json'},
+ body: body
+ });
if (response.status != 200) {
console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText);
@@ -91,6 +95,17 @@ async function postAPI(url, body) {
return await response.json();
}
+async function putAPI(url, body = null) {
+ let response = await fetch(url, { method: "PUT", body: body });
+
+ if (response.status != 200) {
+ console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText);
+ return null;
+ }
+
+ return await response.json();
+}
+
// Extra network preview thumbnails
async function getExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
@@ -180,6 +195,104 @@ function flatten(obj, roots = [], sep = ".") {
);
}
+// Calculate biased tag score based on post count and frequent usage
+function calculateUsageBias(result, count, uses) {
+ // Check setting conditions
+ if (uses < TAC_CFG.frequencyMinCount) {
+ uses = 0;
+ } else if (uses != 0) {
+ result.usageBias = true;
+ }
+
+ switch (TAC_CFG.frequencyFunction) {
+ case "Logarithmic (weak)":
+ return Math.log(1 + count) + Math.log(1 + uses);
+ case "Logarithmic (strong)":
+ return Math.log(1 + count) + 2 * Math.log(1 + uses);
+ case "Usage first":
+ return uses;
+ default:
+ return count;
+ }
+}
+// Beautify return type for easier parsing
+function mapUseCountArray(useCounts, posAndNeg = false) {
+ return useCounts.map(useCount => {
+ if (posAndNeg) {
+ return {
+ "name": useCount[0],
+ "type": useCount[1],
+ "count": useCount[2],
+ "negCount": useCount[3],
+ "lastUseDate": useCount[4]
+ }
+ }
+ return {
+ "name": useCount[0],
+ "type": useCount[1],
+ "count": useCount[2],
+ "lastUseDate": useCount[3]
+ }
+ });
+}
+// Call API endpoint to increase bias of tag in the database
+function increaseUseCount(tagName, type, negative = false) {
+ postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`);
+}
+// Get use count of tag from the database
+async function getUseCount(tagName, type, negative = false) {
+ return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false))["result"];
+}
+async function getUseCounts(tagNames, types, negative = false) {
+ // While semantically weird, we have to use POST here for the body, as urls are limited in length
+ const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative});
+ const rawArray = (await postAPI(`tacapi/v1/get-use-count-list`, body))["result"]
+ return mapUseCountArray(rawArray);
+}
+async function getAllUseCounts() {
+ const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts`))["result"];
+ return mapUseCountArray(rawArray, true);
+}
+async function resetUseCount(tagName, type, resetPosCount, resetNegCount) {
+ await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`);
+}
+
+function createTagUsageTable(tagCounts) {
+ // Create table
+ let tagTable = document.createElement("table");
+ tagTable.innerHTML =
+ `
+
+ | Name |
+ Type |
+ Count(+) |
+ Count(-) |
+ Last used |
+
+ `;
+ tagTable.id = "tac_tagUsageTable"
+
+ tagCounts.forEach(t => {
+ let tr = document.createElement("tr");
+
+ // Fill values
+ let values = [t.name, t.type-1, t.count, t.negCount, t.lastUseDate]
+ values.forEach(v => {
+ let td = document.createElement("td");
+ td.innerText = v;
+ tr.append(td);
+ });
+ // Add delete/reset button
+ let delButton = document.createElement("button");
+ delButton.innerText = "🗑️";
+ delButton.title = "Reset count";
+ tr.append(delButton);
+
+ tagTable.append(tr)
+ });
+
+ return tagTable;
+}
// Sliding window function to get possible combination groups of an array
function toNgrams(inputArray, size) {
@@ -242,12 +355,19 @@ function getSortFunction() {
let criterion = TAC_CFG.modelSortOrder || "Name";
const textSort = (a, b, reverse = false) => {
- const textHolderA = a.type === ResultType.chant ? a.aliases : a.text;
- const textHolderB = b.type === ResultType.chant ? b.aliases : b.text;
+ // Assign keys so next sort is faster
+ if (!a.sortKey) {
+ a.sortKey = a.type === ResultType.chant
+ ? a.aliases
+ : a.text;
+ }
+ if (!b.sortKey) {
+ b.sortKey = b.type === ResultType.chant
+ ? b.aliases
+ : b.text;
+ }
- const aKey = a.sortKey || textHolderA;
- const bKey = b.sortKey || textHolderB;
- return reverse ? bKey.localeCompare(aKey) : aKey.localeCompare(bKey);
+ return reverse ? b.sortKey.localeCompare(a.sortKey) : a.sortKey.localeCompare(b.sortKey);
}
const numericSort = (a, b, reverse = false) => {
const noKey = reverse ? "-1" : Number.MAX_SAFE_INTEGER;
diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js
index 6b199a4..e29a983 100644
--- a/javascript/tagAutocomplete.js
+++ b/javascript/tagAutocomplete.js
@@ -86,6 +86,10 @@ const autocompleteCSS = `
white-space: nowrap;
color: var(--meta-text-color);
}
+ .acMetaText.biased::before {
+ content: "✨";
+ margin-right: 2px;
+ }
.acWikiLink {
padding: 0.5rem;
margin: -0.5rem 0 -0.5rem -0.5rem;
@@ -221,6 +225,12 @@ async function syncOptions() {
showWikiLinks: opts["tac_showWikiLinks"],
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
modelSortOrder: opts["tac_modelSortOrder"],
+ frequencySort: opts["tac_frequencySort"],
+ frequencyFunction: opts["tac_frequencyFunction"],
+ frequencyMinCount: opts["tac_frequencyMinCount"],
+ frequencyMaxAge: opts["tac_frequencyMaxAge"],
+ frequencyRecommendCap: opts["tac_frequencyRecommendCap"],
+ frequencyIncludeAlias: opts["tac_frequencyIncludeAlias"],
useStyleVars: opts["tac_useStyleVars"],
// Insertion related settings
replaceUnderscores: opts["tac_replaceUnderscores"],
@@ -466,6 +476,37 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
}
}
+ // Frequency db update
+ if (TAC_CFG.frequencySort) {
+ let name = null;
+
+ switch (tagType) {
+ case ResultType.wildcardFile:
+ case ResultType.yamlWildcard:
+ // We only want to update the frequency for a full wildcard, not partial paths
+ if (sanitizedText.endsWith("__"))
+ name = text
+ break;
+ case ResultType.chant:
+ // Chants use a slightly different format
+ name = result.aliases;
+ break;
+ default:
+ name = text;
+ break;
+ }
+
+ if (name && name.length > 0) {
+ // Check if it's a negative prompt
+ let textAreaId = getTextAreaIdentifier(textArea);
+ let isNegative = textAreaId.includes("n");
+ // Sanitize name for API call
+ name = encodeURIComponent(name)
+ // Call API & update db
+ increaseUseCount(name, tagType, isNegative)
+ }
+ }
+
var prompt = textArea.value;
// Edit prompt text
@@ -574,6 +615,8 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
tacSelfTrigger = true;
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
// Uses a built-in method from the webui's ui.js which also already accounts for event target
+ if (tagType === ResultType.wildcardTag || tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard)
+ tacSelfTrigger = true;
updateInput(textArea);
// Update previous tags with the edited prompt to prevent re-searching the same term
@@ -688,6 +731,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
let wikiLink = document.createElement("a");
wikiLink.classList.add("acWikiLink");
wikiLink.innerText = "?";
+ wikiLink.title = "Open external wiki page for this tag"
let linkPart = displayText;
// Only use alias result if it is one
@@ -733,7 +777,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
}
// Post count
- if (result.count && !isNaN(result.count)) {
+ if (result.count && !isNaN(result.count) && result.count !== Number.MAX_SAFE_INTEGER) {
let postCount = result.count;
let formatter;
@@ -765,8 +809,24 @@ function addResultsToList(textArea, results, tagword, resetList) {
flexDiv.appendChild(metaDiv);
}
+ // Add small ✨ marker to indicate usage sorting
+ if (result.usageBias) {
+ flexDiv.querySelector(".acMetaText").classList.add("biased");
+ flexDiv.title = "✨ Frequent tag. Ctrl/Cmd + click to reset usage count."
+ }
+
+ // Check if it's a negative prompt
+ let isNegative = textAreaId.includes("n");
+
// Add listener
- li.addEventListener("click", function () { insertTextAtCursor(textArea, result, tagword); });
+ li.addEventListener("click", (e) => {
+ if (e.ctrlKey || e.metaKey) {
+ resetUseCount(result.text, result.type, !isNegative, isNegative);
+ flexDiv.querySelector(".acMetaText").classList.remove("biased");
+ } else {
+ insertTextAtCursor(textArea, result, tagword);
+ }
+ });
// Add element to list
resultsList.appendChild(li);
}
@@ -1034,6 +1094,9 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
resultCountBeforeNormalTags = 0;
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
+ // Needed for slicing check later
+ let normalTags = false;
+
// Process all parsers
let resultCandidates = (await processParsers(textArea, prompt))?.filter(x => x.length > 0);
// If one ore more result candidates match, use their results
@@ -1043,32 +1106,12 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
// Sort results, but not if it's umi tags since they are sorted by count
if (!(resultCandidates.length === 1 && results[0].type === ResultType.umiWildcard))
results = results.sort(getSortFunction());
-
- // Since some tags are kaomoji, we have to add the normal results in some cases
- if (tagword.startsWith("<") || tagword.startsWith("*<")) {
- // Create escaped search regex with support for * as a start placeholder
- let searchRegex;
- if (tagword.startsWith("*")) {
- tagword = tagword.slice(1);
- searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
- } else {
- searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
- }
- let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
-
- genericResults.forEach(g => {
- let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
- result.category = g[1];
- result.count = g[2];
- result.aliases = g[3];
- results.push(result);
- });
- }
}
// Else search the normal tag list
if (!resultCandidates || resultCandidates.length === 0
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
) {
+ normalTags = true;
resultCountBeforeNormalTags = results.length;
// Create escaped search regex with support for * as a start placeholder
@@ -1123,11 +1166,6 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
results = results.concat(extraResults);
}
}
-
- // Slice if the user has set a max result count
- if (!TAC_CFG.showAllResults) {
- results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
- }
}
// Guard for empty results
@@ -1137,6 +1175,57 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
return;
}
+ // Sort again with frequency / usage count if enabled
+ if (TAC_CFG.frequencySort) {
+ // Split our results into a list of names and types
+ let tagNames = [];
+ let aliasNames = [];
+ let types = [];
+ // Limit to 2k for performance reasons
+ const aliasTypes = [ResultType.tag, ResultType.extra];
+ results.slice(0,2000).forEach(r => {
+ const name = r.type === ResultType.chant ? r.aliases : r.text;
+ // Add to alias list or tag list depending on if the name includes the tagword
+ // (the same criteria is used in the filter in calculateUsageBias)
+ if (aliasTypes.includes(r.type) && !name.includes(tagword)) {
+ aliasNames.push(name);
+ } else {
+ tagNames.push(name);
+ }
+ types.push(r.type);
+ });
+
+ // Check if it's a negative prompt
+ let textAreaId = getTextAreaIdentifier(textArea);
+ let isNegative = textAreaId.includes("n");
+
+ // Request use counts from the DB
+ const names = TAC_CFG.frequencyIncludeAlias ? tagNames.concat(aliasNames) : tagNames;
+ const counts = await getUseCounts(names, types, isNegative);
+
+ // Pre-calculate weights to prevent duplicate work
+ const resultBiasMap = new Map();
+ results.forEach(result => {
+ const name = result.type === ResultType.chant ? result.aliases : result.text;
+ const type = result.type;
+ // Find matching pair from DB results
+ const useStats = counts.find(c => c.name === name && c.type === type);
+ const uses = useStats?.count || 0;
+ // Calculate & set weight
+ const weight = calculateUsageBias(result, result.count, uses)
+ resultBiasMap.set(result, weight);
+ });
+ // Actual sorting with the pre-calculated weights
+ results = results.sort((a, b) => {
+ return resultBiasMap.get(b) - resultBiasMap.get(a);
+ });
+ }
+
+ // Slice if the user has set a max result count and we are not in a extra networks / wildcard list
+ if (!TAC_CFG.showAllResults && normalTags) {
+ results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
+ }
+
addResultsToList(textArea, results, tagword, true);
showResults(textArea);
}
@@ -1272,7 +1361,7 @@ async function refreshTacTempFiles(api = false) {
}
if (api) {
- await postAPI("tacapi/v1/refresh-temp-files", null);
+ await postAPI("tacapi/v1/refresh-temp-files");
await reload();
} else {
setTimeout(async () => {
diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py
index 9e6ca05..0a2c479 100644
--- a/scripts/tag_autocomplete_helper.py
+++ b/scripts/tag_autocomplete_helper.py
@@ -2,7 +2,9 @@
# to a temporary file to expose it to the javascript side
import glob
+import importlib
import json
+import sqlite3
import urllib.parse
from pathlib import Path
@@ -11,12 +13,26 @@ import yaml
from fastapi import FastAPI
from fastapi.responses import Response, FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared, hashes
+from pydantic import BaseModel
from scripts.model_keyword_support import (get_lora_simple_hash,
load_hash_cache, update_hash_cache,
write_model_keyword_path)
from scripts.shared_paths import *
+try:
+ import scripts.tag_frequency_db as tdb
+
+ # Ensure the db dependency is reloaded on script reload
+ importlib.reload(tdb)
+
+ db = tdb.TagFrequencyDb()
+ if int(db.version) != int(tdb.db_ver):
+ raise ValueError("Database version mismatch")
+except (ImportError, ValueError, sqlite3.Error) as e:
+ print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
+ db = None
+
# Attempt to get embedding load function, using the same call as api.
try:
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
@@ -488,6 +504,13 @@ def on_ui_settings():
return self
shared.OptionInfo.needs_restart = needs_restart
+ # Dictionary of function options and their explanations
+ frequency_sort_functions = {
+ "Logarithmic (weak)": "Will respect the base order and slightly prefer often used tags",
+ "Logarithmic (strong)": "Same as Logarithmic (weak), but with a stronger bias",
+ "Usage first": "Will list used tags by frequency before all others",
+ }
+
tac_options = {
# Main tag file
"tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files),
@@ -519,6 +542,13 @@ def on_ui_settings():
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like style-vars to actually apply the styles before generating."),
+ # Frequency sorting settings
+ "tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
+ "tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'{key}: {val}' for key, val in frequency_sort_functions.items()])),
+ "tac_frequencyMinCount": shared.OptionInfo(3, "Minimum number of uses for a tag to be considered frequent").info("Tags with less uses than this will not be sorted higher, even if the sorting function would normally result in a higher position."),
+ "tac_frequencyMaxAge": shared.OptionInfo(30, "Maximum days since last use for a tag to be considered frequent").info("Similar to the above, tags that haven't been used in this many days will not be sorted higher. Set to 0 to disable."),
+ "tac_frequencyRecommendCap": shared.OptionInfo(10, "Maximum number of recommended tags").info("Limits the maximum number of recommended tags to not drown out normal results. Set to 0 to disable."),
+ "tac_frequencyIncludeAlias": shared.OptionInfo(False, "Frequency sorting matches aliases for frequent tags").info("Tag frequency will be increased for the main tag even if an alias is used for completion. This option can be used to override the default behavior of alias results being ignored for frequency sorting."),
# Insertion related settings
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
@@ -736,5 +766,59 @@ def api_tac(_: gr.Blocks, app: FastAPI):
return Response(status_code=200) # Success
else:
return Response(status_code=304) # Not modified
+ def db_request(func, get = False):
+ if db is not None:
+ try:
+ if get:
+ ret = func()
+ if ret is list:
+ ret = [{"name": t[0], "type": t[1], "count": t[2], "lastUseDate": t[3]} for t in ret]
+ return JSONResponse({"result": ret})
+ else:
+ func()
+ except sqlite3.Error as e:
+ return JSONResponse({"error": e.__cause__}, status_code=500)
+ else:
+ return JSONResponse({"error": "Database not initialized"}, status_code=500)
+ @app.post("/tacapi/v1/increase-use-count")
+ async def increase_use_count(tagname: str, ttype: int, neg: bool):
+ db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
+
+ @app.get("/tacapi/v1/get-use-count")
+ async def get_use_count(tagname: str, ttype: int, neg: bool):
+ return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
+
+ # Small dataholder class
+ class UseCountListRequest(BaseModel):
+ tagNames: list[str]
+ tagTypes: list[int]
+ neg: bool = False
+
+ # Semantically weird to use post here, but it's required for the body on js side
+ @app.post("/tacapi/v1/get-use-count-list")
+ async def get_use_count_list(body: UseCountListRequest):
+ # If a date limit is set > 0, pass it to the db
+ date_limit = getattr(shared.opts, "tac_frequencyMaxAge", 30)
+ date_limit = date_limit if date_limit > 0 else None
+
+ count_list = list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg, date_limit))
+
+ # If a limit is set, return at max the top n results by count
+ if count_list and len(count_list):
+ limit = int(min(getattr(shared.opts, "tac_frequencyRecommendCap", 10), len(count_list)))
+ # Sort by count and return the top n
+ if limit > 0:
+ count_list = sorted(count_list, key=lambda x: x[2], reverse=True)[:limit]
+
+ return db_request(lambda: count_list, get=True)
+
+ @app.put("/tacapi/v1/reset-use-count")
+ async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
+ db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
+
+ @app.get("/tacapi/v1/get-all-use-counts")
+ async def get_all_tag_counts():
+ return db_request(lambda: db.get_all_tags(), get=True)
+
script_callbacks.on_app_started(api_tac)
diff --git a/scripts/tag_frequency_db.py b/scripts/tag_frequency_db.py
new file mode 100644
index 0000000..5b1b195
--- /dev/null
+++ b/scripts/tag_frequency_db.py
@@ -0,0 +1,189 @@
+import sqlite3
+from contextlib import contextmanager
+
+from scripts.shared_paths import TAGS_PATH
+
+db_file = TAGS_PATH.joinpath("tag_frequency.db")
+timeout = 30
+db_ver = 1
+
+
+@contextmanager
+def transaction(db=db_file):
+ """Context manager for database transactions.
+ Ensures that the connection is properly closed after the transaction.
+ """
+ try:
+ conn = sqlite3.connect(db, timeout=timeout)
+
+ conn.isolation_level = None
+ cursor = conn.cursor()
+ cursor.execute("BEGIN")
+ yield cursor
+ cursor.execute("COMMIT")
+ except sqlite3.Error as e:
+ print("Tag Autocomplete: Frequency database error:", e)
+ finally:
+ if conn:
+ conn.close()
+
+
+class TagFrequencyDb:
+ """Class containing creation and interaction methods for the tag frequency database"""
+
+ def __init__(self) -> None:
+ self.version = self.__check()
+
+ def __check(self):
+ if not db_file.exists():
+ print("Tag Autocomplete: Creating frequency database")
+ with transaction() as cursor:
+ self.__create_db(cursor)
+ self.__update_db_data(cursor, "version", db_ver)
+ print("Tag Autocomplete: Database successfully created")
+
+ return self.__get_version()
+
+ def __create_db(self, cursor: sqlite3.Cursor):
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS db_data (
+ key TEXT PRIMARY KEY,
+ value TEXT
+ )
+ """
+ )
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS tag_frequency (
+ name TEXT NOT NULL,
+ type INT NOT NULL,
+ count_pos INT,
+ count_neg INT,
+ last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (name, type)
+ )
+ """
+ )
+
+ def __update_db_data(self, cursor: sqlite3.Cursor, key, value):
+ cursor.execute(
+ """
+ INSERT OR REPLACE
+ INTO db_data (key, value)
+ VALUES (?, ?)
+ """,
+ (key, value),
+ )
+
+ def __get_version(self):
+ with transaction() as cursor:
+ cursor.execute(
+ """
+ SELECT value
+ FROM db_data
+ WHERE key = 'version'
+ """
+ )
+ db_version = cursor.fetchone()
+
+ return db_version[0] if db_version else 0
+
+ def get_all_tags(self):
+ with transaction() as cursor:
+ cursor.execute(
+ f"""
+ SELECT name, type, count_pos, count_neg, last_used
+ FROM tag_frequency
+ WHERE count_pos > 0 OR count_neg > 0
+ ORDER BY count_pos + count_neg DESC
+ """
+ )
+ tags = cursor.fetchall()
+
+ return tags
+
+ def get_tag_count(self, tag, ttype, negative=False):
+ count_str = "count_neg" if negative else "count_pos"
+ with transaction() as cursor:
+ cursor.execute(
+ f"""
+ SELECT {count_str}, last_used
+ FROM tag_frequency
+ WHERE name = ? AND type = ?
+ """,
+ (tag, ttype),
+ )
+ tag_count = cursor.fetchone()
+
+ if tag_count:
+ return tag_count[0], tag_count[1]
+ else:
+ return 0, None
+
+ def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False, date_limit=None):
+ count_str = "count_neg" if negative else "count_pos"
+ with transaction() as cursor:
+ for tag, ttype in zip(tags, ttypes):
+ if date_limit is not None:
+ cursor.execute(
+ f"""
+ SELECT {count_str}, last_used
+ FROM tag_frequency
+ WHERE name = ? AND type = ?
+ AND last_used > datetime('now', '-' || ? || ' days')
+ """,
+ (tag, ttype, date_limit),
+ )
+ else:
+ cursor.execute(
+ f"""
+ SELECT {count_str}, last_used
+ FROM tag_frequency
+ WHERE name = ? AND type = ?
+ """,
+ (tag, ttype),
+ )
+ tag_count = cursor.fetchone()
+ if tag_count:
+ yield (tag, ttype, tag_count[0], tag_count[1])
+ else:
+ yield (tag, ttype, 0, None)
+
+ def increase_tag_count(self, tag, ttype, negative=False):
+ pos_count = self.get_tag_count(tag, ttype, False)[0]
+ neg_count = self.get_tag_count(tag, ttype, True)[0]
+
+ if negative:
+ neg_count += 1
+ else:
+ pos_count += 1
+
+ with transaction() as cursor:
+ cursor.execute(
+ f"""
+ INSERT OR REPLACE
+ INTO tag_frequency (name, type, count_pos, count_neg)
+ VALUES (?, ?, ?, ?)
+ """,
+ (tag, ttype, pos_count, neg_count),
+ )
+
+ def reset_tag_count(self, tag, ttype, positive=True, negative=False):
+ if positive and negative:
+ set_str = "count_pos = 0, count_neg = 0"
+ elif positive:
+ set_str = "count_pos = 0"
+ elif negative:
+ set_str = "count_neg = 0"
+
+ with transaction() as cursor:
+ cursor.execute(
+ f"""
+ UPDATE tag_frequency
+ SET {set_str}
+ WHERE name = ? AND type = ?
+ """,
+ (tag, ttype),
+ )