Fix db load version comparison, add sort options

This commit is contained in:
DominikDoom
2023-09-24 17:59:14 +02:00
parent 1e81403180
commit b44c36425a
3 changed files with 28 additions and 9 deletions

View File

@@ -92,6 +92,17 @@ async function postAPI(url, body) {
return await response.json();
}
async function putAPI(url, body) {
let response = await fetch(url, { method: "PUT", body: body });
if (response.status != 200) {
console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText);
return null;
}
return await response.json();
}
// Extra network preview thumbnails
async function getExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);

View File

@@ -1,7 +1,6 @@
# This helper script scans folders for wildcards and embeddings and writes them
# to a temporary file to expose it to the javascript side
import os
import glob
import json
import urllib.parse
@@ -19,11 +18,12 @@ from scripts.model_keyword_support import (get_lora_simple_hash,
from scripts.shared_paths import *
try:
from scripts.tag_frequency_db import TagFrequencyDb, version
from scripts.tag_frequency_db import TagFrequencyDb, db_ver
db = TagFrequencyDb()
if db.version != version:
if int(db.version) != int(db_ver):
raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting")
except (ImportError, ValueError):
except (ImportError, ValueError) as e:
print(e)
print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting")
db = None
@@ -391,6 +391,12 @@ def on_ui_settings():
return self
shared.OptionInfo.needs_restart = needs_restart
# Dictionary of function options and their explanations
frequency_sort_functions = {
"Logarithmic": "Will respect the base order and slightly prefer more frequent tags",
"Usage first": "Will list used tags by frequency before all others",
}
tac_options = {
# Main tag file
"tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files),
@@ -418,6 +424,8 @@ def on_ui_settings():
"tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"),
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
"tac_frequencyFunction": shared.OptionInfo("Logarithmic", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
# Insertion related settings
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),

View File

@@ -5,7 +5,7 @@ from scripts.shared_paths import TAGS_PATH
db_file = TAGS_PATH.joinpath("tag_frequency.db")
timeout = 30
version = 1
db_ver = 1
@contextmanager
@@ -35,7 +35,7 @@ class TagFrequencyDb:
print("Tag Autocomplete: Creating frequency database")
with transaction() as cursor:
self.__create_db(cursor)
self.__update_db_data(cursor, "version", version)
self.__update_db_data(cursor, "version", db_ver)
print("Tag Autocomplete: Database successfully created")
return self.__get_version()
@@ -60,7 +60,7 @@ class TagFrequencyDb:
"""
)
def __update_db_data(cursor: sqlite3.Cursor, key, value):
def __update_db_data(self, cursor: sqlite3.Cursor, key, value):
cursor.execute(
"""
INSERT OR REPLACE
@@ -81,7 +81,7 @@ class TagFrequencyDb:
)
db_version = cursor.fetchone()
return db_version
return db_version[0] if db_version else 0
def get_all_tags(self):
with transaction() as cursor:
@@ -108,7 +108,7 @@ class TagFrequencyDb:
)
tag_count = cursor.fetchone()
return tag_count or 0
return tag_count[0] if tag_count else 0
def increase_tag_count(self, tag):
current_count = self.get_tag_count(tag)