mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-03-04 21:00:03 +00:00
WIP database setup inspired by ImageBrowser
This commit is contained in:
@@ -18,6 +18,15 @@ from scripts.model_keyword_support import (get_lora_simple_hash,
|
||||
write_model_keyword_path)
|
||||
from scripts.shared_paths import *
|
||||
|
||||
try:
|
||||
from scripts.tag_frequency_db import TagFrequencyDb, version
|
||||
db = TagFrequencyDb()
|
||||
if db.version != version:
|
||||
raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting")
|
||||
except (ImportError, ValueError):
|
||||
print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting")
|
||||
db = None
|
||||
|
||||
# Attempt to get embedding load function, using the same call as api.
|
||||
try:
|
||||
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
|
||||
@@ -569,14 +578,20 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
|
||||
@app.post("/tacapi/v1/increase-use-count/{tagname}")
|
||||
async def increase_use_count(tagname: str):
|
||||
pass
|
||||
db.increase_tag_count(tagname)
|
||||
|
||||
@app.get("/tacapi/v1/get-use-count/{tagname}")
|
||||
async def get_use_count(tagname: str):
|
||||
return JSONResponse({"count": 0})
|
||||
db_count = db.get_tag_count(tagname)
|
||||
return JSONResponse({"count": db_count})
|
||||
|
||||
@app.put("/tacapi/v1/reset-use-count/{tagname}")
|
||||
async def reset_use_count(tagname: str):
|
||||
pass
|
||||
db.reset_tag_count(tagname)
|
||||
|
||||
@app.get("/tacapi/v1/get-all-tag-counts")
|
||||
async def get_all_tag_counts():
|
||||
db_counts = db.get_all_tags()
|
||||
return JSONResponse({"counts": db_counts})
|
||||
|
||||
script_callbacks.on_app_started(api_tac)
|
||||
|
||||
134
scripts/tag_frequency_db.py
Normal file
134
scripts/tag_frequency_db.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
|
||||
from scripts.shared_paths import TAGS_PATH
|
||||
|
||||
db_file = TAGS_PATH.joinpath("tag_frequency.db")
|
||||
timeout = 30
|
||||
version = 1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction(db=db_file):
|
||||
"""Context manager for database transactions.
|
||||
Ensures that the connection is properly closed after the transaction.
|
||||
"""
|
||||
conn = sqlite3.connect(db, timeout=timeout)
|
||||
try:
|
||||
conn.isolation_level = None
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN")
|
||||
yield cursor
|
||||
cursor.execute("COMMIT")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class TagFrequencyDb:
|
||||
"""Class containing creation and interaction methods for the tag frequency database"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.version = self.__check()
|
||||
|
||||
def __check(self):
|
||||
if not db_file.exists():
|
||||
print("Tag Autocomplete: Creating frequency database")
|
||||
with transaction() as cursor:
|
||||
self.__create_db(cursor)
|
||||
self.__update_db_data(cursor, "version", version)
|
||||
print("Tag Autocomplete: Database successfully created")
|
||||
|
||||
return self.__get_version()
|
||||
|
||||
def __create_db(self, cursor: sqlite3.Cursor):
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS db_data (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tag_frequency (
|
||||
name TEXT PRIMARY KEY,
|
||||
count INT,
|
||||
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def __update_db_data(cursor: sqlite3.Cursor, key, value):
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE
|
||||
INTO db_data (key, value)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(key, value),
|
||||
)
|
||||
|
||||
def __get_version(self):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM db_data
|
||||
WHERE key = 'version'
|
||||
"""
|
||||
)
|
||||
db_version = cursor.fetchone()
|
||||
|
||||
return db_version
|
||||
|
||||
def get_all_tags(self):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name
|
||||
FROM tag_frequency
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
)
|
||||
tags = cursor.fetchall()
|
||||
|
||||
return tags
|
||||
|
||||
def get_tag_count(self, tag):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT count
|
||||
FROM tag_frequency
|
||||
WHERE name = ?
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
|
||||
return tag_count or 0
|
||||
|
||||
def increase_tag_count(self, tag):
|
||||
current_count = self.get_tag_count(tag)
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE
|
||||
INTO tag_frequency (name, count)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(tag, current_count + 1),
|
||||
)
|
||||
|
||||
def reset_tag_count(self, tag):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE tag_frequency
|
||||
SET count = 0
|
||||
WHERE name = ?
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
Reference in New Issue
Block a user