mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-03-14 01:30:03 +00:00
Use composite key with name & type to prevent collisions
This commit is contained in:
@@ -174,21 +174,21 @@ function tagBias(count, uses) {
|
||||
return Math.log(count) + Math.log(uses);
|
||||
}
|
||||
// Call API endpoint to increase bias of tag in the database
|
||||
function increaseUseCount(tagName) {
|
||||
postAPI(`tacapi/v1/increase-use-count/${tagName}`, null);
|
||||
function increaseUseCount(tagName, type) {
|
||||
postAPI(`tacapi/v1/increase-use-count/${tagName}?ttype=${type}`, null);
|
||||
}
|
||||
// Get use count of tag from the database
|
||||
async function getUseCount(tagName) {
|
||||
return (await fetchAPI(`tacapi/v1/get-use-count/${tagName}`, true, false))["result"];
|
||||
async function getUseCount(tagName, type) {
|
||||
return (await fetchAPI(`tacapi/v1/get-use-count/${tagName}?ttype=${type}`, true, false))["result"];
|
||||
}
|
||||
async function getUseCounts(tagNames) {
|
||||
return (await fetchAPI(`tacapi/v1/get-use-count-list?tags=${tagNames.join("&tags=")}`))["result"];
|
||||
async function getUseCounts(tagNames, types) {
|
||||
return (await fetchAPI(`tacapi/v1/get-use-count-list?tags=${tagNames.join("&tags=")}&ttypes=${types.join("&ttypes=")}`))["result"];
|
||||
}
|
||||
async function getAllUseCounts() {
|
||||
return (await fetchAPI(`tacapi/v1/get-all-use-counts`))["result"];
|
||||
}
|
||||
async function resetUseCount(tagName) {
|
||||
putAPI(`tacapi/v1/reset-use-count/${tagName}`, null);
|
||||
async function resetUseCount(tagName, type) {
|
||||
putAPI(`tacapi/v1/reset-use-count/${tagName}?ttype=${type}`, null);
|
||||
}
|
||||
|
||||
// Sliding window function to get possible combination groups of an array
|
||||
|
||||
@@ -598,20 +598,20 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
return JSONResponse({"error": "Database not initialized"}, status_code=500)
|
||||
|
||||
@app.post("/tacapi/v1/increase-use-count/{tagname}")
|
||||
async def increase_use_count(tagname: str):
|
||||
db_request(lambda: db.increase_tag_count(tagname))
|
||||
async def increase_use_count(tagname: str, ttype: int):
|
||||
db_request(lambda: db.increase_tag_count(tagname, ttype))
|
||||
|
||||
@app.get("/tacapi/v1/get-use-count/{tagname}")
|
||||
async def get_use_count(tagname: str):
|
||||
return db_request(lambda: db.get_tag_count(tagname), get=True)
|
||||
async def get_use_count(tagname: str, ttype: int):
|
||||
return db_request(lambda: db.get_tag_count(tagname, ttype), get=True)
|
||||
|
||||
@app.get("/tacapi/v1/get-use-count-list")
|
||||
async def get_use_count_list(tags: list[str] | None = Query(default=None)):
|
||||
return db_request(lambda: list(db.get_tag_counts(tags)), get=True)
|
||||
async def get_use_count_list(tags: list[str] | None = Query(default=None), ttypes: list[int] | None = Query(default=None)):
|
||||
return db_request(lambda: list(db.get_tag_counts(tags, ttypes)), get=True)
|
||||
|
||||
@app.put("/tacapi/v1/reset-use-count/{tagname}")
|
||||
async def reset_use_count(tagname: str):
|
||||
db_request(lambda: db.reset_tag_count(tagname))
|
||||
async def reset_use_count(tagname: str, ttype: int):
|
||||
db_request(lambda: db.reset_tag_count(tagname, ttype))
|
||||
|
||||
@app.get("/tacapi/v1/get-all-use-counts")
|
||||
async def get_all_tag_counts():
|
||||
|
||||
@@ -53,9 +53,11 @@ class TagFrequencyDb:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tag_frequency (
|
||||
name TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
type INT NOT NULL,
|
||||
count INT,
|
||||
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (name, type)
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -87,7 +89,7 @@ class TagFrequencyDb:
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name
|
||||
SELECT name, type, count
|
||||
FROM tag_frequency
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
@@ -96,53 +98,53 @@ class TagFrequencyDb:
|
||||
|
||||
return tags
|
||||
|
||||
def get_tag_count(self, tag):
|
||||
def get_tag_count(self, tag, ttype):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT count
|
||||
FROM tag_frequency
|
||||
WHERE name = ?
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag,),
|
||||
(tag,ttype),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
|
||||
return tag_count[0] if tag_count else 0
|
||||
|
||||
def get_tag_counts(self, tags: list[str]):
|
||||
def get_tag_counts(self, tags: list[str], ttypes: list[str]):
|
||||
with transaction() as cursor:
|
||||
for tag in tags:
|
||||
for tag, ttype in zip(tags, ttypes):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT count
|
||||
FROM tag_frequency
|
||||
WHERE name = ?
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag,),
|
||||
(tag,ttype),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
yield (tag, tag_count[0]) if tag_count else (tag, 0)
|
||||
|
||||
def increase_tag_count(self, tag):
|
||||
current_count = self.get_tag_count(tag)
|
||||
def increase_tag_count(self, tag, ttype):
|
||||
current_count = self.get_tag_count(tag, ttype)
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE
|
||||
INTO tag_frequency (name, count)
|
||||
VALUES (?, ?)
|
||||
INTO tag_frequency (name, type, count)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(tag, current_count + 1),
|
||||
(tag, ttype, current_count + 1),
|
||||
)
|
||||
|
||||
def reset_tag_count(self, tag):
|
||||
def reset_tag_count(self, tag, ttype):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE tag_frequency
|
||||
SET count = 0
|
||||
WHERE name = ?
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag,),
|
||||
(tag,ttype),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user