mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 19:19:57 +00:00
Catch sqlite exceptions, add tag list endpoint
This commit is contained in:
@@ -3,12 +3,13 @@
|
||||
|
||||
import glob
|
||||
import json
|
||||
import sqlite3
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Query
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from modules import script_callbacks, sd_hijack, shared
|
||||
|
||||
@@ -21,10 +22,9 @@ try:
|
||||
from scripts.tag_frequency_db import TagFrequencyDb, db_ver
|
||||
db = TagFrequencyDb()
|
||||
if int(db.version) != int(db_ver):
|
||||
raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting")
|
||||
except (ImportError, ValueError) as e:
|
||||
print(e)
|
||||
print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting")
|
||||
raise ValueError("Database version mismatch")
|
||||
except (ImportError, ValueError, sqlite3.Error) as e:
|
||||
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
|
||||
db = None
|
||||
|
||||
# Attempt to get embedding load function, using the same call as api.
|
||||
@@ -584,36 +584,37 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": e}, status_code=500)
|
||||
|
||||
NO_DB = JSONResponse({"error": "Database not initialized"}, status_code=500)
|
||||
def db_request(func, get = False):
|
||||
if db is not None:
|
||||
try:
|
||||
if get:
|
||||
ret = func()
|
||||
return JSONResponse({"result": ret})
|
||||
else:
|
||||
func()
|
||||
except sqlite3.Error as e:
|
||||
return JSONResponse({"error": e}, status_code=500)
|
||||
else:
|
||||
return JSONResponse({"error": "Database not initialized"}, status_code=500)
|
||||
|
||||
@app.post("/tacapi/v1/increase-use-count/{tagname}")
|
||||
async def increase_use_count(tagname: str):
|
||||
if db is not None:
|
||||
db.increase_tag_count(tagname)
|
||||
else:
|
||||
return NO_DB
|
||||
db_request(lambda: db.increase_tag_count(tagname))
|
||||
|
||||
@app.get("/tacapi/v1/get-use-count/{tagname}")
|
||||
async def get_use_count(tagname: str):
|
||||
if db is not None:
|
||||
db_count = db.get_tag_count(tagname)
|
||||
return JSONResponse({"count": db_count})
|
||||
else:
|
||||
return NO_DB
|
||||
return db_request(lambda: db.get_tag_count(tagname), get=True)
|
||||
|
||||
@app.get("/tacapi/v1/get-use-count-list")
|
||||
async def get_use_count_list(tags: list[str] | None = Query(default=None)):
|
||||
return db_request(lambda: list(db.get_tag_counts(tags)), get=True)
|
||||
|
||||
@app.put("/tacapi/v1/reset-use-count/{tagname}")
|
||||
async def reset_use_count(tagname: str):
|
||||
if db is not None:
|
||||
db.reset_tag_count(tagname)
|
||||
else:
|
||||
return NO_DB
|
||||
db_request(lambda: db.reset_tag_count(tagname))
|
||||
|
||||
@app.get("/tacapi/v1/get-all-use-counts")
|
||||
async def get_all_tag_counts():
|
||||
if db is not None:
|
||||
db_tags = db.get_all_tags()
|
||||
return JSONResponse({"tags": db_tags})
|
||||
else:
|
||||
return NO_DB
|
||||
return db_request(lambda: db.get_all_tags(), get=True)
|
||||
|
||||
script_callbacks.on_app_started(api_tac)
|
||||
|
||||
@@ -110,6 +110,20 @@ class TagFrequencyDb:
|
||||
|
||||
return tag_count[0] if tag_count else 0
|
||||
|
||||
def get_tag_counts(self, tags: list[str]):
|
||||
with transaction() as cursor:
|
||||
for tag in tags:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT count
|
||||
FROM tag_frequency
|
||||
WHERE name = ?
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
yield (tag, tag_count[0]) if tag_count else (tag, 0)
|
||||
|
||||
def increase_tag_count(self, tag):
|
||||
current_count = self.get_tag_count(tag)
|
||||
with transaction() as cursor:
|
||||
|
||||
Reference in New Issue
Block a user