Count positive / negative prompt usage separately

This commit is contained in:
DominikDoom
2023-11-29 15:20:15 +01:00
parent 434301738a
commit 15478e73b5
4 changed files with 68 additions and 42 deletions

View File

@@ -607,29 +607,30 @@ def api_tac(_: gr.Blocks, app: FastAPI):
return JSONResponse({"error": "Database not initialized"}, status_code=500)
@app.post("/tacapi/v1/increase-use-count")
async def increase_use_count(tagname: str, ttype: int):
db_request(lambda: db.increase_tag_count(tagname, ttype))
async def increase_use_count(tagname: str, ttype: int, neg: bool):
db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
@app.get("/tacapi/v1/get-use-count")
async def get_use_count(tagname: str, ttype: int):
return db_request(lambda: db.get_tag_count(tagname, ttype), get=True)
async def get_use_count(tagname: str, ttype: int, neg: bool):
return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
# Small dataholder class
class UseCountListRequest(BaseModel):
tagNames: list[str]
tagTypes: list[int]
neg: bool = False
# Semantically weird to use post here, but it's required for the body on js side
@app.post("/tacapi/v1/get-use-count-list")
async def get_use_count_list(body: UseCountListRequest):
return db_request(lambda: list(db.get_tag_counts(body.tagNames, body.tagTypes)), get=True)
return db_request(lambda: list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg)), get=True)
@app.put("/tacapi/v1/reset-use-count")
async def reset_use_count(tagname: str, ttype: int):
db_request(lambda: db.reset_tag_count(tagname, ttype))
async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
@app.get("/tacapi/v1/get-all-use-counts")
async def get_all_tag_counts():
return db_request(lambda: db.get_all_tags(), get=True)
async def get_all_tag_counts(neg: bool = False):
return db_request(lambda: db.get_all_tags(neg), get=True)
script_callbacks.on_app_started(api_tac)

View File

@@ -55,7 +55,8 @@ class TagFrequencyDb:
CREATE TABLE IF NOT EXISTS tag_frequency (
name TEXT NOT NULL,
type INT NOT NULL,
count INT,
count_pos INT,
count_neg INT,
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (name, type)
)
@@ -85,24 +86,26 @@ class TagFrequencyDb:
return db_version[0] if db_version else 0
def get_all_tags(self):
def get_all_tags(self, negative=False):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
cursor.execute(
"""
SELECT name, type, count
f"""
SELECT name, type, {count_str}
FROM tag_frequency
ORDER BY count DESC
ORDER BY {count_str} DESC
"""
)
tags = cursor.fetchall()
return tags
def get_tag_count(self, tag, ttype):
def get_tag_count(self, tag, ttype, negative=False):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
cursor.execute(
"""
SELECT count
f"""
SELECT {count_str}
FROM tag_frequency
WHERE name = ? AND type = ?
""",
@@ -112,12 +115,13 @@ class TagFrequencyDb:
return tag_count[0] if tag_count else 0
def get_tag_counts(self, tags: list[str], ttypes: list[str]):
def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
for tag, ttype in zip(tags, ttypes):
cursor.execute(
"""
SELECT count
f"""
SELECT {count_str}
FROM tag_frequency
WHERE name = ? AND type = ?
""",
@@ -126,24 +130,38 @@ class TagFrequencyDb:
tag_count = cursor.fetchone()
yield (tag, ttype, tag_count[0]) if tag_count else (tag, ttype, 0)
def increase_tag_count(self, tag, ttype):
current_count = self.get_tag_count(tag, ttype)
def increase_tag_count(self, tag, ttype, negative=False):
pos_count = self.get_tag_count(tag, ttype, False)
neg_count = self.get_tag_count(tag, ttype, True)
if negative:
neg_count += 1
else:
pos_count += 1
with transaction() as cursor:
cursor.execute(
"""
f"""
INSERT OR REPLACE
INTO tag_frequency (name, type, count)
VALUES (?, ?, ?)
INTO tag_frequency (name, type, count_pos, count_neg)
VALUES (?, ?, ?, ?)
""",
(tag, ttype, current_count + 1),
(tag, ttype, pos_count, neg_count),
)
def reset_tag_count(self, tag, ttype):
def reset_tag_count(self, tag, ttype, positive=True, negative=False):
if positive and negative:
set_str = "count_pos = 0, count_neg = 0"
elif positive:
set_str = "count_pos = 0"
elif negative:
set_str = "count_neg = 0"
with transaction() as cursor:
cursor.execute(
"""
f"""
UPDATE tag_frequency
SET count = 0
SET {set_str}
WHERE name = ? AND type = ?
""",
(tag,ttype),