mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-04-26 09:19:07 +00:00
Count positive / negative prompt usage separately
This commit is contained in:
@@ -190,25 +190,25 @@ function mapUseCountArray(useCounts) {
|
||||
return useCounts.map(useCount => {return {"name": useCount[0], "type": useCount[1], "count": useCount[2]}});
|
||||
}
|
||||
// Call API endpoint to increase bias of tag in the database
|
||||
function increaseUseCount(tagName, type) {
|
||||
postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}`);
|
||||
function increaseUseCount(tagName, type, negative = false) {
|
||||
postAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`);
|
||||
}
|
||||
// Get use count of tag from the database
|
||||
async function getUseCount(tagName, type) {
|
||||
return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}`, true, false))["result"];
|
||||
async function getUseCount(tagName, type, negative = false) {
|
||||
return (await fetchAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false))["result"];
|
||||
}
|
||||
async function getUseCounts(tagNames, types) {
|
||||
async function getUseCounts(tagNames, types, negative = false) {
|
||||
// While semantically weird, we have to use POST here for the body, as urls are limited in length
|
||||
const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types});
|
||||
const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative});
|
||||
const rawArray = (await postAPI(`tacapi/v1/get-use-count-list`, body))["result"]
|
||||
return mapUseCountArray(rawArray);
|
||||
}
|
||||
async function getAllUseCounts() {
|
||||
const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts`))["result"];
|
||||
async function getAllUseCounts(negative = false) {
|
||||
const rawArray = (await fetchAPI(`tacapi/v1/get-all-use-counts?neg=${negative}`))["result"];
|
||||
return mapUseCountArray(rawArray);
|
||||
}
|
||||
async function resetUseCount(tagName, type) {
|
||||
await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}`);
|
||||
async function resetUseCount(tagName, type, resetPosCount, resetNegCount) {
|
||||
await putAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`);
|
||||
}
|
||||
|
||||
// Sliding window function to get possible combination groups of an array
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const styleColors = {
|
||||
const styleColors = {
|
||||
"--results-bg": ["#0b0f19", "#ffffff"],
|
||||
"--results-border-color": ["#4b5563", "#e5e7eb"],
|
||||
"--results-border-width": ["1px", "1.5px"],
|
||||
@@ -488,10 +488,13 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
}
|
||||
|
||||
if (name && name.length > 0) {
|
||||
// Check if it's a negative prompt
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let isNegative = textAreaId.includes("n");
|
||||
// Sanitize name for API call
|
||||
name = encodeURIComponent(name)
|
||||
// Call API & update db
|
||||
increaseUseCount(name, tagType)
|
||||
increaseUseCount(name, tagType, isNegative)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1160,8 +1163,12 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
types.push(r.type);
|
||||
});
|
||||
|
||||
// Check if it's a negative prompt
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let isNegative = textAreaId.includes("n");
|
||||
|
||||
// Request use counts from the DB
|
||||
const counts = await getUseCounts(names, types);
|
||||
const counts = await getUseCounts(names, types, isNegative);
|
||||
const usedResults = counts.filter(c => c.count > 0).map(c => c.name);
|
||||
|
||||
// Sort all
|
||||
|
||||
@@ -607,29 +607,30 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
return JSONResponse({"error": "Database not initialized"}, status_code=500)
|
||||
|
||||
@app.post("/tacapi/v1/increase-use-count")
|
||||
async def increase_use_count(tagname: str, ttype: int):
|
||||
db_request(lambda: db.increase_tag_count(tagname, ttype))
|
||||
async def increase_use_count(tagname: str, ttype: int, neg: bool):
|
||||
db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
|
||||
|
||||
@app.get("/tacapi/v1/get-use-count")
|
||||
async def get_use_count(tagname: str, ttype: int):
|
||||
return db_request(lambda: db.get_tag_count(tagname, ttype), get=True)
|
||||
async def get_use_count(tagname: str, ttype: int, neg: bool):
|
||||
return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
|
||||
|
||||
# Small dataholder class
|
||||
class UseCountListRequest(BaseModel):
|
||||
tagNames: list[str]
|
||||
tagTypes: list[int]
|
||||
neg: bool = False
|
||||
|
||||
# Semantically weird to use post here, but it's required for the body on js side
|
||||
@app.post("/tacapi/v1/get-use-count-list")
|
||||
async def get_use_count_list(body: UseCountListRequest):
|
||||
return db_request(lambda: list(db.get_tag_counts(body.tagNames, body.tagTypes)), get=True)
|
||||
return db_request(lambda: list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg)), get=True)
|
||||
|
||||
@app.put("/tacapi/v1/reset-use-count")
|
||||
async def reset_use_count(tagname: str, ttype: int):
|
||||
db_request(lambda: db.reset_tag_count(tagname, ttype))
|
||||
async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
|
||||
db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
|
||||
|
||||
@app.get("/tacapi/v1/get-all-use-counts")
|
||||
async def get_all_tag_counts():
|
||||
return db_request(lambda: db.get_all_tags(), get=True)
|
||||
async def get_all_tag_counts(neg: bool = False):
|
||||
return db_request(lambda: db.get_all_tags(neg), get=True)
|
||||
|
||||
script_callbacks.on_app_started(api_tac)
|
||||
|
||||
@@ -55,7 +55,8 @@ class TagFrequencyDb:
|
||||
CREATE TABLE IF NOT EXISTS tag_frequency (
|
||||
name TEXT NOT NULL,
|
||||
type INT NOT NULL,
|
||||
count INT,
|
||||
count_pos INT,
|
||||
count_neg INT,
|
||||
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (name, type)
|
||||
)
|
||||
@@ -85,24 +86,26 @@ class TagFrequencyDb:
|
||||
|
||||
return db_version[0] if db_version else 0
|
||||
|
||||
def get_all_tags(self):
|
||||
def get_all_tags(self, negative=False):
|
||||
count_str = "count_neg" if negative else "count_pos"
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name, type, count
|
||||
f"""
|
||||
SELECT name, type, {count_str}
|
||||
FROM tag_frequency
|
||||
ORDER BY count DESC
|
||||
ORDER BY {count_str} DESC
|
||||
"""
|
||||
)
|
||||
tags = cursor.fetchall()
|
||||
|
||||
return tags
|
||||
|
||||
def get_tag_count(self, tag, ttype):
|
||||
def get_tag_count(self, tag, ttype, negative=False):
|
||||
count_str = "count_neg" if negative else "count_pos"
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT count
|
||||
f"""
|
||||
SELECT {count_str}
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
@@ -112,12 +115,13 @@ class TagFrequencyDb:
|
||||
|
||||
return tag_count[0] if tag_count else 0
|
||||
|
||||
def get_tag_counts(self, tags: list[str], ttypes: list[str]):
|
||||
def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False):
|
||||
count_str = "count_neg" if negative else "count_pos"
|
||||
with transaction() as cursor:
|
||||
for tag, ttype in zip(tags, ttypes):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT count
|
||||
f"""
|
||||
SELECT {count_str}
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
@@ -126,24 +130,38 @@ class TagFrequencyDb:
|
||||
tag_count = cursor.fetchone()
|
||||
yield (tag, ttype, tag_count[0]) if tag_count else (tag, ttype, 0)
|
||||
|
||||
def increase_tag_count(self, tag, ttype):
|
||||
current_count = self.get_tag_count(tag, ttype)
|
||||
def increase_tag_count(self, tag, ttype, negative=False):
|
||||
pos_count = self.get_tag_count(tag, ttype, False)
|
||||
neg_count = self.get_tag_count(tag, ttype, True)
|
||||
|
||||
if negative:
|
||||
neg_count += 1
|
||||
else:
|
||||
pos_count += 1
|
||||
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
f"""
|
||||
INSERT OR REPLACE
|
||||
INTO tag_frequency (name, type, count)
|
||||
VALUES (?, ?, ?)
|
||||
INTO tag_frequency (name, type, count_pos, count_neg)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(tag, ttype, current_count + 1),
|
||||
(tag, ttype, pos_count, neg_count),
|
||||
)
|
||||
|
||||
def reset_tag_count(self, tag, ttype):
|
||||
def reset_tag_count(self, tag, ttype, positive=True, negative=False):
|
||||
if positive and negative:
|
||||
set_str = "count_pos = 0, count_neg = 0"
|
||||
elif positive:
|
||||
set_str = "count_pos = 0"
|
||||
elif negative:
|
||||
set_str = "count_neg = 0"
|
||||
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
f"""
|
||||
UPDATE tag_frequency
|
||||
SET count = 0
|
||||
SET {set_str}
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag,ttype),
|
||||
|
||||
Reference in New Issue
Block a user