From 0f487a5c5cb2aadb8d15dcc11713a99a197bf200 Mon Sep 17 00:00:00 2001 From: DominikDoom Date: Sun, 24 Sep 2023 16:28:32 +0200 Subject: [PATCH] WIP database setup inspired by ImageBrowser --- scripts/tag_autocomplete_helper.py | 21 ++++- scripts/tag_frequency_db.py | 134 +++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 3 deletions(-) create mode 100644 scripts/tag_frequency_db.py diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 1271673..5af3b34 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -18,6 +18,15 @@ from scripts.model_keyword_support import (get_lora_simple_hash, write_model_keyword_path) from scripts.shared_paths import * +try: + from scripts.tag_frequency_db import TagFrequencyDb, version + db = TagFrequencyDb() + if db.version != version: + raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting") +except (ImportError, ValueError): + print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting") + db = None + # Attempt to get embedding load function, using the same call as api. try: load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings @@ -569,14 +578,20 @@ def api_tac(_: gr.Blocks, app: FastAPI): @app.post("/tacapi/v1/increase-use-count/{tagname}") async def increase_use_count(tagname: str): - pass + db.increase_tag_count(tagname) @app.get("/tacapi/v1/get-use-count/{tagname}") async def get_use_count(tagname: str): - return JSONResponse({"count": 0}) + db_count = db.get_tag_count(tagname) + return JSONResponse({"count": db_count}) @app.put("/tacapi/v1/reset-use-count/{tagname}") async def reset_use_count(tagname: str): - pass + db.reset_tag_count(tagname) + + @app.get("/tacapi/v1/get-all-tag-counts") + async def get_all_tag_counts(): + db_counts = db.get_all_tags() + return JSONResponse({"counts": db_counts}) script_callbacks.on_app_started(api_tac) diff --git a/scripts/tag_frequency_db.py b/scripts/tag_frequency_db.py new file mode 100644 index 0000000..d215b7e --- /dev/null +++ b/scripts/tag_frequency_db.py @@ -0,0 +1,134 @@ +import sqlite3 +from contextlib import contextmanager + +from scripts.shared_paths import TAGS_PATH + +db_file = TAGS_PATH.joinpath("tag_frequency.db") +timeout = 30 +version = 1 + + +@contextmanager +def transaction(db=db_file): + """Context manager for database transactions. + Ensures that the connection is properly closed after the transaction. + """ + conn = sqlite3.connect(db, timeout=timeout) + try: + conn.isolation_level = None + cursor = conn.cursor() + cursor.execute("BEGIN") + yield cursor + cursor.execute("COMMIT") + finally: + conn.close() + + +class TagFrequencyDb: + """Class containing creation and interaction methods for the tag frequency database""" + + def __init__(self) -> None: + self.version = self.__check() + + def __check(self): + if not db_file.exists(): + print("Tag Autocomplete: Creating frequency database") + with transaction() as cursor: + self.__create_db(cursor) + self.__update_db_data(cursor, "version", version) + print("Tag Autocomplete: Database successfully created") + + return self.__get_version() + + def __create_db(self, cursor: sqlite3.Cursor): + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS db_data ( + key TEXT PRIMARY KEY, + value TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS tag_frequency ( + name TEXT PRIMARY KEY, + count INT, + last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + def __update_db_data(cursor: sqlite3.Cursor, key, value): + cursor.execute( + """ + INSERT OR REPLACE + INTO db_data (key, value) + VALUES (?, ?) + """, + (key, value), + ) + + def __get_version(self): + with transaction() as cursor: + cursor.execute( + """ + SELECT value + FROM db_data + WHERE key = 'version' + """ + ) + db_version = cursor.fetchone() + + return db_version + + def get_all_tags(self): + with transaction() as cursor: + cursor.execute( + """ + SELECT name + FROM tag_frequency + ORDER BY count DESC + """ + ) + tags = cursor.fetchall() + + return tags + + def get_tag_count(self, tag): + with transaction() as cursor: + cursor.execute( + """ + SELECT count + FROM tag_frequency + WHERE name = ? + """, + (tag,), + ) + tag_count = cursor.fetchone() + + return tag_count or 0 + + def increase_tag_count(self, tag): + current_count = self.get_tag_count(tag) + with transaction() as cursor: + cursor.execute( + """ + INSERT OR REPLACE + INTO tag_frequency (name, count) + VALUES (?, ?) + """, + (tag, current_count + 1), + ) + + def reset_tag_count(self, tag): + with transaction() as cursor: + cursor.execute( + """ + UPDATE tag_frequency + SET count = 0 + WHERE name = ? + """, + (tag,), + )