Compare commits

..

77 Commits

Author SHA1 Message Date
DominikDoom
38700d4743 Formatting 2025-01-04 19:35:14 +01:00
DominikDoom
bb492ba059 Add default color config & wiki link fix for merged tag list 2025-01-04 19:33:29 +01:00
Drac
40ad070a02 Add danbooru_e621_merged.csv (#312)
Post count threshold for this file is 25
2025-01-04 19:12:57 +01:00
DominikDoom
209b1dd76b End of 2024 tag list update
Danbooru and e621 tag lists as of 2024-12-22 (no Derpibooru for now, sorry).
Both cut off at a post count of 25, slightly improved consistency & new aliases included.
Thanks a lot to @DraconicDragon for the up-to-date tag list at https://github.com/DraconicDragon/dbr-e621-lists-archive
2025-01-03 14:03:26 +01:00
DominikDoom
196fa19bfc Fix derpibooru tags containing merge conflict markers
Thanks to @heftig for noticing this, as discussed in #293
2024-12-08 18:23:21 +01:00
DominikDoom
6ffeeafc49 Update danbooru tags (2024-11-9)
Thanks to @yamosin.
Closes #309

Note: This changes the cutoff type from top 100k to post count > 30, which adds ~21k rows
2024-11-09 15:35:59 +01:00
DominikDoom
08b7c58ea7 More catches for fixing #308 2024-11-02 15:52:10 +01:00
DominikDoom
6be91449f3 Try-catch in umi format check
Possible fix for #308
2024-11-02 13:51:51 +01:00
david419kr
b515c15e01 Underscore replacement exclusion list feature (#306) 2024-10-30 17:45:32 +01:00
DominikDoom
827b99c961 Make embedding refresh non-force by default
Added option for force-refreshing embeddings to restore old behavior
Fixes #301
2024-09-04 22:58:55 +02:00
DominikDoom
49ec047af8 Fix extra network tab refresh listener 2024-08-15 11:52:49 +02:00
DominikDoom
f94da07ed1 Fix ref 2024-08-11 14:56:58 +02:00
DominikDoom
e2cfe7341b Re-register embed load callback after model load if needed 2024-08-11 14:55:35 +02:00
DominikDoom
ce51ec52a2 Fix for forge type detection, sorting fallback if filename is missing 2024-08-11 14:26:37 +02:00
DominikDoom
f64d728ac6 Partial embedding fixes for webui forge
Resolves some symptoms of #297, but doesn't fix the underlying cause
2024-08-11 14:08:31 +02:00
DominikDoom
1c6bba2a3d Formatting 2024-08-05 11:47:12 +02:00
DominikDoom
9a47c2ec2c Mouse hover can now trigger selection event for extra network previews
Closes #292
2024-08-05 11:39:38 +02:00
DominikDoom
fe32ad739d Merge pull request #293 from Siberpone/derpi-update
Derpibooru csv update
2024-07-05 17:15:26 +02:00
siberpone
ade67e30a6 Updated Derpibooru tags 2024-07-05 21:57:15 +07:00
DominikDoom
e9a21e7a55 Add pony quality tags to demo chants 2024-07-05 10:03:23 +02:00
DominikDoom
3ef2a7d206 Prefer loaded over skipped embeddings on name collisions
Fixes #290
2024-06-11 16:23:16 +02:00
DominikDoom
29b5bf0701 Fix db version being accessed before creation if table fails to create
This should prevent the script from hard crashing like in #288
2024-05-28 09:47:15 +02:00
DominikDoom
3eef536b64 Use custom wildcard wrapper chars from sd-dynamic-prompts if the option is set
Closes #286
2024-05-04 14:12:14 +02:00
DominikDoom
0d24e697d2 Update README.md 2024-04-16 13:28:33 +02:00
DominikDoom
a27633da55 Remove hardcoded preview url modifiers, use ResultType instead
Fixes #284
2024-04-15 18:48:44 +02:00
DominikDoom
4cd6174a22 Trying out a hopefully more robust import fix 2024-04-14 12:14:29 +02:00
DominikDoom
9155e4d42c Prevent db response errors from breaking the regular completion 2024-04-14 12:01:15 +02:00
DominikDoom
700642a400 Attempt to fix import error described in #283 2024-04-13 20:27:47 +02:00
DominikDoom
1b592dbf56 Add safety check for db access that was not yet handled by db_request 2024-04-13 19:26:20 +02:00
DominikDoom
d1eea880f3 Rename API call functions on JS side to prevent name conflicts
Fixes #282
2024-04-13 17:27:10 +02:00
DominikDoom
119a3ad51f Merge branch 'feature-sort-by-frequent-use' 2024-04-13 15:08:57 +02:00
DominikDoom
c820a22149 Merge branch 'main' into feature-sort-by-frequent-use 2024-04-13 15:06:53 +02:00
DominikDoom
ef59cff651 Move last used date check guard to SQL side, implement max cap
- Server side date comparison and cap check further improve js sort performance
- The alias check has also been moved out of calculateUsageBias to support the new cap system
2024-03-16 16:44:43 +01:00
DominikDoom
a454383c43 Merge branch 'main' into feature-sort-by-frequent-use 2024-03-03 13:52:32 +01:00
DominikDoom
4c2ef8f770 Merge branch 'main' into feature-sort-by-frequent-use 2024-02-09 19:23:52 +01:00
DominikDoom
7437850600 Merge branch 'main' into feature-sort-by-frequent-use 2024-02-04 14:46:29 +01:00
DominikDoom
f810b2dd8f Merge branch 'main' into feature-sort-by-frequent-use 2024-01-27 12:39:40 +01:00
DominikDoom
95200e82e1 Merge branch 'main' into feature-sort-by-frequent-use 2024-01-26 17:04:53 +01:00
DominikDoom
a966be7546 Merge branch 'main' into feature-sort-by-frequent-use 2024-01-26 16:21:15 +01:00
DominikDoom
342fbc9041 Pre-calculate usage bias for all results instead of in the sort function
Roughly doubles the sort performance
2024-01-19 21:10:09 +01:00
DominikDoom
d496569c9a Cache sort key for small performance increase 2024-01-19 20:17:14 +01:00
DominikDoom
30c9593d3d Merge branch 'main' into feature-sort-by-frequent-use 2023-12-12 14:23:18 +01:00
DominikDoom
57076060df Merge branch 'main' into feature-sort-by-frequent-use 2023-12-11 11:43:26 +01:00
DominikDoom
20b6635a2a WIP usage info table
Might get replaced with gradio depending on how well it works
2023-12-04 15:00:19 +01:00
DominikDoom
1fe8f26670 Add explanatory tooltip and inline reset ability
Also add tooltip for wiki links
2023-12-04 13:56:15 +01:00
DominikDoom
e82e958c3e Fix alias check for non-aliased tag types 2023-11-29 18:15:59 +01:00
DominikDoom
2dd48eab79 Fix error with db return value for no matches 2023-11-29 18:14:14 +01:00
DominikDoom
4df90f5c95 Don't frequency sort alias results by default
with an option to enable it if desired
2023-11-29 18:04:50 +01:00
DominikDoom
a156214a48 Last used & min count settings
Also some performance improvements
2023-11-29 17:45:51 +01:00
DominikDoom
15478e73b5 Count positive / negative prompt usage separately 2023-11-29 15:22:41 +01:00
DominikDoom
434301738a Merge branch 'main' into feature-sort-by-frequent-use 2023-11-05 13:30:51 +01:00
DominikDoom
4fba7baa69 Merge branch 'main' into feature-sort-by-frequent-use 2023-10-06 18:36:24 +02:00
DominikDoom
7128efc4f4 Apply same fix to extra tags
Count now defaults to max safe integer, which simplifies the sort function
Before, it resulted in really bad performance
2023-10-02 00:45:48 +02:00
DominikDoom
bd0ddfbb24 Fix embeddings not at top
(only affecting the "include embeddings in normal results" option)
2023-10-02 00:16:58 +02:00
DominikDoom
3108daf0e8 Remove kaomoji inclusion in < search
because it interfered with use count searching and is not commonly needed
2023-10-01 23:51:35 +02:00
DominikDoom
363895494b Fix hide after insert race condition 2023-10-01 23:17:12 +02:00
DominikDoom
04551a8132 Don't await increase, limit to 2k for performance 2023-10-01 22:59:28 +02:00
DominikDoom
ffc0e378d3 Add different sorting functions 2023-10-01 22:44:35 +02:00
DominikDoom
440f109f1f Use POST + body to get around URL length limit 2023-10-01 22:30:47 +02:00
DominikDoom
80fb247dbe Sort results by usage count 2023-10-01 21:44:24 +02:00
DominikDoom
d7e98200a8 Use count increase logic 2023-09-26 12:20:15 +02:00
DominikDoom
ac790c8ede Return dict instead of array for clarity 2023-09-26 12:12:46 +02:00
DominikDoom
22365ec8d6 Add missing type return to list request 2023-09-26 12:02:36 +02:00
DominikDoom
030a83aa4d Use query parameter instead of path to fix wildcard subfolder issues 2023-09-26 11:55:12 +02:00
DominikDoom
460d32a4ed Ensure proper reload, fix error message 2023-09-26 11:45:42 +02:00
DominikDoom
581bf1e6a4 Use composite key with name & type to prevent collisions 2023-09-26 11:35:24 +02:00
DominikDoom
74ea5493e5 Add rest of utils functions 2023-09-26 10:58:46 +02:00
DominikDoom
6cf9acd6ab Catch sqlite exceptions, add tag list endpoint 2023-09-24 20:06:40 +02:00
DominikDoom
109a8a155e Change endpoint name for consistency 2023-09-24 18:00:41 +02:00
DominikDoom
3caa1b51ed Add db to gitignore 2023-09-24 17:59:39 +02:00
DominikDoom
b44c36425a Fix db load version comparison, add sort options 2023-09-24 17:59:14 +02:00
DominikDoom
1e81403180 Safety catches for DB API access 2023-09-24 16:50:39 +02:00
DominikDoom
0f487a5c5c WIP database setup inspired by ImageBrowser 2023-09-24 16:28:32 +02:00
DominikDoom
2baa12fea3 Merge branch 'main' into feature-sort-by-frequent-use 2023-09-24 15:34:18 +02:00
DominikDoom
67eeb5fbf6 Merge branch 'main' into feature-sort-by-frequent-use 2023-09-19 12:14:12 +02:00
DominikDoom
11ffed8afc Merge branch 'feature-sorting' into feature-sort-by-frequent-use 2023-09-15 16:37:34 +02:00
DominikDoom
0a8e7d7d84 Stub API setup for tag usage stats 2023-09-12 14:10:15 +02:00
15 changed files with 548840 additions and 268405 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
tags/temp/
__pycache__/
tags/tag_frequency.db

View File

@@ -20,18 +20,15 @@ Booru style tag autocompletion for the AUTOMATIC1111 Stable Diffusion WebUI
</div>
<br/>
#### ⚠️ Notice:
I am currently looking for feedback on a new feature I'm working on and want to release soon.<br/>
Please check [the announcement post](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/270) for more info if you are interested to help.
# 📄 Description
Tag Autocomplete is an extension for the popular [AUTOMATIC1111 web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for Stable Diffusion.
You can install it using the inbuilt available extensions list, clone the files manually as described [below](#-installation), or use a pre-packaged version from [Releases](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/releases).
It displays autocompletion hints for recognized tags from "image booru" boards such as Danbooru, which are primarily used for browsing Anime-style illustrations.
Since some Stable Diffusion models were trained using this information, for example [Waifu Diffusion](https://github.com/harubaru/waifu-diffusion) and many of the NAI-descendant models or merges, using exact tags in prompts can often improve composition and consistency.
Since most custom Stable Diffusion models were trained using this information or merged with ones that did, using exact tags in prompts can often improve composition and consistency, even if the model itself has a photorealistic style.
You can install it using the inbuilt available extensions list, clone the files manually as described [below](#-installation), or use a pre-packaged version from [Releases](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/releases).
Disclaimer: The default tag lists contain NSFW terms, please use them responsibly.
<br/>

View File

@@ -24,7 +24,8 @@ class AutocompleteResult {
// Additional info, only used in some cases
category = null;
count = null;
count = Number.MAX_SAFE_INTEGER;
usageBias = null;
aliases = null;
meta = null;
hash = null;

View File

@@ -61,7 +61,7 @@ async function loadCSV(path) {
}
// Fetch API
async function fetchAPI(url, json = true, cache = false) {
async function fetchTacAPI(url, json = true, cache = false) {
if (!cache) {
const appendChar = url.includes("?") ? "&" : "?";
url += `${appendChar}${new Date().getTime()}`
@@ -80,8 +80,12 @@ async function fetchAPI(url, json = true, cache = false) {
return await response.text();
}
async function postAPI(url, body) {
let response = await fetch(url, { method: "POST", body: body });
async function postTacAPI(url, body = null) {
let response = await fetch(url, {
method: "POST",
headers: {'Content-Type': 'application/json'},
body: body
});
if (response.status != 200) {
console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText);
@@ -91,9 +95,20 @@ async function postAPI(url, body) {
return await response.json();
}
async function putTacAPI(url, body = null) {
let response = await fetch(url, { method: "PUT", body: body });
if (response.status != 200) {
console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText);
return null;
}
return await response.json();
}
// Extra network preview thumbnails
async function getExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
async function getTacExtraNetworkPreviewURL(filename, type) {
const previewJSON = await fetchTacAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
if (previewJSON?.url) {
const properURL = `sd_extra_networks/thumb?filename=${previewJSON.url}`;
if ((await fetch(properURL)).status == 200) {
@@ -180,6 +195,114 @@ function flatten(obj, roots = [], sep = ".") {
);
}
// Calculate biased tag score based on post count and frequent usage
function calculateUsageBias(result, count, uses) {
// Check setting conditions
if (uses < TAC_CFG.frequencyMinCount) {
uses = 0;
} else if (uses != 0) {
result.usageBias = true;
}
switch (TAC_CFG.frequencyFunction) {
case "Logarithmic (weak)":
return Math.log(1 + count) + Math.log(1 + uses);
case "Logarithmic (strong)":
return Math.log(1 + count) + 2 * Math.log(1 + uses);
case "Usage first":
return uses;
default:
return count;
}
}
// Beautify return type for easier parsing
function mapUseCountArray(useCounts, posAndNeg = false) {
return useCounts.map(useCount => {
if (posAndNeg) {
return {
"name": useCount[0],
"type": useCount[1],
"count": useCount[2],
"negCount": useCount[3],
"lastUseDate": useCount[4]
}
}
return {
"name": useCount[0],
"type": useCount[1],
"count": useCount[2],
"lastUseDate": useCount[3]
}
});
}
// Call API endpoint to increase bias of tag in the database
function increaseUseCount(tagName, type, negative = false) {
postTacAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`);
}
// Get use count of tag from the database
async function getUseCount(tagName, type, negative = false) {
const response = await fetchTacAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false);
// Guard for no db
if (response == null) return null;
// Result
return response["result"];
}
async function getUseCounts(tagNames, types, negative = false) {
// While semantically weird, we have to use POST here for the body, as urls are limited in length
const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative});
const response = await postTacAPI(`tacapi/v1/get-use-count-list`, body)
// Guard for no db
if (response == null) return null;
// Results
return mapUseCountArray(response["result"]);
}
async function getAllUseCounts() {
const response = await fetchTacAPI(`tacapi/v1/get-all-use-counts`);
// Guard for no db
if (response == null) return null;
// Results
return mapUseCountArray(response["result"], true);
}
async function resetUseCount(tagName, type, resetPosCount, resetNegCount) {
await putTacAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`);
}
function createTagUsageTable(tagCounts) {
// Create table
let tagTable = document.createElement("table");
tagTable.innerHTML =
`<thead>
<tr>
<td>Name</td>
<td>Type</td>
<td>Count(+)</td>
<td>Count(-)</td>
<td>Last used</td>
</tr>
</thead>`;
tagTable.id = "tac_tagUsageTable"
tagCounts.forEach(t => {
let tr = document.createElement("tr");
// Fill values
let values = [t.name, t.type-1, t.count, t.negCount, t.lastUseDate]
values.forEach(v => {
let td = document.createElement("td");
td.innerText = v;
tr.append(td);
});
// Add delete/reset button
let delButton = document.createElement("button");
delButton.innerText = "🗑️";
delButton.title = "Reset count";
tr.append(delButton);
tagTable.append(tr)
});
return tagTable;
}
// Sliding window function to get possible combination groups of an array
function toNgrams(inputArray, size) {
@@ -242,12 +365,19 @@ function getSortFunction() {
let criterion = TAC_CFG.modelSortOrder || "Name";
const textSort = (a, b, reverse = false) => {
const textHolderA = a.type === ResultType.chant ? a.aliases : a.text;
const textHolderB = b.type === ResultType.chant ? b.aliases : b.text;
// Assign keys so next sort is faster
if (!a.sortKey) {
a.sortKey = a.type === ResultType.chant
? a.aliases
: a.text;
}
if (!b.sortKey) {
b.sortKey = b.type === ResultType.chant
? b.aliases
: b.text;
}
const aKey = a.sortKey || textHolderA;
const bKey = b.sortKey || textHolderB;
return reverse ? bKey.localeCompare(aKey) : aKey.localeCompare(bKey);
return reverse ? b.sortKey.localeCompare(a.sortKey) : a.sortKey.localeCompare(b.sortKey);
}
const numericSort = (a, b, reverse = false) => {
const noKey = reverse ? "-1" : Number.MAX_SAFE_INTEGER;

View File

@@ -50,7 +50,7 @@ async function load() {
async function sanitize(tagType, text) {
if (tagType === ResultType.lora) {
let multiplier = TAC_CFG.extraNetworksDefaultMultiplier;
let info = await fetchAPI(`tacapi/v1/lora-info/${text}`)
let info = await fetchTacAPI(`tacapi/v1/lora-info/${text}`)
if (info && info["preferred weight"]) {
multiplier = info["preferred weight"];
}

View File

@@ -50,7 +50,7 @@ async function load() {
async function sanitize(tagType, text) {
if (tagType === ResultType.lyco) {
let multiplier = TAC_CFG.extraNetworksDefaultMultiplier;
let info = await fetchAPI(`tacapi/v1/lyco-info/${text}`)
let info = await fetchTacAPI(`tacapi/v1/lyco-info/${text}`)
if (info && info["preferred weight"]) {
multiplier = info["preferred weight"];
}

View File

@@ -1,14 +1,14 @@
// Regex
const WC_REGEX = /\b__([^,]+)__([^, ]*)\b/g;
const WC_REGEX = new RegExp(/__([^,]+)__([^, ]*)/g);
// Trigger conditions
const WC_TRIGGER = () => TAC_CFG.useWildcards && [...tagword.matchAll(WC_REGEX)].length > 0;
const WC_FILE_TRIGGER = () => TAC_CFG.useWildcards && (tagword.startsWith("__") && !tagword.endsWith("__") || tagword === "__");
const WC_TRIGGER = () => TAC_CFG.useWildcards && [...tagword.matchAll(new RegExp(WC_REGEX.source.replaceAll("__", escapeRegExp(TAC_CFG.wcWrap)), "g"))].length > 0;
const WC_FILE_TRIGGER = () => TAC_CFG.useWildcards && (tagword.startsWith(TAC_CFG.wcWrap) && !tagword.endsWith(TAC_CFG.wcWrap) || tagword === TAC_CFG.wcWrap);
class WildcardParser extends BaseTagParser {
async parse() {
// Show wildcards from a file with that name
let wcMatch = [...tagword.matchAll(WC_REGEX)]
let wcMatch = [...tagword.matchAll(new RegExp(WC_REGEX.source.replaceAll("__", escapeRegExp(TAC_CFG.wcWrap)), "g"))];
let wcFile = wcMatch[0][1];
let wcWord = wcMatch[0][2];
@@ -38,7 +38,7 @@ class WildcardParser extends BaseTagParser {
}
wildcards = wildcards.concat(getDescendantProp(yamlWildcards[basePath], fileName));
} else {
const fileContent = (await fetchAPI(`tacapi/v1/wildcard-contents?basepath=${basePath}&filename=${fileName}.txt`, false))
const fileContent = (await fetchTacAPI(`tacapi/v1/wildcard-contents?basepath=${basePath}&filename=${fileName}.txt`, false))
.split("\n")
.filter(x => x.trim().length > 0 && !x.startsWith('#')); // Remove empty lines and comments
wildcards = wildcards.concat(fileContent);
@@ -64,8 +64,8 @@ class WildcardFileParser extends BaseTagParser {
parse() {
// Show available wildcard files
let tempResults = [];
if (tagword !== "__") {
let lmb = (x) => x[1].toLowerCase().includes(tagword.replace("__", ""))
if (tagword !== TAC_CFG.wcWrap) {
let lmb = (x) => x[1].toLowerCase().includes(tagword.replace(TAC_CFG.wcWrap, ""))
tempResults = wildcardFiles.filter(lmb).concat(wildcardExtFiles.filter(lmb)) // Filter by tagword
} else {
tempResults = wildcardFiles.concat(wildcardExtFiles);
@@ -151,7 +151,7 @@ async function load() {
function sanitize(tagType, text) {
if (tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard) {
return `__${text}__`;
return `${TAC_CFG.wcWrap}${text}${TAC_CFG.wcWrap}`;
} else if (tagType === ResultType.wildcardTag) {
return text;
}

View File

@@ -86,6 +86,10 @@ const autocompleteCSS = `
white-space: nowrap;
color: var(--meta-text-color);
}
.acMetaText.biased::before {
content: "✨";
margin-right: 2px;
}
.acWikiLink {
padding: 0.5rem;
margin: -0.5rem 0 -0.5rem -0.5rem;
@@ -221,9 +225,16 @@ async function syncOptions() {
showWikiLinks: opts["tac_showWikiLinks"],
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
modelSortOrder: opts["tac_modelSortOrder"],
frequencySort: opts["tac_frequencySort"],
frequencyFunction: opts["tac_frequencyFunction"],
frequencyMinCount: opts["tac_frequencyMinCount"],
frequencyMaxAge: opts["tac_frequencyMaxAge"],
frequencyRecommendCap: opts["tac_frequencyRecommendCap"],
frequencyIncludeAlias: opts["tac_frequencyIncludeAlias"],
useStyleVars: opts["tac_useStyleVars"],
// Insertion related settings
replaceUnderscores: opts["tac_replaceUnderscores"],
replaceUnderscoresExclusionList: opts["tac_undersocreReplacementExclusionList"],
escapeParentheses: opts["tac_escapeParentheses"],
appendComma: opts["tac_appendComma"],
appendSpace: opts["tac_appendSpace"],
@@ -231,6 +242,7 @@ async function syncOptions() {
wildcardCompletionMode: opts["tac_wildcardCompletionMode"],
modelKeywordCompletion: opts["tac_modelKeywordCompletion"],
modelKeywordLocation: opts["tac_modelKeywordLocation"],
wcWrap: opts["dp_parser_wildcard_wrap"] || "__", // to support custom wrapper chars set by dp_parser
// Alias settings
alias: {
searchByAlias: opts["tac_alias.searchByAlias"],
@@ -403,7 +415,7 @@ const COMPLETED_WILDCARD_REGEX = /__[^\s,_][^\t\n\r,_]*[^\s,_]__[^\s,_]*/g;
const STYLE_VAR_REGEX = /\$\(?[^$|\[\],\s]*\)?/g;
const NORMAL_TAG_REGEX = /[^\s,|<>\[\]:]+_\([^\s,|<>\[\]:]*\)?|[^\s,|<>():\[\]]+|</g;
const RUBY_TAG_REGEX = /[\w\d<][\w\d' \-?!/$%]{2,}>?/g;
const TAG_REGEX = new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source}|${STYLE_VAR_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g");
const TAG_REGEX = () => { return new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source.replaceAll("__", escapeRegExp(TAC_CFG.wcWrap))}|${STYLE_VAR_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g"); }
// On click, insert the tag into the prompt textbox with respect to the cursor position
async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithoutChoice = false) {
@@ -419,8 +431,12 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
if (sanitizeResults && sanitizeResults.length > 0) {
sanitizedText = sanitizeResults[0];
} else {
sanitizedText = TAC_CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
const excluded_tags = TAC_CFG.replaceUnderscoresExclusionList?.split(',').map(s => s.trim()) || [];
if (TAC_CFG.replaceUnderscores && !excluded_tags.includes(sanitizedText)) {
sanitizedText = text.replaceAll("_", " ")
} else {
sanitizedText = text;
}
if (TAC_CFG.escapeParentheses && tagType === ResultType.tag) {
sanitizedText = sanitizedText
.replaceAll("(", "\\(")
@@ -459,13 +475,44 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
// Don't cut off the __ at the end if it is already the full path
if (firstDifference > 0 && firstDifference < longestResult) {
// +2 because the sanitized text already has the __ at the start but the matched text doesn't
sanitizedText = sanitizedText.substring(0, firstDifference + 2);
sanitizedText = sanitizedText.substring(0, firstDifference + TAC_CFG.wcWrap.length);
} else if (firstDifference === 0) {
sanitizedText = tagword;
}
}
}
// Frequency db update
if (TAC_CFG.frequencySort) {
let name = null;
switch (tagType) {
case ResultType.wildcardFile:
case ResultType.yamlWildcard:
// We only want to update the frequency for a full wildcard, not partial paths
if (sanitizedText.endsWith(TAC_CFG.wcWrap))
name = text
break;
case ResultType.chant:
// Chants use a slightly different format
name = result.aliases;
break;
default:
name = text;
break;
}
if (name && name.length > 0) {
// Check if it's a negative prompt
let textAreaId = getTextAreaIdentifier(textArea);
let isNegative = textAreaId.includes("n");
// Sanitize name for API call
name = encodeURIComponent(name)
// Call API & update db
increaseUseCount(name, tagType, isNegative)
}
}
var prompt = textArea.value;
// Edit prompt text
@@ -511,7 +558,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
let keywords = null;
// Check built-in activation words first
if (tagType === ResultType.lora || tagType === ResultType.lyco) {
let info = await fetchAPI(`tacapi/v1/lora-info/${result.text}`)
let info = await fetchTacAPI(`tacapi/v1/lora-info/${result.text}`)
if (info && info["activation text"]) {
keywords = info["activation text"];
}
@@ -523,7 +570,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
// No match, try to find a sha256 match from the cache file
if (!nameDict) {
const sha256 = await fetchAPI(`/tacapi/v1/lora-cached-hash/${result.text}`)
const sha256 = await fetchTacAPI(`/tacapi/v1/lora-cached-hash/${result.text}`)
if (sha256) {
nameDict = modelKeywordDict.get(sha256);
}
@@ -574,12 +621,14 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
tacSelfTrigger = true;
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
// Uses a built-in method from the webui's ui.js which also already accounts for event target
if (tagType === ResultType.wildcardTag || tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard)
tacSelfTrigger = true;
updateInput(textArea);
// Update previous tags with the edited prompt to prevent re-searching the same term
let weightedTags = [...newPrompt.matchAll(WEIGHT_REGEX)]
.map(match => match[1]);
let tags = newPrompt.match(TAG_REGEX)
let tags = newPrompt.match(TAG_REGEX())
if (weightedTags !== null) {
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted)))
.concat(weightedTags);
@@ -682,28 +731,39 @@ function addResultsToList(textArea, results, tagword, resetList) {
}
// Add wiki link if the setting is enabled and a supported tag set loaded
if (TAC_CFG.showWikiLinks
&& (result.type === ResultType.tag)
&& (tagFileName.toLowerCase().startsWith("danbooru") || tagFileName.toLowerCase().startsWith("e621"))) {
if (
TAC_CFG.showWikiLinks &&
result.type === ResultType.tag &&
(tagFileName.toLowerCase().startsWith("danbooru") ||
tagFileName.toLowerCase().startsWith("e621"))
) {
let wikiLink = document.createElement("a");
wikiLink.classList.add("acWikiLink");
wikiLink.innerText = "?";
wikiLink.title = "Open external wiki page for this tag";
let linkPart = displayText;
// Only use alias result if it is one
if (displayText.includes("➝"))
linkPart = displayText.split(" ➝ ")[1];
if (displayText.includes("➝")) linkPart = displayText.split(" ➝ ")[1];
// Remove any trailing translations
if (linkPart.includes("[")) {
linkPart = linkPart.split("[")[0]
linkPart = linkPart.split("[")[0];
}
linkPart = encodeURIComponent(linkPart);
// Set link based on selected file
let tagFileNameLower = tagFileName.toLowerCase();
if (tagFileNameLower.startsWith("danbooru")) {
if (tagFileNameLower.startsWith("danbooru_e621_merged")) {
// Use danbooru for categories 0-5, e621 for 6+
// Based on the merged categories from https://github.com/DraconicDragon/dbr-e621-lists-archive/tree/main/tag-lists/danbooru_e621_merged
// Danbooru is also the fallback if result.category is not set
wikiLink.href =
result.category && result.category >= 6
? `https://e621.net/wiki_pages/${linkPart}`
: `https://danbooru.donmai.us/wiki_pages/${linkPart}`;
} else if (tagFileNameLower.startsWith("danbooru")) {
wikiLink.href = `https://danbooru.donmai.us/wiki_pages/${linkPart}`;
} else if (tagFileNameLower.startsWith("e621")) {
wikiLink.href = `https://e621.net/wiki_pages/${linkPart}`;
@@ -733,7 +793,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
}
// Post count
if (result.count && !isNaN(result.count)) {
if (result.count && !isNaN(result.count) && result.count !== Number.MAX_SAFE_INTEGER) {
let postCount = result.count;
let formatter;
@@ -765,8 +825,56 @@ function addResultsToList(textArea, results, tagword, resetList) {
flexDiv.appendChild(metaDiv);
}
// Add listener
li.addEventListener("click", function () { insertTextAtCursor(textArea, result, tagword); });
// Add small ✨ marker to indicate usage sorting
if (result.usageBias) {
flexDiv.querySelector(".acMetaText").classList.add("biased");
flexDiv.title = "✨ Frequent tag. Ctrl/Cmd + click to reset usage count."
}
// Check if it's a negative prompt
let isNegative = textAreaId.includes("n");
// Add click listener
li.addEventListener("click", (e) => {
if (e.ctrlKey || e.metaKey) {
resetUseCount(result.text, result.type, !isNegative, isNegative);
flexDiv.querySelector(".acMetaText").classList.remove("biased");
} else {
insertTextAtCursor(textArea, result, tagword);
}
});
// Add delayed hover listener for extra network previews
if (
TAC_CFG.showExtraNetworkPreviews &&
[
ResultType.embedding,
ResultType.hypernetwork,
ResultType.lora,
ResultType.lyco,
].includes(result.type)
) {
li.addEventListener("mouseover", async () => {
const me = this;
let hoverTimeout;
hoverTimeout = setTimeout(async () => {
// If the tag we hover over is already selected, do nothing
if (selectedTag && selectedTag === i) return;
oldSelectedTag = selectedTag;
selectedTag = i;
// Update selection without scrolling to the item (since we would
// immediately trigger the next scroll as the items move under the cursor)
updateSelectionStyle(textArea, selectedTag, oldSelectedTag, false);
}, 400);
// Reset delay timer if we leave the item
me.addEventListener("mouseout", () => {
clearTimeout(hoverTimeout);
});
});
}
// Add element to list
resultsList.appendChild(li);
}
@@ -779,7 +887,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
}
}
async function updateSelectionStyle(textArea, newIndex, oldIndex) {
async function updateSelectionStyle(textArea, newIndex, oldIndex, scroll = true) {
let textAreaId = getTextAreaIdentifier(textArea);
let resultDiv = gradioApp().querySelector('.autocompleteResults' + textAreaId);
let resultsList = resultDiv.querySelector('ul');
@@ -794,40 +902,25 @@ async function updateSelectionStyle(textArea, newIndex, oldIndex) {
let selected = items[newIndex];
selected.classList.add('selected');
// Set scrolltop to selected item
resultDiv.scrollTop = selected.offsetTop - resultDiv.offsetTop;
// Set scrolltop to selected item
if (scroll) resultDiv.scrollTop = selected.offsetTop - resultDiv.offsetTop;
}
// Show preview if enabled and the selected type supports it
if (newIndex !== null) {
let selected = items[newIndex];
let previewTypes = ["v1 Embedding", "v2 Embedding", "Hypernetwork", "Lora", "Lyco"];
let selectedType = selected.querySelector(".acMetaText").innerText;
let selectedFilename = selected.querySelector(".acListItem").innerText;
let selectedResult = results[newIndex];
let selectedType = selectedResult.type;
// These types support previews (others could technically too, but are not native to the webui gallery)
let previewTypes = [ResultType.embedding, ResultType.hypernetwork, ResultType.lora, ResultType.lyco];
let previewDiv = gradioApp().querySelector(`.autocompleteParent${textAreaId} .sideInfo`);
if (TAC_CFG.showExtraNetworkPreviews && previewTypes.includes(selectedType)) {
let shorthandType = "";
switch (selectedType) {
case "v1 Embedding":
case "v2 Embedding":
shorthandType = "embed";
break;
case "Hypernetwork":
shorthandType = "hyper";
break;
case "Lora":
shorthandType = "lora";
break;
case "Lyco":
shorthandType = "lyco";
break;
}
let img = previewDiv.querySelector("img");
let url = await getExtraNetworkPreviewURL(selectedFilename, shorthandType);
// String representation of our type enum
const typeString = Object.keys(ResultType)[selectedType - 1].toLowerCase();
// Get image from API
let url = await getTacExtraNetworkPreviewURL(selectedResult.text, typeString);
if (url) {
img.src = url;
previewDiv.style.display = "block";
@@ -995,7 +1088,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
// We also match for the weighting format (e.g. "tag:1.0") here, and combine the two to get the full tag word set
let weightedTags = [...prompt.matchAll(WEIGHT_REGEX)]
.map(match => match[1]);
let tags = prompt.match(TAG_REGEX)
let tags = prompt.match(TAG_REGEX())
if (weightedTags !== null && tags !== null) {
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted) && !tag.startsWith("<[") && !tag.startsWith("$(")))
.concat(weightedTags);
@@ -1034,6 +1127,9 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
resultCountBeforeNormalTags = 0;
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
// Needed for slicing check later
let normalTags = false;
// Process all parsers
let resultCandidates = (await processParsers(textArea, prompt))?.filter(x => x.length > 0);
// If one ore more result candidates match, use their results
@@ -1043,32 +1139,12 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
// Sort results, but not if it's umi tags since they are sorted by count
if (!(resultCandidates.length === 1 && results[0].type === ResultType.umiWildcard))
results = results.sort(getSortFunction());
// Since some tags are kaomoji, we have to add the normal results in some cases
if (tagword.startsWith("<") || tagword.startsWith("*<")) {
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
tagword = tagword.slice(1);
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
} else {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
result.count = g[2];
result.aliases = g[3];
results.push(result);
});
}
}
// Else search the normal tag list
if (!resultCandidates || resultCandidates.length === 0
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
) {
normalTags = true;
resultCountBeforeNormalTags = results.length;
// Create escaped search regex with support for * as a start placeholder
@@ -1123,11 +1199,6 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
results = results.concat(extraResults);
}
}
// Slice if the user has set a max result count
if (!TAC_CFG.showAllResults) {
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
}
// Guard for empty results
@@ -1137,6 +1208,57 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
return;
}
// Sort again with frequency / usage count if enabled
if (TAC_CFG.frequencySort) {
// Split our results into a list of names and types
let tagNames = [];
let aliasNames = [];
let types = [];
// Limit to 2k for performance reasons
const aliasTypes = [ResultType.tag, ResultType.extra];
results.slice(0,2000).forEach(r => {
const name = r.type === ResultType.chant ? r.aliases : r.text;
// Add to alias list or tag list depending on if the name includes the tagword
// (the same criteria is used in the filter in calculateUsageBias)
if (aliasTypes.includes(r.type) && !name.includes(tagword)) {
aliasNames.push(name);
} else {
tagNames.push(name);
}
types.push(r.type);
});
// Check if it's a negative prompt
let textAreaId = getTextAreaIdentifier(textArea);
let isNegative = textAreaId.includes("n");
// Request use counts from the DB
const names = TAC_CFG.frequencyIncludeAlias ? tagNames.concat(aliasNames) : tagNames;
const counts = await getUseCounts(names, types, isNegative) || [];
// Pre-calculate weights to prevent duplicate work
const resultBiasMap = new Map();
results.forEach(result => {
const name = result.type === ResultType.chant ? result.aliases : result.text;
const type = result.type;
// Find matching pair from DB results
const useStats = counts.find(c => c.name === name && c.type === type);
const uses = useStats?.count || 0;
// Calculate & set weight
const weight = calculateUsageBias(result, result.count, uses)
resultBiasMap.set(result, weight);
});
// Actual sorting with the pre-calculated weights
results = results.sort((a, b) => {
return resultBiasMap.get(b) - resultBiasMap.get(a);
});
}
// Slice if the user has set a max result count and we are not in a extra networks / wildcard list
if (!TAC_CFG.showAllResults && normalTags) {
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
}
addResultsToList(textArea, results, tagword, true);
showResults(textArea);
}
@@ -1272,7 +1394,7 @@ async function refreshTacTempFiles(api = false) {
}
if (api) {
await postAPI("tacapi/v1/refresh-temp-files", null);
await postTacAPI("tacapi/v1/refresh-temp-files");
await reload();
} else {
setTimeout(async () => {
@@ -1282,7 +1404,7 @@ async function refreshTacTempFiles(api = false) {
}
async function refreshEmbeddings() {
await postAPI("tacapi/v1/refresh-embeddings", null);
await postTacAPI("tacapi/v1/refresh-embeddings", null);
embeddings = [];
await processQueue(QUEUE_FILE_LOAD, null);
console.log("TAC: Refreshed embeddings");
@@ -1376,9 +1498,16 @@ async function setup() {
gradioApp().querySelector("#refresh_tac_refreshTempFiles")?.addEventListener("click", refreshTacTempFiles);
// Also add listener for external network refresh button (plus triggering python code)
["#img2img_extra_refresh", "#txt2img_extra_refresh"].forEach(e => {
gradioApp().querySelector(e)?.addEventListener("click", ()=>{
refreshTacTempFiles(true);
let alreadyAdded = new Set();
["#img2img_extra_refresh", "#txt2img_extra_refresh", ".extra-network-control--refresh"].forEach(e => {
const elems = gradioApp().querySelectorAll(e);
elems.forEach(elem => {
if (!elem || alreadyAdded.has(elem)) return;
alreadyAdded.add(elem);
elem.addEventListener("click", ()=>{
refreshTacTempFiles(true);
});
});
})

View File

@@ -2,24 +2,72 @@
# to a temporary file to expose it to the javascript side
import glob
import importlib
import json
import sqlite3
import sys
import urllib.parse
from asyncio import sleep
from pathlib import Path
import gradio as gr
import yaml
from fastapi import FastAPI
from fastapi.responses import Response, FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared, hashes
from fastapi.responses import FileResponse, JSONResponse, Response
from modules import hashes, script_callbacks, sd_hijack, sd_models, shared
from pydantic import BaseModel
from scripts.model_keyword_support import (get_lora_simple_hash,
load_hash_cache, update_hash_cache,
write_model_keyword_path)
from scripts.shared_paths import *
try:
try:
from scripts import tag_frequency_db as tdb
except ModuleNotFoundError:
from inspect import currentframe, getframeinfo
filename = getframeinfo(currentframe()).filename
parent = Path(filename).resolve().parent
sys.path.append(str(parent))
import tag_frequency_db as tdb
# Ensure the db dependency is reloaded on script reload
importlib.reload(tdb)
db = tdb.TagFrequencyDb()
if int(db.version) != int(tdb.db_ver):
raise ValueError("Database version mismatch")
except (ImportError, ValueError, sqlite3.Error) as e:
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
db = None
def get_embed_db(sd_model=None):
"""Returns the embedding database, if available."""
try:
return sd_hijack.model_hijack.embedding_db
except Exception:
try: # sd next with diffusers backend
sdnext_model = sd_model if sd_model is not None else shared.sd_model
return sdnext_model.embedding_db
except Exception:
try: # forge webui
forge_model = sd_model if sd_model is not None else sd_models.model_data.get_sd_model()
if type(forge_model).__name__ == "FakeInitialModel":
return None
else:
processer = getattr(forge_model, "text_processing_engine", getattr(forge_model, "text_processing_engine_l"))
return processer.embeddings
except Exception:
return None
# Attempt to get embedding load function, using the same call as api.
try:
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
embed_db = get_embed_db()
if embed_db is not None:
load_textual_inversion_embeddings = embed_db.load_textual_inversion_embeddings
else:
load_textual_inversion_embeddings = lambda *args, **kwargs: None
except Exception as e: # Not supported.
load_textual_inversion_embeddings = lambda *args, **kwargs: None
print("Tag Autocomplete: Cannot reload embeddings instantly:", e)
@@ -27,8 +75,8 @@ except Exception as e: # Not supported.
# Sorting functions for extra networks / embeddings stuff
sort_criteria = {
"Name": lambda path, name, subpath: name.lower() if subpath else path.stem.lower(),
"Date Modified (newest first)": lambda path, name, subpath: path.stat().st_mtime,
"Date Modified (oldest first)": lambda path, name, subpath: path.stat().st_mtime
"Date Modified (newest first)": lambda path, name, subpath: path.stat().st_mtime if path.exists() else name.lower(),
"Date Modified (oldest first)": lambda path, name, subpath: path.stat().st_mtime if path.exists() else name.lower()
}
def sort_models(model_list, sort_method = None, name_has_subpath = False):
@@ -86,7 +134,11 @@ def is_umi_format(data):
"""Returns True if the YAML file is in UMI format."""
issue_found = False
for item in data:
if not (data[item] and 'Tags' in data[item] and isinstance(data[item]['Tags'], list)):
try:
if not (data[item] and 'Tags' in data[item] and isinstance(data[item]['Tags'], list)):
issue_found = True
break
except:
issue_found = True
break
return not issue_found
@@ -108,9 +160,12 @@ def parse_dynamic_prompt_format(yaml_wildcards, data, path):
elif not (isinstance(value, list) and all(isinstance(v, str) for v in value)):
del d[key]
recurse_dict(data)
# Add to yaml_wildcards
yaml_wildcards[path.name] = data
try:
recurse_dict(data)
# Add to yaml_wildcards
yaml_wildcards[path.name] = data
except:
return
def get_yaml_wildcards():
@@ -135,9 +190,13 @@ def get_yaml_wildcards():
parse_dynamic_prompt_format(yaml_wildcards, data, path)
else:
print('No data found in ' + path.name)
except (yaml.YAMLError, UnicodeDecodeError) as e:
except (yaml.YAMLError, UnicodeDecodeError, AttributeError, TypeError) as e:
# YAML file not in wildcard format or couldn't be read
print(f'Issue in parsing YAML file {path.name}: {e}')
continue
except Exception as e:
# Something else went wrong, just skip
continue
# Sort by count
umi_sorted = sorted(umi_tags.items(), key=lambda item: item[1], reverse=True)
@@ -166,35 +225,45 @@ def get_embeddings(sd_model):
results = []
try:
# The sd_model embedding_db reference only exists in sd.next with diffusers backend
try:
loaded_sdnext = sd_model.embedding_db.word_embeddings
skipped_sdnext = sd_model.embedding_db.skipped_embeddings
except (NameError, AttributeError):
loaded_sdnext = {}
skipped_sdnext = {}
embed_db = get_embed_db(sd_model)
# Re-register callback if needed
global load_textual_inversion_embeddings
if embed_db is not None and load_textual_inversion_embeddings != embed_db.load_textual_inversion_embeddings:
load_textual_inversion_embeddings = embed_db.load_textual_inversion_embeddings
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
loaded = loaded | loaded_sdnext
skipped = skipped | skipped_sdnext
loaded = embed_db.word_embeddings
skipped = embed_db.skipped_embeddings
# Add embeddings to the correct list
for key, emb in (loaded | skipped).items():
if emb.filename is None:
continue
if emb.shape is None:
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
elif emb.shape == V1_SHAPE:
emb_v1.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v1"))
elif emb.shape == V2_SHAPE:
emb_v2.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v2"))
elif emb.shape == VXL_SHAPE:
emb_vXL.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "vXL"))
for key, emb in (skipped | loaded).items():
filename = getattr(emb, "filename", None)
if filename is None:
if emb.shape is None:
emb_unknown.append((Path(key), key, ""))
elif emb.shape == V1_SHAPE:
emb_v1.append((Path(key), key, "v1"))
elif emb.shape == V2_SHAPE:
emb_v2.append((Path(key), key, "v2"))
elif emb.shape == VXL_SHAPE:
emb_vXL.append((Path(key), key, "vXL"))
else:
emb_unknown.append((Path(key), key, ""))
else:
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
if emb.filename is None:
continue
if emb.shape is None:
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
elif emb.shape == V1_SHAPE:
emb_v1.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v1"))
elif emb.shape == V2_SHAPE:
emb_v2.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v2"))
elif emb.shape == VXL_SHAPE:
emb_vXL.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "vXL"))
else:
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL) + sort_models(emb_unknown)
except AttributeError:
@@ -265,7 +334,7 @@ try:
import sys
from modules import extensions
sys.path.append(Path(extensions.extensions_builtin_dir).joinpath("Lora").as_posix())
import lora # pyright: ignore [reportMissingImports]
import lora # pyright: ignore [reportMissingImports]
def _get_lora():
return [
@@ -406,8 +475,11 @@ def refresh_embeddings(force: bool, *args, **kwargs):
# Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand.
# This will just skip embedding loading if no model is loaded yet (or there really are no embeddings).
# Try catch is just for safety incase sd_hijack access fails for some reason.
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
embed_db = get_embed_db()
if embed_db is None:
return
loaded = embed_db.word_embeddings
skipped = embed_db.skipped_embeddings
if len((loaded | skipped)) > 0:
load_textual_inversion_embeddings(force_reload=force)
get_embeddings(None)
@@ -420,7 +492,8 @@ def refresh_temp_files(*args, **kwargs):
if skip_wildcard_refresh:
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
write_temp_files(skip_wildcard_refresh)
refresh_embeddings(force=True)
force_embed_refresh = getattr(shared.opts, "tac_forceRefreshEmbeddings", False)
refresh_embeddings(force=force_embed_refresh)
def write_style_names(*args, **kwargs):
styles = get_style_names()
@@ -488,6 +561,13 @@ def on_ui_settings():
return self
shared.OptionInfo.needs_restart = needs_restart
# Dictionary of function options and their explanations
frequency_sort_functions = {
"Logarithmic (weak)": "Will respect the base order and slightly prefer often used tags",
"Logarithmic (strong)": "Same as Logarithmic (weak), but with a stronger bias",
"Usage first": "Will list used tags by frequency before all others",
}
tac_options = {
# Main tag file
"tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files),
@@ -510,6 +590,7 @@ def on_ui_settings():
"tac_wildcardExclusionList": shared.OptionInfo("", "Wildcard folder exclusion list").info("Add folder names that shouldn't be searched for wildcards, separated by comma.").needs_restart(),
"tac_skipWildcardRefresh": shared.OptionInfo(False, "Don't re-scan for wildcard files when pressing the extra networks refresh button").info("Useful to prevent hanging if you use a very large wildcard collection."),
"tac_useEmbeddings": shared.OptionInfo(True, "Search for embeddings"),
"tac_forceRefreshEmbeddings": shared.OptionInfo(False, "Force refresh embeddings when pressing the extra networks refresh button").info("Turn this on if you have issues with new embeddings not registering correctly in TAC. Warning: Seems to cause reloading issues in gradio for some users."),
"tac_includeEmbeddingsInNormalResults": shared.OptionInfo(False, "Include embeddings in normal tag results").info("The 'JumpTo...' keybinds (End & Home key by default) will select the first non-embedding result of their direction on the first press for quick navigation in longer lists."),
"tac_useHypernetworks": shared.OptionInfo(True, "Search for hypernetworks"),
"tac_useLoras": shared.OptionInfo(True, "Search for Loras"),
@@ -519,8 +600,16 @@ def on_ui_settings():
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
# Frequency sorting settings
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
"tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
"tac_frequencyMinCount": shared.OptionInfo(3, "Minimum number of uses for a tag to be considered frequent").info("Tags with less uses than this will not be sorted higher, even if the sorting function would normally result in a higher position."),
"tac_frequencyMaxAge": shared.OptionInfo(30, "Maximum days since last use for a tag to be considered frequent").info("Similar to the above, tags that haven't been used in this many days will not be sorted higher. Set to 0 to disable."),
"tac_frequencyRecommendCap": shared.OptionInfo(10, "Maximum number of recommended tags").info("Limits the maximum number of recommended tags to not drown out normal results. Set to 0 to disable."),
"tac_frequencyIncludeAlias": shared.OptionInfo(False, "Frequency sorting matches aliases for frequent tags").info("Tag frequency will be increased for the main tag even if an alias is used for completion. This option can be used to override the default behavior of alias results being ignored for frequency sorting."),
# Insertion related settings
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
"tac_undersocreReplacementExclusionList": shared.OptionInfo("0_0,(o)_(o),+_+,+_-,._.,<o>_<o>,<|>_<|>,=_=,>_<,3_3,6_9,>_o,@_@,^_^,o_o,u_u,x_x,|_|,||_||", "Underscore replacement exclusion list").info("Add tags that shouldn't have underscores replaced with spaces, separated by comma."),
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
"tac_appendComma": shared.OptionInfo(True, "Append comma on tag autocompletion"),
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
@@ -597,6 +686,23 @@ def on_ui_settings():
"9": ["#df3647", "#8e1c2b"],
"10": ["#c98f2b", "#7b470e"],
"11": ["#e87ebe", "#a83583"]
},
"danbooru_e621_merged": {
"-1": ["red", "maroon"],
"0": ["lightblue", "dodgerblue"],
"1": ["indianred", "firebrick"],
"3": ["violet", "darkorchid"],
"4": ["lightgreen", "darkgreen"],
"5": ["orange", "darkorange"],
"6": ["red", "maroon"],
"7": ["lightblue", "dodgerblue"],
"8": ["gold", "goldenrod"],
"9": ["gold", "goldenrod"],
"10": ["violet", "darkorchid"],
"11": ["lightgreen", "darkgreen"],
"12": ["tomato", "darksalmon"],
"14": ["whitesmoke", "black"],
"15": ["seagreen", "darkseagreen"]
}
}\
"""
@@ -658,6 +764,7 @@ def api_tac(_: gr.Blocks, app: FastAPI):
@app.post("/tacapi/v1/refresh-temp-files")
async def api_refresh_temp_files():
await sleep(0) # might help with refresh blocking gradio
refresh_temp_files()
@app.post("/tacapi/v1/refresh-embeddings")
@@ -689,9 +796,9 @@ def api_tac(_: gr.Blocks, app: FastAPI):
return LORA_PATH
elif type == "lyco":
return LYCO_PATH
elif type == "hyper":
elif type == "hypernetwork":
return HYP_PATH
elif type == "embed":
elif type == "embedding":
return EMB_PATH
else:
return None
@@ -736,5 +843,62 @@ def api_tac(_: gr.Blocks, app: FastAPI):
return Response(status_code=200) # Success
else:
return Response(status_code=304) # Not modified
def db_request(func, get = False):
if db is not None:
try:
if get:
ret = func()
if ret is list:
ret = [{"name": t[0], "type": t[1], "count": t[2], "lastUseDate": t[3]} for t in ret]
return JSONResponse({"result": ret})
else:
func()
except sqlite3.Error as e:
return JSONResponse({"error": e.__cause__}, status_code=500)
else:
return JSONResponse({"error": "Database not initialized"}, status_code=500)
@app.post("/tacapi/v1/increase-use-count")
async def increase_use_count(tagname: str, ttype: int, neg: bool):
db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
@app.get("/tacapi/v1/get-use-count")
async def get_use_count(tagname: str, ttype: int, neg: bool):
return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
# Small dataholder class
class UseCountListRequest(BaseModel):
tagNames: list[str]
tagTypes: list[int]
neg: bool = False
# Semantically weird to use post here, but it's required for the body on js side
@app.post("/tacapi/v1/get-use-count-list")
async def get_use_count_list(body: UseCountListRequest):
# If a date limit is set > 0, pass it to the db
date_limit = getattr(shared.opts, "tac_frequencyMaxAge", 30)
date_limit = date_limit if date_limit > 0 else None
if db:
count_list = list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg, date_limit))
else:
count_list = None
# If a limit is set, return at max the top n results by count
if count_list and len(count_list):
limit = int(min(getattr(shared.opts, "tac_frequencyRecommendCap", 10), len(count_list)))
# Sort by count and return the top n
if limit > 0:
count_list = sorted(count_list, key=lambda x: x[2], reverse=True)[:limit]
return db_request(lambda: count_list, get=True)
@app.put("/tacapi/v1/reset-use-count")
async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
@app.get("/tacapi/v1/get-all-use-counts")
async def get_all_tag_counts():
return db_request(lambda: db.get_all_tags(), get=True)
script_callbacks.on_app_started(api_tac)

190
scripts/tag_frequency_db.py Normal file
View File

@@ -0,0 +1,190 @@
import sqlite3
from contextlib import contextmanager
from scripts.shared_paths import TAGS_PATH
db_file = TAGS_PATH.joinpath("tag_frequency.db")
timeout = 30
db_ver = 1
@contextmanager
def transaction(db=db_file):
"""Context manager for database transactions.
Ensures that the connection is properly closed after the transaction.
"""
try:
conn = sqlite3.connect(db, timeout=timeout)
conn.isolation_level = None
cursor = conn.cursor()
cursor.execute("BEGIN")
yield cursor
cursor.execute("COMMIT")
except sqlite3.Error as e:
print("Tag Autocomplete: Frequency database error:", e)
finally:
if conn:
conn.close()
class TagFrequencyDb:
"""Class containing creation and interaction methods for the tag frequency database"""
def __init__(self) -> None:
self.version = self.__check()
def __check(self):
if not db_file.exists():
print("Tag Autocomplete: Creating frequency database")
with transaction() as cursor:
self.__create_db(cursor)
self.__update_db_data(cursor, "version", db_ver)
print("Tag Autocomplete: Database successfully created")
return self.__get_version()
def __create_db(self, cursor: sqlite3.Cursor):
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS db_data (
key TEXT PRIMARY KEY,
value TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tag_frequency (
name TEXT NOT NULL,
type INT NOT NULL,
count_pos INT,
count_neg INT,
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (name, type)
)
"""
)
def __update_db_data(self, cursor: sqlite3.Cursor, key, value):
cursor.execute(
"""
INSERT OR REPLACE
INTO db_data (key, value)
VALUES (?, ?)
""",
(key, value),
)
def __get_version(self):
db_version = None
with transaction() as cursor:
cursor.execute(
"""
SELECT value
FROM db_data
WHERE key = 'version'
"""
)
db_version = cursor.fetchone()
return db_version[0] if db_version else 0
def get_all_tags(self):
with transaction() as cursor:
cursor.execute(
f"""
SELECT name, type, count_pos, count_neg, last_used
FROM tag_frequency
WHERE count_pos > 0 OR count_neg > 0
ORDER BY count_pos + count_neg DESC
"""
)
tags = cursor.fetchall()
return tags
def get_tag_count(self, tag, ttype, negative=False):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
""",
(tag, ttype),
)
tag_count = cursor.fetchone()
if tag_count:
return tag_count[0], tag_count[1]
else:
return 0, None
def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False, date_limit=None):
count_str = "count_neg" if negative else "count_pos"
with transaction() as cursor:
for tag, ttype in zip(tags, ttypes):
if date_limit is not None:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
AND last_used > datetime('now', '-' || ? || ' days')
""",
(tag, ttype, date_limit),
)
else:
cursor.execute(
f"""
SELECT {count_str}, last_used
FROM tag_frequency
WHERE name = ? AND type = ?
""",
(tag, ttype),
)
tag_count = cursor.fetchone()
if tag_count:
yield (tag, ttype, tag_count[0], tag_count[1])
else:
yield (tag, ttype, 0, None)
def increase_tag_count(self, tag, ttype, negative=False):
pos_count = self.get_tag_count(tag, ttype, False)[0]
neg_count = self.get_tag_count(tag, ttype, True)[0]
if negative:
neg_count += 1
else:
pos_count += 1
with transaction() as cursor:
cursor.execute(
f"""
INSERT OR REPLACE
INTO tag_frequency (name, type, count_pos, count_neg)
VALUES (?, ?, ?, ?)
""",
(tag, ttype, pos_count, neg_count),
)
def reset_tag_count(self, tag, ttype, positive=True, negative=False):
if positive and negative:
set_str = "count_pos = 0, count_neg = 0"
elif positive:
set_str = "count_pos = 0"
elif negative:
set_str = "count_neg = 0"
with transaction() as cursor:
cursor.execute(
f"""
UPDATE tag_frequency
SET {set_str}
WHERE name = ? AND type = ?
""",
(tag, ttype),
)

File diff suppressed because it is too large Load Diff

221787
tags/danbooru_e621_merged.csv Normal file

File diff suppressed because one or more lines are too long

View File

@@ -28,5 +28,17 @@
"terms": "Water, Magic, Fancy",
"content": "(extremely detailed CG unity 8k wallpaper), (masterpiece), (best quality), (ultra-detailed), (best illustration),(best shadow), (an extremely delicate and beautiful), classic, dynamic angle, floating, fine detail, Depth of field, classic, (painting), (sketch), (bloom), (shine), glinting stars,\n\na girl, solo, bare shoulders, flat chest, diamond and glaring eyes, beautiful detailed cold face, very long blue and sliver hair, floating black feathers, wavy hair, extremely delicate and beautiful girls, beautiful detailed eyes, glowing eyes,\n\nriver, (forest),palace, (fairyland,feather,flowers, nature),(sunlight),Hazy fog, mist",
"color": 5
},
{
"name": "Pony-Positive",
"terms": "Pony,Score,Positive,Quality",
"content": "score_9, score_8_up, score_7_up, score_6_up, source_anime, source_furry, source_pony, source_cartoon",
"color": 1
},
{
"name": "Pony-Negative",
"terms": "Pony,Score,Negative,Quality",
"content": "score_1, score_2, score_3, score_4, score_5, source_anime, source_furry, source_pony, source_cartoon",
"color": 3
}
]

File diff suppressed because it is too large Load Diff

200358
tags/e621.csv

File diff suppressed because one or more lines are too long