WIP database setup inspired by ImageBrowser

This commit is contained in:
DominikDoom
2023-09-24 16:28:32 +02:00
parent 2baa12fea3
commit 0f487a5c5c
2 changed files with 152 additions and 3 deletions

View File

@@ -18,6 +18,15 @@ from scripts.model_keyword_support import (get_lora_simple_hash,
write_model_keyword_path)
from scripts.shared_paths import *
try:
from scripts.tag_frequency_db import TagFrequencyDb, version
db = TagFrequencyDb()
if db.version != version:
raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting")
except (ImportError, ValueError):
print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting")
db = None
# Attempt to get embedding load function, using the same call as api.
try:
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
@@ -569,14 +578,20 @@ def api_tac(_: gr.Blocks, app: FastAPI):
@app.post("/tacapi/v1/increase-use-count/{tagname}")
async def increase_use_count(tagname: str):
pass
db.increase_tag_count(tagname)
@app.get("/tacapi/v1/get-use-count/{tagname}")
async def get_use_count(tagname: str):
return JSONResponse({"count": 0})
db_count = db.get_tag_count(tagname)
return JSONResponse({"count": db_count})
@app.put("/tacapi/v1/reset-use-count/{tagname}")
async def reset_use_count(tagname: str):
pass
db.reset_tag_count(tagname)
@app.get("/tacapi/v1/get-all-tag-counts")
async def get_all_tag_counts():
db_counts = db.get_all_tags()
return JSONResponse({"counts": db_counts})
script_callbacks.on_app_started(api_tac)

134
scripts/tag_frequency_db.py Normal file
View File

@@ -0,0 +1,134 @@
import sqlite3
from contextlib import contextmanager
from scripts.shared_paths import TAGS_PATH
db_file = TAGS_PATH.joinpath("tag_frequency.db")
timeout = 30
version = 1
@contextmanager
def transaction(db=db_file):
"""Context manager for database transactions.
Ensures that the connection is properly closed after the transaction.
"""
conn = sqlite3.connect(db, timeout=timeout)
try:
conn.isolation_level = None
cursor = conn.cursor()
cursor.execute("BEGIN")
yield cursor
cursor.execute("COMMIT")
finally:
conn.close()
class TagFrequencyDb:
"""Class containing creation and interaction methods for the tag frequency database"""
def __init__(self) -> None:
self.version = self.__check()
def __check(self):
if not db_file.exists():
print("Tag Autocomplete: Creating frequency database")
with transaction() as cursor:
self.__create_db(cursor)
self.__update_db_data(cursor, "version", version)
print("Tag Autocomplete: Database successfully created")
return self.__get_version()
def __create_db(self, cursor: sqlite3.Cursor):
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS db_data (
key TEXT PRIMARY KEY,
value TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tag_frequency (
name TEXT PRIMARY KEY,
count INT,
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
def __update_db_data(cursor: sqlite3.Cursor, key, value):
cursor.execute(
"""
INSERT OR REPLACE
INTO db_data (key, value)
VALUES (?, ?)
""",
(key, value),
)
def __get_version(self):
with transaction() as cursor:
cursor.execute(
"""
SELECT value
FROM db_data
WHERE key = 'version'
"""
)
db_version = cursor.fetchone()
return db_version
def get_all_tags(self):
with transaction() as cursor:
cursor.execute(
"""
SELECT name
FROM tag_frequency
ORDER BY count DESC
"""
)
tags = cursor.fetchall()
return tags
def get_tag_count(self, tag):
with transaction() as cursor:
cursor.execute(
"""
SELECT count
FROM tag_frequency
WHERE name = ?
""",
(tag,),
)
tag_count = cursor.fetchone()
return tag_count or 0
def increase_tag_count(self, tag):
current_count = self.get_tag_count(tag)
with transaction() as cursor:
cursor.execute(
"""
INSERT OR REPLACE
INTO tag_frequency (name, count)
VALUES (?, ?)
""",
(tag, current_count + 1),
)
def reset_tag_count(self, tag):
with transaction() as cursor:
cursor.execute(
"""
UPDATE tag_frequency
SET count = 0
WHERE name = ?
""",
(tag,),
)