From b44c36425a148a49899db0c93b8370ea2d80c46d Mon Sep 17 00:00:00 2001 From: DominikDoom Date: Sun, 24 Sep 2023 17:59:14 +0200 Subject: [PATCH] Fix db load version comparison, add sort options --- javascript/_utils.js | 11 +++++++++++ scripts/tag_autocomplete_helper.py | 16 ++++++++++++---- scripts/tag_frequency_db.py | 10 +++++----- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/javascript/_utils.js b/javascript/_utils.js index 50ab101..d9c7759 100644 --- a/javascript/_utils.js +++ b/javascript/_utils.js @@ -92,6 +92,17 @@ async function postAPI(url, body) { return await response.json(); } +async function putAPI(url, body) { + let response = await fetch(url, { method: "PUT", body: body }); + + if (response.status != 200) { + console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText); + return null; + } + + return await response.json(); +} + // Extra network preview thumbnails async function getExtraNetworkPreviewURL(filename, type) { const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true); diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index bbb4034..97fb756 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -1,7 +1,6 @@ # This helper script scans folders for wildcards and embeddings and writes them # to a temporary file to expose it to the javascript side -import os import glob import json import urllib.parse @@ -19,11 +18,12 @@ from scripts.model_keyword_support import (get_lora_simple_hash, from scripts.shared_paths import * try: - from scripts.tag_frequency_db import TagFrequencyDb, version + from scripts.tag_frequency_db import TagFrequencyDb, db_ver db = TagFrequencyDb() - if db.version != version: + if int(db.version) != int(db_ver): raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting") -except (ImportError, ValueError): +except (ImportError, ValueError) as e: + print(e) print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting") db = None @@ -391,6 +391,12 @@ def on_ui_settings(): return self shared.OptionInfo.needs_restart = needs_restart + # Dictionary of function options and their explanations + frequency_sort_functions = { + "Logarithmic": "Will respect the base order and slightly prefer more frequent tags", + "Usage first": "Will list used tags by frequency before all others", + } + tac_options = { # Main tag file "tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files), @@ -418,6 +424,8 @@ def on_ui_settings(): "tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"), "tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"), "tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"), + "tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"), + "tac_frequencyFunction": shared.OptionInfo("Logarithmic", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'{key}: {val}' for key, val in frequency_sort_functions.items()])), # Insertion related settings "tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"), "tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"), diff --git a/scripts/tag_frequency_db.py b/scripts/tag_frequency_db.py index d215b7e..d186b0d 100644 --- a/scripts/tag_frequency_db.py +++ b/scripts/tag_frequency_db.py @@ -5,7 +5,7 @@ from scripts.shared_paths import TAGS_PATH db_file = TAGS_PATH.joinpath("tag_frequency.db") timeout = 30 -version = 1 +db_ver = 1 @contextmanager @@ -35,7 +35,7 @@ class TagFrequencyDb: print("Tag Autocomplete: Creating frequency database") with transaction() as cursor: self.__create_db(cursor) - self.__update_db_data(cursor, "version", version) + self.__update_db_data(cursor, "version", db_ver) print("Tag Autocomplete: Database successfully created") return self.__get_version() @@ -60,7 +60,7 @@ class TagFrequencyDb: """ ) - def __update_db_data(cursor: sqlite3.Cursor, key, value): + def __update_db_data(self, cursor: sqlite3.Cursor, key, value): cursor.execute( """ INSERT OR REPLACE @@ -81,7 +81,7 @@ class TagFrequencyDb: ) db_version = cursor.fetchone() - return db_version + return db_version[0] if db_version else 0 def get_all_tags(self): with transaction() as cursor: @@ -108,7 +108,7 @@ class TagFrequencyDb: ) tag_count = cursor.fetchone() - return tag_count or 0 + return tag_count[0] if tag_count else 0 def increase_tag_count(self, tag): current_count = self.get_tag_count(tag)