From 15478e73b57046e4b77bc14910452aca9f7b7cf5 Mon Sep 17 00:00:00 2001 From: DominikDoom Date: Wed, 29 Nov 2023 15:20:15 +0100 Subject: [PATCH] Count positive / negative prompt usage separately --- javascript/_utils.js | 20 +++++------ javascript/tagAutocomplete.js | 13 +++++-- scripts/tag_autocomplete_helper.py | 19 +++++----- scripts/tag_frequency_db.py | 58 +++++++++++++++++++----------- 4 files changed, 68 insertions(+), 42 deletions(-) diff --git a/javascript/_utils.js b/javascript/_utils.js index 1c6d978..45ff5e8 100644 --- a/javascript/_utils.js +++ b/javascript/_utils.js @@ -190,25 +190,25 @@ function mapUseCountArray(useCounts) { return useCounts.map(useCount => {return {"name": useCount[0], "type": useCount[1], "count": useCount[2]}}); } // Call API endpoint to increase bias of tag in the database -function increaseUseCount(tagName, type) { - postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}`); +function increaseUseCount(tagName, type, negative = false) { + postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`); } // Get use count of tag from the database -async function getUseCount(tagName, type) { - return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}`, true, false))["result"]; +async function getUseCount(tagName, type, negative = false) { + return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false))["result"]; } -async function getUseCounts(tagNames, types) { +async function getUseCounts(tagNames, types, negative = false) { // While semantically weird, we have to use POST here for the body, as urls are limited in length - const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types}); + const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative}); const rawArray = (await postAPI(`tacapi/v1/get-use-count-list`, body))["result"] return mapUseCountArray(rawArray); } -async function getAllUseCounts() { - const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts`))["result"]; +async function getAllUseCounts(negative = false) { + const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts?neg=${negative}`))["result"]; return mapUseCountArray(rawArray); } -async function resetUseCount(tagName, type) { - await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}`); +async function resetUseCount(tagName, type, resetPosCount, resetNegCount) { + await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`); } // Sliding window function to get possible combination groups of an array diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index 75a51de..b28f3f0 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -1,4 +1,4 @@ -const styleColors = { +const styleColors = { "--results-bg": ["#0b0f19", "#ffffff"], "--results-border-color": ["#4b5563", "#e5e7eb"], "--results-border-width": ["1px", "1.5px"], @@ -488,10 +488,13 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout } if (name && name.length > 0) { + // Check if it's a negative prompt + let textAreaId = getTextAreaIdentifier(textArea); + let isNegative = textAreaId.includes("n"); // Sanitize name for API call name = encodeURIComponent(name) // Call API & update db - increaseUseCount(name, tagType) + increaseUseCount(name, tagType, isNegative) } } @@ -1160,8 +1163,12 @@ async function autocomplete(textArea, prompt, fixedTag = null) { types.push(r.type); }); + // Check if it's a negative prompt + let textAreaId = getTextAreaIdentifier(textArea); + let isNegative = textAreaId.includes("n"); + // Request use counts from the DB - const counts = await getUseCounts(names, types); + const counts = await getUseCounts(names, types, isNegative); const usedResults = counts.filter(c => c.count > 0).map(c => c.name); // Sort all diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index d6b044e..43f6c39 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -607,29 +607,30 @@ def api_tac(_: gr.Blocks, app: FastAPI): return JSONResponse({"error": "Database not initialized"}, status_code=500) @app.post("/tacapi/v1/increase-use-count") - async def increase_use_count(tagname: str, ttype: int): - db_request(lambda: db.increase_tag_count(tagname, ttype)) + async def increase_use_count(tagname: str, ttype: int, neg: bool): + db_request(lambda: db.increase_tag_count(tagname, ttype, neg)) @app.get("/tacapi/v1/get-use-count") - async def get_use_count(tagname: str, ttype: int): - return db_request(lambda: db.get_tag_count(tagname, ttype), get=True) + async def get_use_count(tagname: str, ttype: int, neg: bool): + return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True) # Small dataholder class class UseCountListRequest(BaseModel): tagNames: list[str] tagTypes: list[int] + neg: bool = False # Semantically weird to use post here, but it's required for the body on js side @app.post("/tacapi/v1/get-use-count-list") async def get_use_count_list(body: UseCountListRequest): - return db_request(lambda: list(db.get_tag_counts(body.tagNames, body.tagTypes)), get=True) + return db_request(lambda: list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg)), get=True) @app.put("/tacapi/v1/reset-use-count") - async def reset_use_count(tagname: str, ttype: int): - db_request(lambda: db.reset_tag_count(tagname, ttype)) + async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool): + db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg)) @app.get("/tacapi/v1/get-all-use-counts") - async def get_all_tag_counts(): - return db_request(lambda: db.get_all_tags(), get=True) + async def get_all_tag_counts(neg: bool = False): + return db_request(lambda: db.get_all_tags(neg), get=True) script_callbacks.on_app_started(api_tac) diff --git a/scripts/tag_frequency_db.py b/scripts/tag_frequency_db.py index 71862e6..4f4f26b 100644 --- a/scripts/tag_frequency_db.py +++ b/scripts/tag_frequency_db.py @@ -55,7 +55,8 @@ class TagFrequencyDb: CREATE TABLE IF NOT EXISTS tag_frequency ( name TEXT NOT NULL, type INT NOT NULL, - count INT, + count_pos INT, + count_neg INT, last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (name, type) ) @@ -85,24 +86,26 @@ class TagFrequencyDb: return db_version[0] if db_version else 0 - def get_all_tags(self): + def get_all_tags(self, negative=False): + count_str = "count_neg" if negative else "count_pos" with transaction() as cursor: cursor.execute( - """ - SELECT name, type, count + f""" + SELECT name, type, {count_str} FROM tag_frequency - ORDER BY count DESC + ORDER BY {count_str} DESC """ ) tags = cursor.fetchall() return tags - def get_tag_count(self, tag, ttype): + def get_tag_count(self, tag, ttype, negative=False): + count_str = "count_neg" if negative else "count_pos" with transaction() as cursor: cursor.execute( - """ - SELECT count + f""" + SELECT {count_str} FROM tag_frequency WHERE name = ? AND type = ? """, @@ -112,12 +115,13 @@ class TagFrequencyDb: return tag_count[0] if tag_count else 0 - def get_tag_counts(self, tags: list[str], ttypes: list[str]): + def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False): + count_str = "count_neg" if negative else "count_pos" with transaction() as cursor: for tag, ttype in zip(tags, ttypes): cursor.execute( - """ - SELECT count + f""" + SELECT {count_str} FROM tag_frequency WHERE name = ? AND type = ? """, @@ -126,24 +130,38 @@ class TagFrequencyDb: tag_count = cursor.fetchone() yield (tag, ttype, tag_count[0]) if tag_count else (tag, ttype, 0) - def increase_tag_count(self, tag, ttype): - current_count = self.get_tag_count(tag, ttype) + def increase_tag_count(self, tag, ttype, negative=False): + pos_count = self.get_tag_count(tag, ttype, False) + neg_count = self.get_tag_count(tag, ttype, True) + + if negative: + neg_count += 1 + else: + pos_count += 1 + with transaction() as cursor: cursor.execute( - """ + f""" INSERT OR REPLACE - INTO tag_frequency (name, type, count) - VALUES (?, ?, ?) + INTO tag_frequency (name, type, count_pos, count_neg) + VALUES (?, ?, ?, ?) """, - (tag, ttype, current_count + 1), + (tag, ttype, pos_count, neg_count), ) - def reset_tag_count(self, tag, ttype): + def reset_tag_count(self, tag, ttype, positive=True, negative=False): + if positive and negative: + set_str = "count_pos = 0, count_neg = 0" + elif positive: + set_str = "count_pos = 0" + elif negative: + set_str = "count_neg = 0" + with transaction() as cursor: cursor.execute( - """ + f""" UPDATE tag_frequency - SET count = 0 + SET {set_str} WHERE name = ? AND type = ? """, (tag,ttype),