diff --git a/javascript/_utils.js b/javascript/_utils.js index d9c7759..ab67f0f 100644 --- a/javascript/_utils.js +++ b/javascript/_utils.js @@ -179,7 +179,10 @@ function increaseUseCount(tagName) { } // Get use count of tag from the database async function getUseCount(tagName) { - return (await fetchAPI(`tacapi/v1/get-use-count/${tagName}`, true, false))["count"]; + return (await fetchAPI(`tacapi/v1/get-use-count/${tagName}`, true, false))["result"]; +} +async function getUseCounts(tagNames) { + return (await fetchAPI(`tacapi/v1/get-use-count-list?tags=${tagNames.join("&tags=")}`))["result"]; } // Sliding window function to get possible combination groups of an array diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 9b35232..9b37dfa 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -3,12 +3,13 @@ import glob import json +import sqlite3 import urllib.parse from pathlib import Path import gradio as gr import yaml -from fastapi import FastAPI +from fastapi import FastAPI, Query from fastapi.responses import FileResponse, JSONResponse from modules import script_callbacks, sd_hijack, shared @@ -21,10 +22,9 @@ try: from scripts.tag_frequency_db import TagFrequencyDb, db_ver db = TagFrequencyDb() if int(db.version) != int(db_ver): - raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting") -except (ImportError, ValueError) as e: - print(e) - print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting") + raise ValueError("Database version mismatch") +except (ImportError, ValueError, sqlite3.Error) as e: + print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"") db = None # Attempt to get embedding load function, using the same call as api. @@ -584,36 +584,37 @@ def api_tac(_: gr.Blocks, app: FastAPI): except Exception as e: return JSONResponse({"error": e}, status_code=500) - NO_DB = JSONResponse({"error": "Database not initialized"}, status_code=500) + def db_request(func, get = False): + if db is not None: + try: + if get: + ret = func() + return JSONResponse({"result": ret}) + else: + func() + except sqlite3.Error as e: + return JSONResponse({"error": e}, status_code=500) + else: + return JSONResponse({"error": "Database not initialized"}, status_code=500) @app.post("/tacapi/v1/increase-use-count/{tagname}") async def increase_use_count(tagname: str): - if db is not None: - db.increase_tag_count(tagname) - else: - return NO_DB + db_request(lambda: db.increase_tag_count(tagname)) @app.get("/tacapi/v1/get-use-count/{tagname}") async def get_use_count(tagname: str): - if db is not None: - db_count = db.get_tag_count(tagname) - return JSONResponse({"count": db_count}) - else: - return NO_DB + return db_request(lambda: db.get_tag_count(tagname), get=True) + + @app.get("/tacapi/v1/get-use-count-list") + async def get_use_count_list(tags: list[str] | None = Query(default=None)): + return db_request(lambda: list(db.get_tag_counts(tags)), get=True) @app.put("/tacapi/v1/reset-use-count/{tagname}") async def reset_use_count(tagname: str): - if db is not None: - db.reset_tag_count(tagname) - else: - return NO_DB + db_request(lambda: db.reset_tag_count(tagname)) @app.get("/tacapi/v1/get-all-use-counts") async def get_all_tag_counts(): - if db is not None: - db_tags = db.get_all_tags() - return JSONResponse({"tags": db_tags}) - else: - return NO_DB + return db_request(lambda: db.get_all_tags(), get=True) script_callbacks.on_app_started(api_tac) diff --git a/scripts/tag_frequency_db.py b/scripts/tag_frequency_db.py index d186b0d..77945bb 100644 --- a/scripts/tag_frequency_db.py +++ b/scripts/tag_frequency_db.py @@ -110,6 +110,20 @@ class TagFrequencyDb: return tag_count[0] if tag_count else 0 + def get_tag_counts(self, tags: list[str]): + with transaction() as cursor: + for tag in tags: + cursor.execute( + """ + SELECT count + FROM tag_frequency + WHERE name = ? + """, + (tag,), + ) + tag_count = cursor.fetchone() + yield (tag, tag_count[0]) if tag_count else (tag, 0) + def increase_tag_count(self, tag): current_count = self.get_tag_count(tag) with transaction() as cursor: