Catch sqlite exceptions, add tag list endpoint

This commit is contained in:
DominikDoom
2023-09-24 20:06:40 +02:00
parent 109a8a155e
commit 6cf9acd6ab
3 changed files with 43 additions and 25 deletions

View File

@@ -3,12 +3,13 @@
import glob
import json
import sqlite3
import urllib.parse
from pathlib import Path
import gradio as gr
import yaml
from fastapi import FastAPI
from fastapi import FastAPI, Query
from fastapi.responses import FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared
@@ -21,10 +22,9 @@ try:
from scripts.tag_frequency_db import TagFrequencyDb, db_ver
db = TagFrequencyDb()
if int(db.version) != int(db_ver):
raise ValueError("Tag Autocomplete: Tag frequency database version mismatch, disabling tag frequency sorting")
except (ImportError, ValueError) as e:
print(e)
print("Tag Autocomplete: Tag frequency database could not be loaded, disabling tag frequency sorting")
raise ValueError("Database version mismatch")
except (ImportError, ValueError, sqlite3.Error) as e:
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
db = None
# Attempt to get embedding load function, using the same call as api.
@@ -584,36 +584,37 @@ def api_tac(_: gr.Blocks, app: FastAPI):
except Exception as e:
return JSONResponse({"error": e}, status_code=500)
NO_DB = JSONResponse({"error": "Database not initialized"}, status_code=500)
def db_request(func, get = False):
if db is not None:
try:
if get:
ret = func()
return JSONResponse({"result": ret})
else:
func()
except sqlite3.Error as e:
return JSONResponse({"error": e}, status_code=500)
else:
return JSONResponse({"error": "Database not initialized"}, status_code=500)
@app.post("/tacapi/v1/increase-use-count/{tagname}")
async def increase_use_count(tagname: str):
if db is not None:
db.increase_tag_count(tagname)
else:
return NO_DB
db_request(lambda: db.increase_tag_count(tagname))
@app.get("/tacapi/v1/get-use-count/{tagname}")
async def get_use_count(tagname: str):
if db is not None:
db_count = db.get_tag_count(tagname)
return JSONResponse({"count": db_count})
else:
return NO_DB
return db_request(lambda: db.get_tag_count(tagname), get=True)
@app.get("/tacapi/v1/get-use-count-list")
async def get_use_count_list(tags: list[str] | None = Query(default=None)):
return db_request(lambda: list(db.get_tag_counts(tags)), get=True)
@app.put("/tacapi/v1/reset-use-count/{tagname}")
async def reset_use_count(tagname: str):
if db is not None:
db.reset_tag_count(tagname)
else:
return NO_DB
db_request(lambda: db.reset_tag_count(tagname))
@app.get("/tacapi/v1/get-all-use-counts")
async def get_all_tag_counts():
if db is not None:
db_tags = db.get_all_tags()
return JSONResponse({"tags": db_tags})
else:
return NO_DB
return db_request(lambda: db.get_all_tags(), get=True)
script_callbacks.on_app_started(api_tac)

View File

@@ -110,6 +110,20 @@ class TagFrequencyDb:
return tag_count[0] if tag_count else 0
def get_tag_counts(self, tags: list[str]):
with transaction() as cursor:
for tag in tags:
cursor.execute(
"""
SELECT count
FROM tag_frequency
WHERE name = ?
""",
(tag,),
)
tag_count = cursor.fetchone()
yield (tag, tag_count[0]) if tag_count else (tag, 0)
def increase_tag_count(self, tag):
current_count = self.get_tag_count(tag)
with transaction() as cursor: