mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-27 03:29:55 +00:00
Compare commits
175 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8766965a30 | ||
|
|
34e68e1628 | ||
|
|
41d185b616 | ||
|
|
e0baa58ace | ||
|
|
c1ef12d887 | ||
|
|
4fc122de4b | ||
|
|
c341ccccb6 | ||
|
|
bda8701734 | ||
|
|
63fca457a7 | ||
|
|
38700d4743 | ||
|
|
bb492ba059 | ||
|
|
40ad070a02 | ||
|
|
209b1dd76b | ||
|
|
196fa19bfc | ||
|
|
6ffeeafc49 | ||
|
|
08b7c58ea7 | ||
|
|
6be91449f3 | ||
|
|
b515c15e01 | ||
|
|
827b99c961 | ||
|
|
49ec047af8 | ||
|
|
f94da07ed1 | ||
|
|
e2cfe7341b | ||
|
|
ce51ec52a2 | ||
|
|
f64d728ac6 | ||
|
|
1c6bba2a3d | ||
|
|
9a47c2ec2c | ||
|
|
fe32ad739d | ||
|
|
ade67e30a6 | ||
|
|
e9a21e7a55 | ||
|
|
3ef2a7d206 | ||
|
|
29b5bf0701 | ||
|
|
3eef536b64 | ||
|
|
0d24e697d2 | ||
|
|
a27633da55 | ||
|
|
4cd6174a22 | ||
|
|
9155e4d42c | ||
|
|
700642a400 | ||
|
|
1b592dbf56 | ||
|
|
d1eea880f3 | ||
|
|
119a3ad51f | ||
|
|
c820a22149 | ||
|
|
eb1e1820f9 | ||
|
|
ef59cff651 | ||
|
|
a454383c43 | ||
|
|
bec567fe26 | ||
|
|
d4041096c9 | ||
|
|
0903259ddf | ||
|
|
f3e64b1fa5 | ||
|
|
312cec5d71 | ||
|
|
b71e6339bd | ||
|
|
7ddbc3c0b2 | ||
|
|
4c2ef8f770 | ||
|
|
97c5e4f53c | ||
|
|
1d8d9f64b5 | ||
|
|
7437850600 | ||
|
|
829a4a7b89 | ||
|
|
22472ac8ad | ||
|
|
5f77fa26d3 | ||
|
|
f810b2dd8f | ||
|
|
08d3436f3b | ||
|
|
afa13306ef | ||
|
|
95200e82e1 | ||
|
|
a63ce64f4e | ||
|
|
a966be7546 | ||
|
|
d37e37acfa | ||
|
|
342fbc9041 | ||
|
|
d496569c9a | ||
|
|
7778142520 | ||
|
|
cde90c13c4 | ||
|
|
231b121fe0 | ||
|
|
c659ed2155 | ||
|
|
0a4c17cada | ||
|
|
6e65811d4a | ||
|
|
03673c060e | ||
|
|
1c11c4ad5a | ||
|
|
30c9593d3d | ||
|
|
f840586b6b | ||
|
|
886704e351 | ||
|
|
41626d22c3 | ||
|
|
57076060df | ||
|
|
5ef346cde3 | ||
|
|
edf76d9df2 | ||
|
|
837dc39811 | ||
|
|
f1870b7e87 | ||
|
|
20b6635a2a | ||
|
|
1fe8f26670 | ||
|
|
e82e958c3e | ||
|
|
2dd48eab79 | ||
|
|
4df90f5c95 | ||
|
|
a156214a48 | ||
|
|
15478e73b5 | ||
|
|
fcacf7dd66 | ||
|
|
82f819f336 | ||
|
|
effda54526 | ||
|
|
434301738a | ||
|
|
58804796f0 | ||
|
|
668ca800b8 | ||
|
|
a7233a594f | ||
|
|
4fba7baa69 | ||
|
|
5ebe22ddfc | ||
|
|
44c5450b28 | ||
|
|
5fd48f53de | ||
|
|
7128efc4f4 | ||
|
|
bd0ddfbb24 | ||
|
|
3108daf0e8 | ||
|
|
446ac14e7f | ||
|
|
363895494b | ||
|
|
04551a8132 | ||
|
|
ffc0e378d3 | ||
|
|
440f109f1f | ||
|
|
80fb247dbe | ||
|
|
b3e71e840d | ||
|
|
998514bebb | ||
|
|
d7e98200a8 | ||
|
|
ac790c8ede | ||
|
|
22365ec8d6 | ||
|
|
030a83aa4d | ||
|
|
460d32a4ed | ||
|
|
581bf1e6a4 | ||
|
|
74ea5493e5 | ||
|
|
94ec8884c3 | ||
|
|
6cf9acd6ab | ||
|
|
109a8a155e | ||
|
|
3caa1b51ed | ||
|
|
b44c36425a | ||
|
|
1e81403180 | ||
|
|
0f487a5c5c | ||
|
|
2baa12fea3 | ||
|
|
1a9157fe6e | ||
|
|
67eeb5fbf6 | ||
|
|
5911248ab9 | ||
|
|
1c693c0263 | ||
|
|
11ffed8afc | ||
|
|
cb54b66eda | ||
|
|
92a937ad01 | ||
|
|
ba9dce8d90 | ||
|
|
2622e1b596 | ||
|
|
b03b1a0211 | ||
|
|
3e33169a3a | ||
|
|
d8d991531a | ||
|
|
f626b9453d | ||
|
|
5067afeee9 | ||
|
|
018c6c8198 | ||
|
|
2846d79b7d | ||
|
|
783a847978 | ||
|
|
44effca702 | ||
|
|
475ef59197 | ||
|
|
3953260485 | ||
|
|
0a8e7d7d84 | ||
|
|
46d07d703a | ||
|
|
bd1dbe92c2 | ||
|
|
66fa745d6f | ||
|
|
37b5dca66e | ||
|
|
5db035cc3a | ||
|
|
90cf3147fd | ||
|
|
4d4f23e551 | ||
|
|
80b47c61bb | ||
|
|
57821aae6a | ||
|
|
e23bb6d4ea | ||
|
|
d4cca00575 | ||
|
|
86ea94a565 | ||
|
|
53f46c91a2 | ||
|
|
e5f93188c3 | ||
|
|
3e57842ac6 | ||
|
|
32c4589df3 | ||
|
|
5bbd97588c | ||
|
|
b2a663f7a7 | ||
|
|
6f93d19a2b | ||
|
|
79bab04fd2 | ||
|
|
5b69d1e622 | ||
|
|
651cf5fb46 | ||
|
|
5deb72cddf | ||
|
|
97ebe78205 | ||
|
|
b937e853c9 | ||
|
|
f63bbf947f |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
tags/temp/
|
||||
__pycache__/
|
||||
tags/tag_frequency.db
|
||||
|
||||
10
README.md
10
README.md
@@ -23,11 +23,12 @@ Booru style tag autocompletion for the AUTOMATIC1111 Stable Diffusion WebUI
|
||||
# 📄 Description
|
||||
|
||||
Tag Autocomplete is an extension for the popular [AUTOMATIC1111 web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for Stable Diffusion.
|
||||
You can install it using the inbuilt available extensions list, clone the files manually as described [below](#-installation), or use a pre-packaged version from [Releases](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/releases).
|
||||
|
||||
It displays autocompletion hints for recognized tags from "image booru" boards such as Danbooru, which are primarily used for browsing Anime-style illustrations.
|
||||
Since some Stable Diffusion models were trained using this information, for example [Waifu Diffusion](https://github.com/harubaru/waifu-diffusion) and many of the NAI-descendant models or merges, using exact tags in prompts can often improve composition and consistency.
|
||||
Since most custom Stable Diffusion models were trained using this information or merged with ones that did, using exact tags in prompts can often improve composition and consistency, even if the model itself has a photorealistic style.
|
||||
|
||||
You can install it using the inbuilt available extensions list, clone the files manually as described [below](#-installation), or use a pre-packaged version from [Releases](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/releases).
|
||||
Disclaimer: The default tag lists contain NSFW terms, please use them responsibly.
|
||||
|
||||
<br/>
|
||||
|
||||
@@ -474,7 +475,9 @@ You can also add this to your quicksettings bar to have the refresh button avail
|
||||
|
||||
# Translations
|
||||
An additional file can be added in the translation section, which will be used to translate both tags and aliases and also enables searching by translation.
|
||||
This file needs to be a CSV in the format `<English tag/alias>,<Translation>`, but for backwards compatibility with older files that used a three column format, you can turn on `Translation file uses old 3-column translation format instead of the new 2-column one` to support them. In that case, the second column will be unused and skipped during parsing.
|
||||
This file needs to be a CSV in the format `<English tag/alias>,<Translation>`. Some older files use a three column format, which requires a compatibility setting to be activated.
|
||||
You can find it under `Settings > Tag autocomplete > Translation filename > Translation file uses old 3-column translation format instead of the new 2-column one`.
|
||||
With it on, the second column will be unused and skipped during parsing.
|
||||
|
||||
Example with Chinese translation:
|
||||
|
||||
@@ -484,6 +487,7 @@ Example with Chinese translation:
|
||||
## List of translations
|
||||
- [🇨🇳 Chinese tags](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/23) by @HalfMAI, using machine translation and manual correction for the most common tags (uses legacy format)
|
||||
- [🇨🇳 Chinese tags](https://github.com/sgmklp/tag-for-autocompletion-with-translation) by @sgmklp, smaller set of manual translations based on https://github.com/zcyzcy88/TagTable
|
||||
- [🇯🇵 Japanese tags](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/265) by @applemango, both machine and human translations available
|
||||
|
||||
> ### 🫵 I need your help!
|
||||
> Translations are a community effort. If you have translated a tag file or want to create one, please open a Pull Request or Issue so your link can be added here.
|
||||
|
||||
@@ -410,8 +410,9 @@ https://www.w3.org/TR/uievents-key/#named-key-attribute-value
|
||||

|
||||
|
||||
## 翻訳リスト
|
||||
- [🇨🇳 Chinese tags](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/23) by @HalfMAI, 最も一般的なタグを機械翻訳と手作業で修正(レガシーフォーマットを使用)
|
||||
- [🇨🇳 Chinese tags](https://github.com/sgmklp/tag-for-autocompletion-with-translation) by @sgmklp, [こちら](https://github.com/zcyzcy88/TagTable)をベースにして、より小さくした手動での翻訳セット。
|
||||
- [🇨🇳 中国語訳](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/23) by @HalfMAI, 最も一般的なタグを機械翻訳と手作業で修正(レガシーフォーマットを使用)
|
||||
- [🇨🇳 中国語訳](https://github.com/sgmklp/tag-for-autocompletion-with-translation) by @sgmklp, [こちら](https://github.com/zcyzcy88/TagTable)をベースにして、より小さくした手動での翻訳セット。
|
||||
- [🇯🇵 日本語訳](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/discussions/265) by @applemango, 機械翻訳と人力翻訳の両方が利用可能。
|
||||
|
||||
> ### 🫵 あなたの助けが必要です!
|
||||
> 翻訳はコミュニティの努力により支えられています。もしあなたがタグファイルを翻訳したことがある場合、または作成したい場合は、あなたの成果をここに追加できるように、Pull RequestまたはIssueを開いてください。
|
||||
|
||||
@@ -13,6 +13,12 @@
|
||||
你可以按照[以下方法](#installation)下载或拷贝文件,也可以使用[Releases](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/releases)中打包好的文件。
|
||||
|
||||
## 常见问题 & 已知缺陷:
|
||||
- 很多中国用户都报告过此扩展名和其他扩展名的 JavaScript 文件被阻止的问题。
|
||||
常见的罪魁祸首是 IDM / Internet Download Manager 浏览器插件,它似乎出于安全目的阻止了本地文件请求。
|
||||
如果您安装了 IDM,请确保在使用 webui 时禁用以下插件:
|
||||
|
||||

|
||||
|
||||
- 当`replaceUnderscores`选项开启时, 脚本只会替换Tag的一部分如果Tag包含多个单词,比如将`atago (azur lane)`修改`atago`为`taihou`并使用自动补全时.会得到 `taihou (azur lane), lane)`的结果, 因为脚本没有把后面的部分认为成同一个Tag。
|
||||
|
||||
## 演示与截图
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
var TAC_CFG = null;
|
||||
var tagBasePath = "";
|
||||
var modelKeywordPath = "";
|
||||
var tacSelfTrigger = false;
|
||||
|
||||
// Tag completion data loaded from files
|
||||
var allTags = [];
|
||||
@@ -18,6 +19,7 @@ var loras = [];
|
||||
var lycos = [];
|
||||
var modelKeywordDict = new Map();
|
||||
var chants = [];
|
||||
var styleNames = [];
|
||||
|
||||
// Selected model info for black/whitelisting
|
||||
var currentModelHash = "";
|
||||
@@ -36,6 +38,7 @@ let hideBlocked = false;
|
||||
// Tag selection for keyboard navigation
|
||||
var selectedTag = null;
|
||||
var oldSelectedTag = null;
|
||||
var resultCountBeforeNormalTags = 0;
|
||||
|
||||
// Lora keyword undo/redo history
|
||||
var textBeforeKeywordInsertion = "";
|
||||
|
||||
@@ -12,7 +12,8 @@ const ResultType = Object.freeze({
|
||||
"hypernetwork": 8,
|
||||
"lora": 9,
|
||||
"lyco": 10,
|
||||
"chant": 11
|
||||
"chant": 11,
|
||||
"styleName": 12
|
||||
});
|
||||
|
||||
// Class to hold result data and annotations to make it clearer to use
|
||||
@@ -23,10 +24,12 @@ class AutocompleteResult {
|
||||
|
||||
// Additional info, only used in some cases
|
||||
category = null;
|
||||
count = null;
|
||||
count = Number.MAX_SAFE_INTEGER;
|
||||
usageBias = null;
|
||||
aliases = null;
|
||||
meta = null;
|
||||
hash = null;
|
||||
sortKey = null;
|
||||
|
||||
// Constructor
|
||||
constructor(text, type) {
|
||||
|
||||
@@ -9,7 +9,11 @@ const core = [
|
||||
"#img2img_prompt > label > textarea",
|
||||
"#txt2img_neg_prompt > label > textarea",
|
||||
"#img2img_neg_prompt > label > textarea",
|
||||
".prompt > label > textarea"
|
||||
".prompt > label > textarea",
|
||||
"#txt2img_edit_style_prompt > label > textarea",
|
||||
"#txt2img_edit_style_neg_prompt > label > textarea",
|
||||
"#img2img_edit_style_prompt > label > textarea",
|
||||
"#img2img_edit_style_neg_prompt > label > textarea"
|
||||
];
|
||||
|
||||
// Third party text area selectors
|
||||
@@ -57,6 +61,38 @@ const thirdParty = {
|
||||
"[id^=MD-i2i][id$=prompt] textarea",
|
||||
"[id^=MD-i2i][id$=prompt] input[type='text']"
|
||||
]
|
||||
},
|
||||
"adetailer-t2i": {
|
||||
"base": "#txt2img_script_container",
|
||||
"hasIds": true,
|
||||
"onDemand": true,
|
||||
"selectors": [
|
||||
"[id^=script_txt2img_adetailer_ad_prompt] textarea",
|
||||
"[id^=script_txt2img_adetailer_ad_negative_prompt] textarea"
|
||||
]
|
||||
},
|
||||
"adetailer-i2i": {
|
||||
"base": "#img2img_script_container",
|
||||
"hasIds": true,
|
||||
"onDemand": true,
|
||||
"selectors": [
|
||||
"[id^=script_img2img_adetailer_ad_prompt] textarea",
|
||||
"[id^=script_img2img_adetailer_ad_negative_prompt] textarea"
|
||||
]
|
||||
},
|
||||
"deepdanbooru-object-recognition": {
|
||||
"base": "#tab_deepdanboru_object_recg_tab",
|
||||
"hasIds": false,
|
||||
"selectors": [
|
||||
"Found tags",
|
||||
]
|
||||
},
|
||||
"TIPO": {
|
||||
"base": "#tab_txt2img",
|
||||
"hasIds": false,
|
||||
"selectors": [
|
||||
"Tag Prompt"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +126,7 @@ function addOnDemandObservers(setupFunction) {
|
||||
|
||||
let base = gradioApp().querySelector(entry.base);
|
||||
if (!base) continue;
|
||||
|
||||
|
||||
let accordions = [...base?.querySelectorAll(".gradio-accordion")];
|
||||
if (!accordions) continue;
|
||||
|
||||
@@ -115,12 +151,12 @@ function addOnDemandObservers(setupFunction) {
|
||||
[...gradioApp().querySelectorAll(entry.selectors.join(", "))].forEach(x => setupFunction(x));
|
||||
} else { // Otherwise, we have to find the text areas by their adjacent labels
|
||||
let base = gradioApp().querySelector(entry.base);
|
||||
|
||||
|
||||
// Safety check
|
||||
if (!base) continue;
|
||||
|
||||
|
||||
let allTextAreas = [...base.querySelectorAll("textarea, input[type='text']")];
|
||||
|
||||
|
||||
// Filter the text areas where the adjacent label matches one of the selectors
|
||||
let matchingTextAreas = allTextAreas.filter(ta => [...ta.parentElement.childNodes].some(x => entry.selectors.includes(x.innerText)));
|
||||
matchingTextAreas.forEach(x => setupFunction(x));
|
||||
@@ -165,4 +201,4 @@ function getTextAreaIdentifier(textArea) {
|
||||
break;
|
||||
}
|
||||
return modifier;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
// Utility functions for tag autocomplete
|
||||
|
||||
// Parse the CSV file into a 2D array. Doesn't use regex, so it is very lightweight.
|
||||
// We are ignoring newlines in quote fields since we expect one-line entries and parsing would break for unclosed quotes otherwise
|
||||
function parseCSV(str) {
|
||||
var arr = [];
|
||||
var quote = false; // 'true' means we're inside a quoted field
|
||||
const arr = [];
|
||||
let quote = false; // 'true' means we're inside a quoted field
|
||||
|
||||
// Iterate over each character, keep track of current row and column (of the returned array)
|
||||
for (var row = 0, col = 0, c = 0; c < str.length; c++) {
|
||||
var cc = str[c], nc = str[c + 1]; // Current character, next character
|
||||
for (let row = 0, col = 0, c = 0; c < str.length; c++) {
|
||||
let cc = str[c], nc = str[c+1]; // Current character, next character
|
||||
arr[row] = arr[row] || []; // Create a new row if necessary
|
||||
arr[row][col] = arr[row][col] || ''; // Create a new column (start with empty string) if necessary
|
||||
|
||||
@@ -22,14 +23,12 @@ function parseCSV(str) {
|
||||
// If it's a comma and we're not in a quoted field, move on to the next column
|
||||
if (cc == ',' && !quote) { ++col; continue; }
|
||||
|
||||
// If it's a newline (CRLF) and we're not in a quoted field, skip the next character
|
||||
// and move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\r' && nc == '\n' && !quote) { ++row; col = 0; ++c; continue; }
|
||||
// If it's a newline (CRLF), skip the next character and move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\r' && nc == '\n') { ++row; col = 0; ++c; quote = false; continue; }
|
||||
|
||||
// If it's a newline (LF or CR) and we're not in a quoted field,
|
||||
// move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\n' && !quote) { ++row; col = 0; continue; }
|
||||
if (cc == '\r' && !quote) { ++row; col = 0; continue; }
|
||||
// If it's a newline (LF or CR) move on to the next row and move to column 0 of that new row
|
||||
if (cc == '\n') { ++row; col = 0; quote = false; continue; }
|
||||
if (cc == '\r') { ++row; col = 0; quote = false; continue; }
|
||||
|
||||
// Otherwise, append the current character to the current column
|
||||
arr[row][col] += cc;
|
||||
@@ -41,7 +40,7 @@ function parseCSV(str) {
|
||||
async function readFile(filePath, json = false, cache = false) {
|
||||
if (!cache)
|
||||
filePath += `?${new Date().getTime()}`;
|
||||
|
||||
|
||||
let response = await fetch(`file=${filePath}`);
|
||||
|
||||
if (response.status != 200) {
|
||||
@@ -62,7 +61,7 @@ async function loadCSV(path) {
|
||||
}
|
||||
|
||||
// Fetch API
|
||||
async function fetchAPI(url, json = true, cache = false) {
|
||||
async function fetchTacAPI(url, json = true, cache = false) {
|
||||
if (!cache) {
|
||||
const appendChar = url.includes("?") ? "&" : "?";
|
||||
url += `${appendChar}${new Date().getTime()}`
|
||||
@@ -81,13 +80,70 @@ async function fetchAPI(url, json = true, cache = false) {
|
||||
return await response.text();
|
||||
}
|
||||
|
||||
// Extra network preview thumbnails
|
||||
async function getExtraNetworkPreviewURL(filename, type) {
|
||||
const previewJSON = await fetchAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
|
||||
if (previewJSON?.url)
|
||||
return `file=${previewJSON.url}`;
|
||||
else
|
||||
async function postTacAPI(url, body = null) {
|
||||
let response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: body
|
||||
});
|
||||
|
||||
if (response.status != 200) {
|
||||
console.error(`Error posting to API endpoint "${url}": ` + response.status, response.statusText);
|
||||
return null;
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async function putTacAPI(url, body = null) {
|
||||
let response = await fetch(url, { method: "PUT", body: body });
|
||||
|
||||
if (response.status != 200) {
|
||||
console.error(`Error putting to API endpoint "${url}": ` + response.status, response.statusText);
|
||||
return null;
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
// Extra network preview thumbnails
|
||||
async function getTacExtraNetworkPreviewURL(filename, type) {
|
||||
const previewJSON = await fetchTacAPI(`tacapi/v1/thumb-preview/${filename}?type=${type}`, true, true);
|
||||
if (previewJSON?.url) {
|
||||
const properURL = `sd_extra_networks/thumb?filename=${previewJSON.url}`;
|
||||
if ((await fetch(properURL)).status == 200) {
|
||||
return properURL;
|
||||
} else {
|
||||
// create blob url
|
||||
const blob = await (await fetch(`tacapi/v1/thumb-preview-blob/${filename}?type=${type}`)).blob();
|
||||
return URL.createObjectURL(blob);
|
||||
}
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
lastStyleRefresh = 0;
|
||||
// Refresh style file if needed
|
||||
async function refreshStyleNamesIfChanged() {
|
||||
// Only refresh once per second
|
||||
currentTimestamp = new Date().getTime();
|
||||
if (currentTimestamp - lastStyleRefresh < 1000) return;
|
||||
lastStyleRefresh = currentTimestamp;
|
||||
|
||||
const response = await fetch(`tacapi/v1/refresh-styles-if-changed?${new Date().getTime()}`)
|
||||
if (response.status === 304) {
|
||||
// Not modified
|
||||
} else if (response.status === 200) {
|
||||
// Reload
|
||||
QUEUE_FILE_LOAD.forEach(async fn => {
|
||||
if (fn.toString().includes("styleNames"))
|
||||
await fn.call(null, true);
|
||||
})
|
||||
} else {
|
||||
// Error
|
||||
console.error(`Error refreshing styles.txt: ` + response.status, response.statusText);
|
||||
}
|
||||
}
|
||||
|
||||
// Debounce function to prevent spamming the autocomplete function
|
||||
@@ -138,7 +194,115 @@ function flatten(obj, roots = [], sep = ".") {
|
||||
{}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Calculate biased tag score based on post count and frequent usage
|
||||
function calculateUsageBias(result, count, uses) {
|
||||
// Check setting conditions
|
||||
if (uses < TAC_CFG.frequencyMinCount) {
|
||||
uses = 0;
|
||||
} else if (uses != 0) {
|
||||
result.usageBias = true;
|
||||
}
|
||||
|
||||
switch (TAC_CFG.frequencyFunction) {
|
||||
case "Logarithmic (weak)":
|
||||
return Math.log(1 + count) + Math.log(1 + uses);
|
||||
case "Logarithmic (strong)":
|
||||
return Math.log(1 + count) + 2 * Math.log(1 + uses);
|
||||
case "Usage first":
|
||||
return uses;
|
||||
default:
|
||||
return count;
|
||||
}
|
||||
}
|
||||
// Beautify return type for easier parsing
|
||||
function mapUseCountArray(useCounts, posAndNeg = false) {
|
||||
return useCounts.map(useCount => {
|
||||
if (posAndNeg) {
|
||||
return {
|
||||
"name": useCount[0],
|
||||
"type": useCount[1],
|
||||
"count": useCount[2],
|
||||
"negCount": useCount[3],
|
||||
"lastUseDate": useCount[4]
|
||||
}
|
||||
}
|
||||
return {
|
||||
"name": useCount[0],
|
||||
"type": useCount[1],
|
||||
"count": useCount[2],
|
||||
"lastUseDate": useCount[3]
|
||||
}
|
||||
});
|
||||
}
|
||||
// Call API endpoint to increase bias of tag in the database
|
||||
function increaseUseCount(tagName, type, negative = false) {
|
||||
postTacAPI(`tacapi/v1/increase-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`);
|
||||
}
|
||||
// Get use count of tag from the database
|
||||
async function getUseCount(tagName, type, negative = false) {
|
||||
const response = await fetchTacAPI(`tacapi/v1/get-use-count?tagname=${tagName}&ttype=${type}&neg=${negative}`, true, false);
|
||||
// Guard for no db
|
||||
if (response == null) return null;
|
||||
// Result
|
||||
return response["result"];
|
||||
}
|
||||
async function getUseCounts(tagNames, types, negative = false) {
|
||||
// While semantically weird, we have to use POST here for the body, as urls are limited in length
|
||||
const body = JSON.stringify({"tagNames": tagNames, "tagTypes": types, "neg": negative});
|
||||
const response = await postTacAPI(`tacapi/v1/get-use-count-list`, body)
|
||||
// Guard for no db
|
||||
if (response == null) return null;
|
||||
// Results
|
||||
return mapUseCountArray(response["result"]);
|
||||
}
|
||||
async function getAllUseCounts() {
|
||||
const response = await fetchTacAPI(`tacapi/v1/get-all-use-counts`);
|
||||
// Guard for no db
|
||||
if (response == null) return null;
|
||||
// Results
|
||||
return mapUseCountArray(response["result"], true);
|
||||
}
|
||||
async function resetUseCount(tagName, type, resetPosCount, resetNegCount) {
|
||||
await putTacAPI(`tacapi/v1/reset-use-count?tagname=${tagName}&ttype=${type}&pos=${resetPosCount}&neg=${resetNegCount}`);
|
||||
}
|
||||
|
||||
function createTagUsageTable(tagCounts) {
|
||||
// Create table
|
||||
let tagTable = document.createElement("table");
|
||||
tagTable.innerHTML =
|
||||
`<thead>
|
||||
<tr>
|
||||
<td>Name</td>
|
||||
<td>Type</td>
|
||||
<td>Count(+)</td>
|
||||
<td>Count(-)</td>
|
||||
<td>Last used</td>
|
||||
</tr>
|
||||
</thead>`;
|
||||
tagTable.id = "tac_tagUsageTable"
|
||||
|
||||
tagCounts.forEach(t => {
|
||||
let tr = document.createElement("tr");
|
||||
|
||||
// Fill values
|
||||
let values = [t.name, t.type-1, t.count, t.negCount, t.lastUseDate]
|
||||
values.forEach(v => {
|
||||
let td = document.createElement("td");
|
||||
td.innerText = v;
|
||||
tr.append(td);
|
||||
});
|
||||
// Add delete/reset button
|
||||
let delButton = document.createElement("button");
|
||||
delButton.innerText = "🗑️";
|
||||
delButton.title = "Reset count";
|
||||
tr.append(delButton);
|
||||
|
||||
tagTable.append(tr)
|
||||
});
|
||||
|
||||
return tagTable;
|
||||
}
|
||||
|
||||
// Sliding window function to get possible combination groups of an array
|
||||
function toNgrams(inputArray, size) {
|
||||
@@ -148,7 +312,11 @@ function toNgrams(inputArray, size) {
|
||||
);
|
||||
}
|
||||
|
||||
function escapeRegExp(string) {
|
||||
function escapeRegExp(string, wildcardMatching = false) {
|
||||
if (wildcardMatching) {
|
||||
// Escape all characters except asterisks and ?, which should be treated separately as placeholders.
|
||||
return string.replace(/[-[\]{}()+.,\\^$|#\s]/g, '\\$&').replace(/\*/g, '.*').replace(/\?/g, '.');
|
||||
}
|
||||
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
|
||||
}
|
||||
function escapeHTML(unsafeText) {
|
||||
@@ -192,6 +360,49 @@ function observeElement(element, property, callback, delay = 0) {
|
||||
}
|
||||
}
|
||||
|
||||
// Sort functions
|
||||
function getSortFunction() {
|
||||
let criterion = TAC_CFG.modelSortOrder || "Name";
|
||||
|
||||
const textSort = (a, b, reverse = false) => {
|
||||
// Assign keys so next sort is faster
|
||||
if (!a.sortKey) {
|
||||
a.sortKey = a.type === ResultType.chant
|
||||
? a.aliases
|
||||
: a.text;
|
||||
}
|
||||
if (!b.sortKey) {
|
||||
b.sortKey = b.type === ResultType.chant
|
||||
? b.aliases
|
||||
: b.text;
|
||||
}
|
||||
|
||||
return reverse ? b.sortKey.localeCompare(a.sortKey) : a.sortKey.localeCompare(b.sortKey);
|
||||
}
|
||||
const numericSort = (a, b, reverse = false) => {
|
||||
const noKey = reverse ? "-1" : Number.MAX_SAFE_INTEGER;
|
||||
let aParsed = parseFloat(a.sortKey || noKey);
|
||||
let bParsed = parseFloat(b.sortKey || noKey);
|
||||
|
||||
if (aParsed === bParsed) {
|
||||
return textSort(a, b, false);
|
||||
}
|
||||
|
||||
return reverse ? bParsed - aParsed : aParsed - bParsed;
|
||||
}
|
||||
|
||||
return (a, b) => {
|
||||
switch (criterion) {
|
||||
case "Date Modified (newest first)":
|
||||
return numericSort(a, b, true);
|
||||
case "Date Modified (oldest first)":
|
||||
return numericSort(a, b, false);
|
||||
default:
|
||||
return textSort(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Queue calling function to process global queues
|
||||
async function processQueue(queue, context, ...args) {
|
||||
for (let i = 0; i < queue.length; i++) {
|
||||
|
||||
@@ -7,7 +7,10 @@ class ChantParser extends BaseTagParser {
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<c:") {
|
||||
let searchTerm = tagword.replace("<chant:", "").replace("<c:", "").replace("<", "");
|
||||
let filterCondition = x => x.terms.toLowerCase().includes(searchTerm) || x.name.toLowerCase().includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.terms.toLowerCase()) || regex.test(x.name.toLowerCase());
|
||||
};
|
||||
tempResults = chants.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = chants;
|
||||
@@ -41,7 +44,7 @@ async function load() {
|
||||
|
||||
function sanitize(tagType, text) {
|
||||
if (tagType === ResultType.chant) {
|
||||
return text.replace(/^.*?: /g, "");
|
||||
return text;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -51,4 +54,4 @@ PARSERS.push(new ChantParser(CHANT_TRIGGER));
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_AFTER_CONFIG_CHANGE.push(load);
|
||||
QUEUE_AFTER_CONFIG_CHANGE.push(load);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const EMB_REGEX = /<(?!l:|h:|c:)[^,> ]*>?/g;
|
||||
const EMB_TRIGGER = () => TAC_CFG.useEmbeddings && tagword.match(EMB_REGEX);
|
||||
const EMB_TRIGGER = () => TAC_CFG.useEmbeddings && (tagword.match(EMB_REGEX) || TAC_CFG.includeEmbeddingsInNormalResults);
|
||||
|
||||
class EmbeddingParser extends BaseTagParser {
|
||||
parse() {
|
||||
@@ -11,12 +11,18 @@ class EmbeddingParser extends BaseTagParser {
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
searchTerm = searchTerm.slice(2);
|
||||
} else if (searchTerm.startsWith("vxl")) {
|
||||
versionString = searchTerm.slice(0, 3);
|
||||
searchTerm = searchTerm.slice(3);
|
||||
}
|
||||
|
||||
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x[0].toLowerCase()) || regex.test(x[0].toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
|
||||
if (versionString)
|
||||
tempResults = embeddings.filter(x => filterCondition(x) && x[1] && x[1] === versionString); // Filter by tagword
|
||||
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2].toLowerCase() === versionString.toLowerCase()); // Filter by tagword
|
||||
else
|
||||
tempResults = embeddings.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
@@ -26,8 +32,13 @@ class EmbeddingParser extends BaseTagParser {
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
|
||||
result.meta = t[1] + " Embedding";
|
||||
let lastDot = t[0].lastIndexOf(".") > -1 ? t[0].lastIndexOf(".") : t[0].length;
|
||||
let lastSlash = t[0].lastIndexOf("/") > -1 ? t[0].lastIndexOf("/") : -1;
|
||||
let name = t[0].trim().substring(lastSlash + 1, lastDot);
|
||||
|
||||
let result = new AutocompleteResult(name, ResultType.embedding)
|
||||
result.sortKey = t[1];
|
||||
result.meta = t[2] + " Embedding";
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
@@ -38,9 +49,9 @@ class EmbeddingParser extends BaseTagParser {
|
||||
async function load() {
|
||||
if (embeddings.length === 0) {
|
||||
try {
|
||||
embeddings = (await readFile(`${tagBasePath}/temp/emb.txt`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => x.trim().split(",")); // Split into name, version type pairs
|
||||
embeddings = (await loadCSV(`${tagBasePath}/temp/emb.txt`))
|
||||
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
|
||||
.map(x => [x[0].trim(), x[1], x[2]]); // Return name, sortKey, hash tuples
|
||||
} catch (e) {
|
||||
console.error("Error loading embeddings.txt: " + e);
|
||||
}
|
||||
@@ -49,7 +60,7 @@ async function load() {
|
||||
|
||||
function sanitize(tagType, text) {
|
||||
if (tagType === ResultType.embedding) {
|
||||
return text.replace(/^.*?: /g, "");
|
||||
return text;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -58,4 +69,4 @@ PARSERS.push(new EmbeddingParser(EMB_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -7,8 +7,11 @@ class HypernetParser extends BaseTagParser {
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<h:" && tagword !== "<hypernet:") {
|
||||
let searchTerm = tagword.replace("<hypernet:", "").replace("<h:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
tempResults = hypernetworks.filter(x => filterCondition(x)); // Filter by tagword
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = hypernetworks.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = hypernetworks;
|
||||
}
|
||||
@@ -16,8 +19,9 @@ class HypernetParser extends BaseTagParser {
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t.trim(), ResultType.hypernetwork)
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.hypernetwork)
|
||||
result.meta = "Hypernetwork";
|
||||
result.sortKey = t[1];
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
@@ -28,9 +32,9 @@ class HypernetParser extends BaseTagParser {
|
||||
async function load() {
|
||||
if (hypernetworks.length === 0) {
|
||||
try {
|
||||
hypernetworks = (await readFile(`${tagBasePath}/temp/hyp.txt`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) //Remove empty lines
|
||||
.map(x => x.trim()); // Remove carriage returns and padding if it exists
|
||||
hypernetworks = (await loadCSV(`${tagBasePath}/temp/hyp.txt`))
|
||||
.filter(x => x[0]?.trim().length > 0) //Remove empty lines
|
||||
.map(x => [x[0]?.trim(), x[1]]); // Remove carriage returns and padding if it exists
|
||||
} catch (e) {
|
||||
console.error("Error loading hypernetworks.txt: " + e);
|
||||
}
|
||||
@@ -48,4 +52,4 @@ PARSERS.push(new HypernetParser(HYP_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -7,7 +7,10 @@ class LoraParser extends BaseTagParser {
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lora:") {
|
||||
let searchTerm = tagword.replace("<lora:", "").replace("<l:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = loras.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = loras;
|
||||
@@ -23,7 +26,8 @@ class LoraParser extends BaseTagParser {
|
||||
|
||||
let result = new AutocompleteResult(name, ResultType.lora)
|
||||
result.meta = "Lora";
|
||||
result.hash = t[1];
|
||||
result.sortKey = t[1];
|
||||
result.hash = t[2];
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
@@ -36,7 +40,7 @@ async function load() {
|
||||
try {
|
||||
loras = (await loadCSV(`${tagBasePath}/temp/lora.txt`))
|
||||
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
|
||||
.map(x => [x[0]?.trim(), x[1]]); // Trim filenames and return the name, hash pairs
|
||||
.map(x => [x[0]?.trim(), x[1], x[2]]); // Trim filenames and return the name, sortKey, hash pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading lora.txt: " + e);
|
||||
}
|
||||
@@ -46,7 +50,7 @@ async function load() {
|
||||
async function sanitize(tagType, text) {
|
||||
if (tagType === ResultType.lora) {
|
||||
let multiplier = TAC_CFG.extraNetworksDefaultMultiplier;
|
||||
let info = await fetchAPI(`tacapi/v1/lora-info/${text}`)
|
||||
let info = await fetchTacAPI(`tacapi/v1/lora-info/${text}`)
|
||||
if (info && info["preferred weight"]) {
|
||||
multiplier = info["preferred weight"];
|
||||
}
|
||||
@@ -60,4 +64,4 @@ PARSERS.push(new LoraParser(LORA_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -5,9 +5,12 @@ class LycoParser extends BaseTagParser {
|
||||
parse() {
|
||||
// Show lyco
|
||||
let tempResults = [];
|
||||
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lyco:") {
|
||||
let searchTerm = tagword.replace("<lyco:", "").replace("<l:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lyco:" && tagword !== "<lora:") {
|
||||
let searchTerm = tagword.replace("<lyco:", "").replace("<lora:", "").replace("<l:", "").replace("<", "");
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x.toLowerCase()) || regex.test(x.toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = lycos.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = lycos;
|
||||
@@ -23,7 +26,8 @@ class LycoParser extends BaseTagParser {
|
||||
|
||||
let result = new AutocompleteResult(name, ResultType.lyco)
|
||||
result.meta = "Lyco";
|
||||
result.hash = t[1];
|
||||
result.sortKey = t[1];
|
||||
result.hash = t[2];
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
@@ -36,7 +40,7 @@ async function load() {
|
||||
try {
|
||||
lycos = (await loadCSV(`${tagBasePath}/temp/lyco.txt`))
|
||||
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
|
||||
.map(x => [x[0]?.trim(), x[1]]); // Trim filenames and return the name, hash pairs
|
||||
.map(x => [x[0]?.trim(), x[1], x[2]]); // Trim filenames and return the name, sortKey, hash pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading lyco.txt: " + e);
|
||||
}
|
||||
@@ -46,12 +50,13 @@ async function load() {
|
||||
async function sanitize(tagType, text) {
|
||||
if (tagType === ResultType.lyco) {
|
||||
let multiplier = TAC_CFG.extraNetworksDefaultMultiplier;
|
||||
let info = await fetchAPI(`tacapi/v1/lyco-info/${text}`)
|
||||
let info = await fetchTacAPI(`tacapi/v1/lyco-info/${text}`)
|
||||
if (info && info["preferred weight"]) {
|
||||
multiplier = info["preferred weight"];
|
||||
}
|
||||
|
||||
return `<lyco:${text}:${multiplier}>`;
|
||||
let prefix = TAC_CFG.useLoraPrefixForLycos ? "lora" : "lyco";
|
||||
return `<${prefix}:${text}:${multiplier}>`;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -60,4 +65,4 @@ PARSERS.push(new LycoParser(LYCO_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
|
||||
@@ -20,7 +20,7 @@ async function load() {
|
||||
// Add to the dict
|
||||
csv_lines.forEach(parts => {
|
||||
const hash = parts[0];
|
||||
const keywords = parts[1].replaceAll("| ", ", ").replaceAll("|", ", ").trim();
|
||||
const keywords = parts[1]?.replaceAll("| ", ", ")?.replaceAll("|", ", ")?.trim();
|
||||
const lastSepIndex = parts[2]?.lastIndexOf("/") + 1 || parts[2]?.lastIndexOf("\\") + 1 || 0;
|
||||
const name = parts[2]?.substring(lastSepIndex).trim() || "none"
|
||||
|
||||
|
||||
70
javascript/ext_styles.js
Normal file
70
javascript/ext_styles.js
Normal file
@@ -0,0 +1,70 @@
|
||||
const STYLE_REGEX = /(\$(\d*)\(?)[^$|\[\],\s]*\)?/;
|
||||
const STYLE_TRIGGER = () => TAC_CFG.useStyleVars && tagword.match(STYLE_REGEX);
|
||||
|
||||
var lastStyleVarIndex = "";
|
||||
|
||||
class StyleParser extends BaseTagParser {
|
||||
async parse() {
|
||||
// Refresh if needed
|
||||
await refreshStyleNamesIfChanged();
|
||||
|
||||
// Show styles
|
||||
let tempResults = [];
|
||||
let matchGroups = tagword.match(STYLE_REGEX);
|
||||
|
||||
// Save index to insert again later or clear last one
|
||||
lastStyleVarIndex = matchGroups[2] ? matchGroups[2] : "";
|
||||
|
||||
if (tagword !== matchGroups[1]) {
|
||||
let searchTerm = tagword.replace(matchGroups[1], "");
|
||||
|
||||
let filterCondition = x => {
|
||||
let regex = new RegExp(escapeRegExp(searchTerm, true), 'i');
|
||||
return regex.test(x[0].toLowerCase()) || regex.test(x[0].toLowerCase().replaceAll(" ", "_"));
|
||||
};
|
||||
tempResults = styleNames.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = styleNames;
|
||||
}
|
||||
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.styleName)
|
||||
result.meta = "Style";
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
return finalResults;
|
||||
}
|
||||
}
|
||||
|
||||
async function load(force = false) {
|
||||
if (styleNames.length === 0 || force) {
|
||||
try {
|
||||
styleNames = (await loadCSV(`${tagBasePath}/temp/styles.txt`))
|
||||
.filter(x => x[0]?.trim().length > 0) // Remove empty lines
|
||||
.filter(x => x[0] !== "None") // Remove "None" style
|
||||
.map(x => [x[0].trim()]); // Trim name
|
||||
} catch (e) {
|
||||
console.error("Error loading styles.txt: " + e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function sanitize(tagType, text) {
|
||||
if (tagType === ResultType.styleName) {
|
||||
if (text.includes(" ")) {
|
||||
return `$${lastStyleVarIndex}(${text})`;
|
||||
} else {
|
||||
return`$${lastStyleVarIndex}${text}`
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
PARSERS.push(new StyleParser(STYLE_TRIGGER));
|
||||
|
||||
// Add our utility functions to their respective queues
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
QUEUE_SANITIZE.push(sanitize);
|
||||
@@ -7,7 +7,7 @@ class UmiParser extends BaseTagParser {
|
||||
parse(textArea, prompt) {
|
||||
// We are in a UMI yaml tag definition, parse further
|
||||
let umiSubPrompts = [...prompt.matchAll(UMI_PROMPT_REGEX)];
|
||||
|
||||
|
||||
let umiTags = [];
|
||||
let umiTagsWithOperators = []
|
||||
|
||||
@@ -15,7 +15,7 @@ class UmiParser extends BaseTagParser {
|
||||
|
||||
umiSubPrompts.forEach(umiSubPrompt => {
|
||||
umiTags = umiTags.concat([...umiSubPrompt[0].matchAll(UMI_TAG_REGEX)].map(x => x[1].toLowerCase()));
|
||||
|
||||
|
||||
const start = umiSubPrompt.index;
|
||||
const end = umiSubPrompt.index + umiSubPrompt[0].length;
|
||||
if (textArea.selectionStart >= start && textArea.selectionStart <= end) {
|
||||
@@ -113,7 +113,7 @@ class UmiParser extends BaseTagParser {
|
||||
|| !matches.all.includes(x[0])
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
if (umiTags.length > 0) {
|
||||
// Get difference for subprompt
|
||||
let tagCountChange = umiTags.length - umiPreviousTags.length;
|
||||
@@ -129,7 +129,7 @@ class UmiParser extends BaseTagParser {
|
||||
return;
|
||||
}
|
||||
|
||||
let umiTagword = diff[0] || '';
|
||||
let umiTagword = tagCountChange < 0 ? '' : diff[0] || '';
|
||||
let tempResults = [];
|
||||
if (umiTagword && umiTagword.length > 0) {
|
||||
umiTagword = umiTagword.toLowerCase().replace(/[\n\r]/g, "");
|
||||
@@ -149,10 +149,11 @@ class UmiParser extends BaseTagParser {
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
finalResults = finalResults.sort((a, b) => b.count - a.count);
|
||||
return finalResults;
|
||||
} else if (showAll) {
|
||||
let filteredWildcardsSorted = filteredWildcards("");
|
||||
|
||||
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
filteredWildcardsSorted.forEach(t => {
|
||||
@@ -160,14 +161,16 @@ class UmiParser extends BaseTagParser {
|
||||
result.count = t[1];
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
|
||||
originalTagword = tagword;
|
||||
tagword = "";
|
||||
|
||||
finalResults = finalResults.sort((a, b) => b.count - a.count);
|
||||
return finalResults;
|
||||
}
|
||||
} else {
|
||||
let filteredWildcardsSorted = filteredWildcards("");
|
||||
|
||||
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
filteredWildcardsSorted.forEach(t => {
|
||||
@@ -178,12 +181,14 @@ class UmiParser extends BaseTagParser {
|
||||
|
||||
originalTagword = tagword;
|
||||
tagword = "";
|
||||
|
||||
finalResults = finalResults.sort((a, b) => b.count - a.count);
|
||||
return finalResults;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function updateUmiTags( tagType, sanitizedText, newPrompt, textArea) {
|
||||
function updateUmiTags(tagType, sanitizedText, newPrompt, textArea) {
|
||||
// If it was a umi wildcard, also update the umiPreviousTags
|
||||
if (tagType === ResultType.umiWildcard && originalTagword.length > 0) {
|
||||
let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)];
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// Regex
|
||||
const WC_REGEX = /\b__([^,]+)__([^, ]*)\b/g;
|
||||
const WC_REGEX = new RegExp(/__([^,]+)__([^, ]*)/g);
|
||||
|
||||
// Trigger conditions
|
||||
const WC_TRIGGER = () => TAC_CFG.useWildcards && [...tagword.matchAll(WC_REGEX)].length > 0;
|
||||
const WC_FILE_TRIGGER = () => TAC_CFG.useWildcards && (tagword.startsWith("__") && !tagword.endsWith("__") || tagword === "__");
|
||||
const WC_TRIGGER = () => TAC_CFG.useWildcards && [...tagword.matchAll(new RegExp(WC_REGEX.source.replaceAll("__", escapeRegExp(TAC_CFG.wcWrap)), "g"))].length > 0;
|
||||
const WC_FILE_TRIGGER = () => TAC_CFG.useWildcards && (tagword.startsWith(TAC_CFG.wcWrap) && !tagword.endsWith(TAC_CFG.wcWrap) || tagword === TAC_CFG.wcWrap);
|
||||
|
||||
class WildcardParser extends BaseTagParser {
|
||||
async parse() {
|
||||
// Show wildcards from a file with that name
|
||||
let wcMatch = [...tagword.matchAll(WC_REGEX)]
|
||||
let wcMatch = [...tagword.matchAll(new RegExp(WC_REGEX.source.replaceAll("__", escapeRegExp(TAC_CFG.wcWrap)), "g"))];
|
||||
let wcFile = wcMatch[0][1];
|
||||
let wcWord = wcMatch[0][2];
|
||||
|
||||
@@ -19,13 +19,16 @@ class WildcardParser extends BaseTagParser {
|
||||
let wcPairs = wcFound || wildcardExtFiles.filter(x => x[1].toLowerCase() === wcFile);
|
||||
|
||||
if (!wcPairs) return [];
|
||||
|
||||
|
||||
let wildcards = [];
|
||||
for (let i = 0; i < wcPairs.length; i++) {
|
||||
const wcPair = wcPairs[i];
|
||||
if (!wcPair[0] || !wcPair[1]) continue;
|
||||
const basePath = wcPairs[i][0];
|
||||
const fileName = wcPairs[i][1];
|
||||
if (!basePath || !fileName) return;
|
||||
|
||||
if (wcPair[0].endsWith(".yaml")) {
|
||||
// YAML wildcards are already loaded as json, so we can get the values directly.
|
||||
// basePath is the name of the file in this case, and fileName the key
|
||||
if (basePath.endsWith(".yaml")) {
|
||||
const getDescendantProp = (obj, desc) => {
|
||||
const arr = desc.split("/");
|
||||
while (arr.length) {
|
||||
@@ -33,10 +36,11 @@ class WildcardParser extends BaseTagParser {
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
wildcards = wildcards.concat(getDescendantProp(yamlWildcards[wcPair[0]], wcPair[1]));
|
||||
wildcards = wildcards.concat(getDescendantProp(yamlWildcards[basePath], fileName));
|
||||
} else {
|
||||
const fileContent = (await readFile(`${wcPair[0]}/${wcPair[1]}.txt`)).split("\n")
|
||||
.filter(x => x.trim().length > 0 && !x.startsWith('#')); // Remove empty lines and comments
|
||||
const fileContent = (await fetchTacAPI(`tacapi/v1/wildcard-contents?basepath=${basePath}&filename=${fileName}.txt`, false))
|
||||
.split("\n")
|
||||
.filter(x => x.trim().length > 0 && !x.startsWith('#')); // Remove empty lines and comments
|
||||
wildcards = wildcards.concat(fileContent);
|
||||
}
|
||||
}
|
||||
@@ -60,8 +64,8 @@ class WildcardFileParser extends BaseTagParser {
|
||||
parse() {
|
||||
// Show available wildcard files
|
||||
let tempResults = [];
|
||||
if (tagword !== "__") {
|
||||
let lmb = (x) => x[1].toLowerCase().includes(tagword.replace("__", ""))
|
||||
if (tagword !== TAC_CFG.wcWrap) {
|
||||
let lmb = (x) => x[1].toLowerCase().includes(tagword.replace(TAC_CFG.wcWrap, ""))
|
||||
tempResults = wildcardFiles.filter(lmb).concat(wildcardExtFiles.filter(lmb)) // Filter by tagword
|
||||
} else {
|
||||
tempResults = wildcardFiles.concat(wildcardExtFiles);
|
||||
@@ -81,13 +85,14 @@ class WildcardFileParser extends BaseTagParser {
|
||||
} else {
|
||||
result = new AutocompleteResult(wcFile[1].trim(), ResultType.wildcardFile);
|
||||
result.meta = "Wildcard file";
|
||||
result.sortKey = wcFile[2].trim();
|
||||
}
|
||||
|
||||
|
||||
finalResults.push(result);
|
||||
alreadyAdded.set(wcFile[1], true);
|
||||
});
|
||||
|
||||
finalResults.sort((a, b) => a.text.localeCompare(b.text));
|
||||
finalResults.sort(getSortFunction());
|
||||
|
||||
return finalResults;
|
||||
}
|
||||
@@ -96,17 +101,19 @@ class WildcardFileParser extends BaseTagParser {
|
||||
async function load() {
|
||||
if (wildcardFiles.length === 0 && wildcardExtFiles.length === 0) {
|
||||
try {
|
||||
let wcFileArr = (await readFile(`${tagBasePath}/temp/wc.txt`)).split("\n");
|
||||
let wcBasePath = wcFileArr[0].trim(); // First line should be the base path
|
||||
wildcardFiles = wcFileArr.slice(1)
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => [wcBasePath, x.trim().replace(".txt", "")]); // Remove file extension & newlines
|
||||
let wcFileArr = await loadCSV(`${tagBasePath}/temp/wc.txt`);
|
||||
if (wcFileArr && wcFileArr.length > 0) {
|
||||
let wcBasePath = wcFileArr[0][0].trim(); // First line should be the base path
|
||||
wildcardFiles = wcFileArr.slice(1)
|
||||
.filter(x => x[0]?.trim().length > 0) //Remove empty lines
|
||||
.map(x => [wcBasePath, x[0]?.trim().replace(".txt", ""), x[1]]); // Remove file extension & newlines
|
||||
}
|
||||
|
||||
// To support multiple sources, we need to separate them using the provided "-----" strings
|
||||
let wcExtFileArr = (await readFile(`${tagBasePath}/temp/wce.txt`)).split("\n");
|
||||
let wcExtFileArr = await loadCSV(`${tagBasePath}/temp/wce.txt`);
|
||||
let splitIndices = [];
|
||||
for (let index = 0; index < wcExtFileArr.length; index++) {
|
||||
if (wcExtFileArr[index].trim() === "-----") {
|
||||
if (wcExtFileArr[index][0].trim() === "-----") {
|
||||
splitIndices.push(index);
|
||||
}
|
||||
}
|
||||
@@ -117,18 +124,18 @@ async function load() {
|
||||
let end = splitIndices[i];
|
||||
|
||||
let wcExtFile = wcExtFileArr.slice(start, end);
|
||||
let base = wcExtFile[0].trim() + "/";
|
||||
wcExtFile = wcExtFile.slice(1)
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => x.trim().replace(base, "").replace(".txt", "")); // Remove file extension & newlines;
|
||||
|
||||
wcExtFile = wcExtFile.map(x => [base, x]);
|
||||
wildcardExtFiles.push(...wcExtFile);
|
||||
if (wcExtFile && wcExtFile.length > 0) {
|
||||
let base = wcExtFile[0][0].trim() + "/";
|
||||
wcExtFile = wcExtFile.slice(1)
|
||||
.filter(x => x[0]?.trim().length > 0) //Remove empty lines
|
||||
.map(x => [base, x[0]?.trim().replace(base, "").replace(".txt", ""), x[1]]);
|
||||
wildcardExtFiles.push(...wcExtFile);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the yaml wildcard json file and append it as a wildcard file, appending each key as a path component until we reach the end
|
||||
yamlWildcards = await readFile(`${tagBasePath}/temp/wc_yaml.json`, true);
|
||||
|
||||
|
||||
// Append each key as a path component until we reach a leaf
|
||||
Object.keys(yamlWildcards).forEach(file => {
|
||||
const flattened = flatten(yamlWildcards[file], [], "/");
|
||||
@@ -144,9 +151,9 @@ async function load() {
|
||||
|
||||
function sanitize(tagType, text) {
|
||||
if (tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard) {
|
||||
return `__${text}__`;
|
||||
return `${TAC_CFG.wcWrap}${text}${TAC_CFG.wcWrap}`;
|
||||
} else if (tagType === ResultType.wildcardTag) {
|
||||
return text.replace(/^.*?: /g, "");
|
||||
return text;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -155,7 +162,6 @@ function keepOpenIfWildcard(tagType, sanitizedText, newPrompt, textArea) {
|
||||
// If it's a wildcard, we want to keep the results open so the user can select another wildcard
|
||||
if (tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard) {
|
||||
hideBlocked = true;
|
||||
autocomplete(textArea, newPrompt, sanitizedText);
|
||||
setTimeout(() => { hideBlocked = false; }, 450);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const styleColors = {
|
||||
"--results-neutral-text": ["#e0e0e0","black"],
|
||||
"--results-bg": ["#0b0f19", "#ffffff"],
|
||||
"--results-border-color": ["#4b5563", "#e5e7eb"],
|
||||
"--results-border-width": ["1px", "1.5px"],
|
||||
@@ -30,11 +31,12 @@ const autocompleteCSS = `
|
||||
position: absolute;
|
||||
z-index: 999;
|
||||
max-width: calc(100% - 1.5rem);
|
||||
margin: 5px 0 0 0;
|
||||
flex-direction: column; /* Ensure children stack vertically */
|
||||
}
|
||||
.autocompleteResults {
|
||||
background-color: var(--results-bg) !important;
|
||||
border: var(--results-border-width) solid var(--results-border-color) !important;
|
||||
color: var(--results-neutral-text) !important;
|
||||
border-radius: 12px !important;
|
||||
height: fit-content;
|
||||
flex-basis: fit-content;
|
||||
@@ -42,6 +44,7 @@ const autocompleteCSS = `
|
||||
overflow-y: var(--results-overflow-y);
|
||||
overflow-x: hidden;
|
||||
word-break: break-word;
|
||||
margin-top: 10px; /* Margin to create space below the cursor */
|
||||
}
|
||||
.sideInfo {
|
||||
display: none;
|
||||
@@ -84,6 +87,14 @@ const autocompleteCSS = `
|
||||
white-space: nowrap;
|
||||
color: var(--meta-text-color);
|
||||
}
|
||||
.acMetaText.biased::before {
|
||||
content: "✨";
|
||||
margin-right: 2px;
|
||||
}
|
||||
.acMetaText span.used::after {
|
||||
content: "🔁";
|
||||
margin-right: 2px;
|
||||
}
|
||||
.acWikiLink {
|
||||
padding: 0.5rem;
|
||||
margin: -0.5rem 0 -0.5rem -0.5rem;
|
||||
@@ -211,19 +222,32 @@ async function syncOptions() {
|
||||
useWildcards: opts["tac_useWildcards"],
|
||||
sortWildcardResults: opts["tac_sortWildcardResults"],
|
||||
useEmbeddings: opts["tac_useEmbeddings"],
|
||||
includeEmbeddingsInNormalResults: opts["tac_includeEmbeddingsInNormalResults"],
|
||||
useHypernetworks: opts["tac_useHypernetworks"],
|
||||
useLoras: opts["tac_useLoras"],
|
||||
useLycos: opts["tac_useLycos"],
|
||||
useLycos: opts["tac_useLycos"],
|
||||
useLoraPrefixForLycos: opts["tac_useLoraPrefixForLycos"],
|
||||
showWikiLinks: opts["tac_showWikiLinks"],
|
||||
showExtraNetworkPreviews: opts["tac_showExtraNetworkPreviews"],
|
||||
modelSortOrder: opts["tac_modelSortOrder"],
|
||||
frequencySort: opts["tac_frequencySort"],
|
||||
frequencyFunction: opts["tac_frequencyFunction"],
|
||||
frequencyMinCount: opts["tac_frequencyMinCount"],
|
||||
frequencyMaxAge: opts["tac_frequencyMaxAge"],
|
||||
frequencyRecommendCap: opts["tac_frequencyRecommendCap"],
|
||||
frequencyIncludeAlias: opts["tac_frequencyIncludeAlias"],
|
||||
useStyleVars: opts["tac_useStyleVars"],
|
||||
// Insertion related settings
|
||||
replaceUnderscores: opts["tac_replaceUnderscores"],
|
||||
replaceUnderscoresExclusionList: opts["tac_undersocreReplacementExclusionList"],
|
||||
escapeParentheses: opts["tac_escapeParentheses"],
|
||||
appendComma: opts["tac_appendComma"],
|
||||
appendSpace: opts["tac_appendSpace"],
|
||||
alwaysSpaceAtEnd: opts["tac_alwaysSpaceAtEnd"],
|
||||
wildcardCompletionMode: opts["tac_wildcardCompletionMode"],
|
||||
modelKeywordCompletion: opts["tac_modelKeywordCompletion"],
|
||||
modelKeywordLocation: opts["tac_modelKeywordLocation"],
|
||||
wcWrap: opts["dp_parser_wildcard_wrap"] || "__", // to support custom wrapper chars set by dp_parser
|
||||
// Alias settings
|
||||
alias: {
|
||||
searchByAlias: opts["tac_alias.searchByAlias"],
|
||||
@@ -267,6 +291,17 @@ async function syncOptions() {
|
||||
await loadTags(newCFG);
|
||||
}
|
||||
|
||||
// Refresh temp files if model sort order changed
|
||||
// Contrary to the other loads, this one shouldn't happen on a first time load
|
||||
if (TAC_CFG && newCFG.modelSortOrder !== TAC_CFG.modelSortOrder) {
|
||||
const dropdown = gradioApp().querySelector("#setting_tac_modelSortOrder");
|
||||
dropdown.style.opacity = 0.5;
|
||||
dropdown.style.pointerEvents = "none";
|
||||
await refreshTacTempFiles(true);
|
||||
dropdown.style.opacity = null;
|
||||
dropdown.style.pointerEvents = null;
|
||||
}
|
||||
|
||||
// Update CSS if maxResults changed
|
||||
if (TAC_CFG && newCFG.maxResults !== TAC_CFG.maxResults) {
|
||||
gradioApp().querySelectorAll(".autocompleteResults").forEach(r => {
|
||||
@@ -328,10 +363,13 @@ function showResults(textArea) {
|
||||
parentDiv.style.display = "flex";
|
||||
|
||||
if (TAC_CFG.slidingPopup) {
|
||||
let caretPosition = getCaretCoordinates(textArea, textArea.selectionEnd).left;
|
||||
let offset = Math.min(textArea.offsetLeft - textArea.scrollLeft + caretPosition, textArea.offsetWidth - parentDiv.offsetWidth);
|
||||
|
||||
parentDiv.style.left = `${offset}px`;
|
||||
let caretPosition = getCaretCoordinates(textArea, textArea.selectionEnd);
|
||||
// Top cursor offset fix for SDNext modern UI, based on code by https://github.com/Nyx01
|
||||
let offsetTop = textArea.offsetTop + caretPosition.top - textArea.scrollTop + 10; // Adjust this value for desired distance below cursor
|
||||
let offsetLeft = Math.min(textArea.offsetLeft - textArea.scrollLeft + caretPosition.left, textArea.offsetWidth - parentDiv.offsetWidth);
|
||||
|
||||
parentDiv.style.top = `${offsetTop}px`; // Position below the cursor
|
||||
parentDiv.style.left = `${offsetLeft}px`;
|
||||
} else {
|
||||
if (parentDiv.style.left)
|
||||
parentDiv.style.removeProperty("left");
|
||||
@@ -346,9 +384,9 @@ function showResults(textArea) {
|
||||
function hideResults(textArea) {
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let resultsDiv = gradioApp().querySelector('.autocompleteParent' + textAreaId);
|
||||
|
||||
|
||||
if (!resultsDiv) return;
|
||||
|
||||
|
||||
resultsDiv.style.display = "none";
|
||||
selectedTag = null;
|
||||
}
|
||||
@@ -358,12 +396,12 @@ function isEnabled() {
|
||||
if (TAC_CFG.activeIn.global) {
|
||||
// Skip check if the current model was not correctly detected, since it could wrongly disable the script otherwise
|
||||
if (!currentModelName || !currentModelHash) return true;
|
||||
|
||||
|
||||
let modelList = TAC_CFG.activeIn.modelList
|
||||
.split(",")
|
||||
.map(x => x.trim())
|
||||
.filter(x => x.length > 0);
|
||||
|
||||
|
||||
let shortHash = currentModelHash.substring(0, 10);
|
||||
let modelNameWithoutHash = currentModelName.replace(/\[.*\]$/g, "").trim();
|
||||
if (TAC_CFG.activeIn.modelListMode.toLowerCase() === "blacklist") {
|
||||
@@ -382,9 +420,10 @@ function isEnabled() {
|
||||
const WEIGHT_REGEX = /[([]([^()[\]:|]+)(?::(?:\d+(?:\.\d+)?|\.\d+))?[)\]]/g;
|
||||
const POINTY_REGEX = /<[^\s,<](?:[^\t\n\r,<>]*>|[^\t\n\r,> ]*)/g;
|
||||
const COMPLETED_WILDCARD_REGEX = /__[^\s,_][^\t\n\r,_]*[^\s,_]__[^\s,_]*/g;
|
||||
const NORMAL_TAG_REGEX = /[^\s,|<>)\]]+|</g;
|
||||
const STYLE_VAR_REGEX = /\$\(?[^$|\[\],\s]*\)?/g;
|
||||
const NORMAL_TAG_REGEX = /[^\s,|<>\[\]:]+_\([^\s,|<>\[\]:]*\)?|[^\s,|<>():\[\]]+|</g;
|
||||
const RUBY_TAG_REGEX = /[\w\d<][\w\d' \-?!/$%]{2,}>?/g;
|
||||
const TAG_REGEX = new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g");
|
||||
const TAG_REGEX = () => { return new RegExp(`${POINTY_REGEX.source}|${COMPLETED_WILDCARD_REGEX.source.replaceAll("__", escapeRegExp(TAC_CFG.wcWrap))}|${STYLE_VAR_REGEX.source}|${NORMAL_TAG_REGEX.source}`, "g"); }
|
||||
|
||||
// On click, insert the tag into the prompt textbox with respect to the cursor position
|
||||
async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithoutChoice = false) {
|
||||
@@ -400,8 +439,12 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
if (sanitizeResults && sanitizeResults.length > 0) {
|
||||
sanitizedText = sanitizeResults[0];
|
||||
} else {
|
||||
sanitizedText = TAC_CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
|
||||
|
||||
const excluded_tags = TAC_CFG.replaceUnderscoresExclusionList?.split(',').map(s => s.trim()) || [];
|
||||
if (TAC_CFG.replaceUnderscores && !excluded_tags.includes(sanitizedText)) {
|
||||
sanitizedText = text.replaceAll("_", " ")
|
||||
} else {
|
||||
sanitizedText = text;
|
||||
}
|
||||
if (TAC_CFG.escapeParentheses && tagType === ResultType.tag) {
|
||||
sanitizedText = sanitizedText
|
||||
.replaceAll("(", "\\(")
|
||||
@@ -440,13 +483,44 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
// Don't cut off the __ at the end if it is already the full path
|
||||
if (firstDifference > 0 && firstDifference < longestResult) {
|
||||
// +2 because the sanitized text already has the __ at the start but the matched text doesn't
|
||||
sanitizedText = sanitizedText.substring(0, firstDifference + 2);
|
||||
sanitizedText = sanitizedText.substring(0, firstDifference + TAC_CFG.wcWrap.length);
|
||||
} else if (firstDifference === 0) {
|
||||
sanitizedText = tagword;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Frequency db update
|
||||
if (TAC_CFG.frequencySort) {
|
||||
let name = null;
|
||||
|
||||
switch (tagType) {
|
||||
case ResultType.wildcardFile:
|
||||
case ResultType.yamlWildcard:
|
||||
// We only want to update the frequency for a full wildcard, not partial paths
|
||||
if (sanitizedText.endsWith(TAC_CFG.wcWrap))
|
||||
name = text
|
||||
break;
|
||||
case ResultType.chant:
|
||||
// Chants use a slightly different format
|
||||
name = result.aliases;
|
||||
break;
|
||||
default:
|
||||
name = text;
|
||||
break;
|
||||
}
|
||||
|
||||
if (name && name.length > 0) {
|
||||
// Check if it's a negative prompt
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let isNegative = textAreaId.includes("n");
|
||||
// Sanitize name for API call
|
||||
name = encodeURIComponent(name)
|
||||
// Call API & update db
|
||||
increaseUseCount(name, tagType, isNegative)
|
||||
}
|
||||
}
|
||||
|
||||
var prompt = textArea.value;
|
||||
|
||||
// Edit prompt text
|
||||
@@ -475,6 +549,10 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
optionalSeparator = TAC_CFG.extraNetworksSeparator || " ";
|
||||
}
|
||||
|
||||
// Escape $ signs since they are special chars for the replace function
|
||||
// We need four since we're also escaping them in replaceAll in the first place
|
||||
sanitizedText = sanitizedText.replaceAll("$", "$$$$");
|
||||
|
||||
// Replace partial tag word with new text, add comma if needed
|
||||
let insert = surrounding.replace(match, sanitizedText + optionalSeparator);
|
||||
|
||||
@@ -488,16 +566,24 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
let keywords = null;
|
||||
// Check built-in activation words first
|
||||
if (tagType === ResultType.lora || tagType === ResultType.lyco) {
|
||||
let info = await fetchAPI(`tacapi/v1/lora-info/${result.text}`)
|
||||
let info = await fetchTacAPI(`tacapi/v1/lora-info/${result.text}`)
|
||||
if (info && info["activation text"]) {
|
||||
keywords = info["activation text"];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (!keywords && modelKeywordPath.length > 0 && result.hash && result.hash !== "NOFILE" && result.hash.length > 0) {
|
||||
let nameDict = modelKeywordDict.get(result.hash);
|
||||
let names = [result.text + ".safetensors", result.text + ".pt", result.text + ".ckpt"];
|
||||
|
||||
// No match, try to find a sha256 match from the cache file
|
||||
if (!nameDict) {
|
||||
const sha256 = await fetchTacAPI(`/tacapi/v1/lora-cached-hash/${result.text}`)
|
||||
if (sha256) {
|
||||
nameDict = modelKeywordDict.get(sha256);
|
||||
}
|
||||
}
|
||||
|
||||
if (nameDict) {
|
||||
let found = false;
|
||||
names.forEach(name => {
|
||||
@@ -506,7 +592,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
keywords = nameDict.get(name);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
if (!found)
|
||||
keywords = nameDict.get("none");
|
||||
}
|
||||
@@ -514,33 +600,64 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
|
||||
if (keywords && keywords.length > 0) {
|
||||
textBeforeKeywordInsertion = newPrompt;
|
||||
|
||||
newPrompt = `${keywords}, ${newPrompt}`; // Insert keywords
|
||||
|
||||
|
||||
if (TAC_CFG.modelKeywordLocation === "Start of prompt")
|
||||
newPrompt = `${keywords}, ${newPrompt}`; // Insert keywords
|
||||
else if (TAC_CFG.modelKeywordLocation === "End of prompt")
|
||||
newPrompt = `${newPrompt}, ${keywords}`; // Insert keywords
|
||||
else {
|
||||
let keywordStart = prompt[editStart - 1] === " " ? editStart - 1 : editStart;
|
||||
newPrompt = prompt.substring(0, keywordStart) + `, ${keywords} ${insert}` + prompt.substring(editEnd);
|
||||
}
|
||||
|
||||
|
||||
textAfterKeywordInsertion = newPrompt;
|
||||
keywordInsertionUndone = false;
|
||||
setTimeout(() => lastEditWasKeywordInsertion = true, 200)
|
||||
|
||||
|
||||
keywordsLength = keywords.length + 2; // +2 for the comma and space
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Insert into prompt textbox and reposition cursor
|
||||
textArea.value = newPrompt;
|
||||
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength;
|
||||
textArea.selectionEnd = textArea.selectionStart
|
||||
|
||||
// Set self trigger flag to show wildcard contents after the filename was inserted
|
||||
if ([ResultType.wildcardFile, ResultType.yamlWildcard, ResultType.umiWildcard].includes(result.type))
|
||||
tacSelfTrigger = true;
|
||||
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
|
||||
// Uses a built-in method from the webui's ui.js which also already accounts for event target
|
||||
if (tagType === ResultType.wildcardTag || tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard)
|
||||
tacSelfTrigger = true;
|
||||
updateInput(textArea);
|
||||
|
||||
// Update previous tags with the edited prompt to prevent re-searching the same term
|
||||
let weightedTags = [...newPrompt.matchAll(WEIGHT_REGEX)]
|
||||
.map(match => match[1]);
|
||||
let tags = newPrompt.match(TAG_REGEX)
|
||||
if (weightedTags !== null) {
|
||||
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted)))
|
||||
.concat(weightedTags);
|
||||
let weightedTags = [...prompt.matchAll(WEIGHT_REGEX)]
|
||||
.map(match => match[1])
|
||||
.sort((a, b) => a.length - b.length);
|
||||
let tags = [...prompt.match(TAG_REGEX())].sort((a, b) => a.length - b.length);
|
||||
|
||||
if (weightedTags !== null && tags !== null) {
|
||||
// Create a working copy of the normal tags
|
||||
let workingTags = [...tags];
|
||||
|
||||
// For each weighted tag
|
||||
for (const weightedTag of weightedTags) {
|
||||
// Find first matching tag and remove it from working set
|
||||
const matchIndex = workingTags.findIndex(tag =>
|
||||
tag === weightedTag && !tag.startsWith("<[") && !tag.startsWith("$(")
|
||||
);
|
||||
|
||||
if (matchIndex !== -1) {
|
||||
// Remove the matched tag from the working set
|
||||
workingTags.splice(matchIndex, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Combine filtered normal tags with weighted tags
|
||||
tags = workingTags.concat(weightedTags);
|
||||
}
|
||||
previousTags = tags;
|
||||
|
||||
@@ -575,6 +692,30 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
let tagColors = TAC_CFG.colorMap;
|
||||
let mode = (document.querySelector(".dark") || gradioApp().querySelector(".dark")) ? 0 : 1;
|
||||
let nextLength = Math.min(results.length, resultCount + TAC_CFG.resultStepLength);
|
||||
const IS_DAN_OR_E621_TAG_FILE = (tagFileName.toLowerCase().startsWith("danbooru") || tagFileName.toLowerCase().startsWith("e621"));
|
||||
|
||||
const tagCount = {};
|
||||
|
||||
// Indicate if tag was used before
|
||||
if (IS_DAN_OR_E621_TAG_FILE) {
|
||||
const prompt = textArea.value.trim();
|
||||
const tags = prompt.replaceAll('\n', ',').split(',').map(tag => tag.trim()).filter(tag => tag);
|
||||
|
||||
const unsanitizedTags = tags.map(tag => {
|
||||
const weightedTags = [...tag.matchAll(WEIGHT_REGEX)].flat();
|
||||
if (weightedTags.length === 2) {
|
||||
return weightedTags[1];
|
||||
} else {
|
||||
// normal tags
|
||||
return tag;
|
||||
}
|
||||
}).map(tag => tag.replaceAll(" ", "_").replaceAll("\\(", "(").replaceAll("\\)", ")"));
|
||||
|
||||
// Split tags by `,` and count tag
|
||||
for (const tag of unsanitizedTags) {
|
||||
tagCount[tag] = tagCount[tag] ? tagCount[tag] + 1 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = resultCount; i < nextLength; i++) {
|
||||
let result = results[i];
|
||||
@@ -640,26 +781,43 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
}
|
||||
|
||||
// Add wiki link if the setting is enabled and a supported tag set loaded
|
||||
if (TAC_CFG.showWikiLinks
|
||||
&& (result.type === ResultType.tag)
|
||||
&& (tagFileName.toLowerCase().startsWith("danbooru") || tagFileName.toLowerCase().startsWith("e621"))) {
|
||||
if (
|
||||
TAC_CFG.showWikiLinks &&
|
||||
result.type === ResultType.tag &&
|
||||
IS_DAN_OR_E621_TAG_FILE
|
||||
) {
|
||||
let wikiLink = document.createElement("a");
|
||||
wikiLink.classList.add("acWikiLink");
|
||||
wikiLink.innerText = "?";
|
||||
wikiLink.title = "Open external wiki page for this tag";
|
||||
|
||||
let linkPart = displayText;
|
||||
// Only use alias result if it is one
|
||||
if (displayText.includes("➝"))
|
||||
linkPart = displayText.split(" ➝ ")[1];
|
||||
|
||||
if (displayText.includes("➝")) linkPart = displayText.split(" ➝ ")[1];
|
||||
|
||||
// Remove any trailing translations
|
||||
if (linkPart.includes("[")) {
|
||||
linkPart = linkPart.split("[")[0];
|
||||
}
|
||||
|
||||
linkPart = encodeURIComponent(linkPart);
|
||||
|
||||
// Set link based on selected file
|
||||
let tagFileNameLower = tagFileName.toLowerCase();
|
||||
if (tagFileNameLower.startsWith("danbooru")) {
|
||||
if (tagFileNameLower.startsWith("danbooru_e621_merged")) {
|
||||
// Use danbooru for categories 0-5, e621 for 6+
|
||||
// Based on the merged categories from https://github.com/DraconicDragon/dbr-e621-lists-archive/tree/main/tag-lists/danbooru_e621_merged
|
||||
// Danbooru is also the fallback if result.category is not set
|
||||
wikiLink.href =
|
||||
result.category && result.category >= 6
|
||||
? `https://e621.net/wiki_pages/${linkPart}`
|
||||
: `https://danbooru.donmai.us/wiki_pages/${linkPart}`;
|
||||
} else if (tagFileNameLower.startsWith("danbooru")) {
|
||||
wikiLink.href = `https://danbooru.donmai.us/wiki_pages/${linkPart}`;
|
||||
} else if (tagFileNameLower.startsWith("e621")) {
|
||||
wikiLink.href = `https://e621.net/wiki_pages/${linkPart}`;
|
||||
}
|
||||
|
||||
|
||||
wikiLink.target = "_blank";
|
||||
flexDiv.appendChild(wikiLink);
|
||||
}
|
||||
@@ -684,7 +842,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
}
|
||||
|
||||
// Post count
|
||||
if (result.count && !isNaN(result.count)) {
|
||||
if (result.count && !isNaN(result.count) && result.count !== Number.MAX_SAFE_INTEGER) {
|
||||
let postCount = result.count;
|
||||
let formatter;
|
||||
|
||||
@@ -712,12 +870,72 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
else if (result.meta.startsWith("v2"))
|
||||
itemText.classList.add("acEmbeddingV2");
|
||||
}
|
||||
|
||||
|
||||
flexDiv.appendChild(metaDiv);
|
||||
}
|
||||
|
||||
// Add listener
|
||||
li.addEventListener("click", function () { insertTextAtCursor(textArea, result, tagword); });
|
||||
// Add small ✨ marker to indicate usage sorting
|
||||
if (result.usageBias) {
|
||||
flexDiv.querySelector(".acMetaText").classList.add("biased");
|
||||
flexDiv.title = "✨ Frequent tag. Ctrl/Cmd + click to reset usage count.";
|
||||
}
|
||||
|
||||
// Add 🔁 to indicate if tag was used before
|
||||
if (IS_DAN_OR_E621_TAG_FILE && tagCount[result.text]) {
|
||||
// Fix PR#313#issuecomment-2592551794
|
||||
if (!(result.text === tagword && tagCount[result.text] === 1)) {
|
||||
const textNode = flexDiv.querySelector(".acMetaText");
|
||||
const span = document.createElement("span");
|
||||
textNode.insertBefore(span, textNode.firstChild);
|
||||
span.classList.add("used");
|
||||
span.title = "🔁 The prompt already contains this tag";
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a negative prompt
|
||||
let isNegative = textAreaId.includes("n");
|
||||
|
||||
// Add click listener
|
||||
li.addEventListener("click", (e) => {
|
||||
if (e.ctrlKey || e.metaKey) {
|
||||
resetUseCount(result.text, result.type, !isNegative, isNegative);
|
||||
flexDiv.querySelector(".acMetaText").classList.remove("biased");
|
||||
} else {
|
||||
insertTextAtCursor(textArea, result, tagword);
|
||||
}
|
||||
});
|
||||
// Add delayed hover listener for extra network previews
|
||||
if (
|
||||
TAC_CFG.showExtraNetworkPreviews &&
|
||||
[
|
||||
ResultType.embedding,
|
||||
ResultType.hypernetwork,
|
||||
ResultType.lora,
|
||||
ResultType.lyco,
|
||||
].includes(result.type)
|
||||
) {
|
||||
li.addEventListener("mouseover", async () => {
|
||||
const me = this;
|
||||
let hoverTimeout;
|
||||
|
||||
hoverTimeout = setTimeout(async () => {
|
||||
// If the tag we hover over is already selected, do nothing
|
||||
if (selectedTag && selectedTag === i) return;
|
||||
|
||||
oldSelectedTag = selectedTag;
|
||||
selectedTag = i;
|
||||
|
||||
// Update selection without scrolling to the item (since we would
|
||||
// immediately trigger the next scroll as the items move under the cursor)
|
||||
updateSelectionStyle(textArea, selectedTag, oldSelectedTag, false);
|
||||
}, 400);
|
||||
// Reset delay timer if we leave the item
|
||||
me.addEventListener("mouseout", () => {
|
||||
clearTimeout(hoverTimeout);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Add element to list
|
||||
resultsList.appendChild(li);
|
||||
}
|
||||
@@ -730,7 +948,7 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
}
|
||||
}
|
||||
|
||||
async function updateSelectionStyle(textArea, newIndex, oldIndex) {
|
||||
async function updateSelectionStyle(textArea, newIndex, oldIndex, scroll = true) {
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let resultDiv = gradioApp().querySelector('.autocompleteResults' + textAreaId);
|
||||
let resultsList = resultDiv.querySelector('ul');
|
||||
@@ -745,40 +963,25 @@ async function updateSelectionStyle(textArea, newIndex, oldIndex) {
|
||||
let selected = items[newIndex];
|
||||
selected.classList.add('selected');
|
||||
|
||||
// Set scrolltop to selected item
|
||||
resultDiv.scrollTop = selected.offsetTop - resultDiv.offsetTop;
|
||||
// Set scrolltop to selected item
|
||||
if (scroll) resultDiv.scrollTop = selected.offsetTop - resultDiv.offsetTop;
|
||||
}
|
||||
|
||||
// Show preview if enabled and the selected type supports it
|
||||
if (newIndex !== null) {
|
||||
let selected = items[newIndex];
|
||||
let previewTypes = ["v1 Embedding", "v2 Embedding", "Hypernetwork", "Lora", "Lyco"];
|
||||
let selectedType = selected.querySelector(".acMetaText").innerText;
|
||||
let selectedFilename = selected.querySelector(".acListItem").innerText;
|
||||
let selectedResult = results[newIndex];
|
||||
let selectedType = selectedResult.type;
|
||||
// These types support previews (others could technically too, but are not native to the webui gallery)
|
||||
let previewTypes = [ResultType.embedding, ResultType.hypernetwork, ResultType.lora, ResultType.lyco];
|
||||
|
||||
let previewDiv = gradioApp().querySelector(`.autocompleteParent${textAreaId} .sideInfo`);
|
||||
|
||||
if (TAC_CFG.showExtraNetworkPreviews && previewTypes.includes(selectedType)) {
|
||||
let shorthandType = "";
|
||||
switch (selectedType) {
|
||||
case "v1 Embedding":
|
||||
case "v2 Embedding":
|
||||
shorthandType = "embed";
|
||||
break;
|
||||
case "Hypernetwork":
|
||||
shorthandType = "hyper";
|
||||
break;
|
||||
case "Lora":
|
||||
shorthandType = "lora";
|
||||
break;
|
||||
case "Lyco":
|
||||
shorthandType = "lyco";
|
||||
break;
|
||||
}
|
||||
|
||||
let img = previewDiv.querySelector("img");
|
||||
|
||||
let url = await getExtraNetworkPreviewURL(selectedFilename, shorthandType);
|
||||
// String representation of our type enum
|
||||
const typeString = Object.keys(ResultType)[selectedType - 1].toLowerCase();
|
||||
// Get image from API
|
||||
let url = await getTacExtraNetworkPreviewURL(selectedResult.text, typeString);
|
||||
if (url) {
|
||||
img.src = url;
|
||||
previewDiv.style.display = "block";
|
||||
@@ -803,7 +1006,7 @@ function updateRuby(textArea, prompt) {
|
||||
ruby.setAttribute("class", `acRuby${typeClass} notranslate`);
|
||||
textArea.parentNode.appendChild(ruby);
|
||||
}
|
||||
|
||||
|
||||
ruby.innerText = prompt;
|
||||
|
||||
let bracketEscapedPrompt = prompt.replaceAll("\\(", "$").replaceAll("\\)", "%");
|
||||
@@ -821,9 +1024,9 @@ function updateRuby(textArea, prompt) {
|
||||
.replaceAll(" ", "_")
|
||||
.replaceAll("\\(", "(")
|
||||
.replaceAll("\\)", ")");
|
||||
|
||||
|
||||
const translation = translations?.get(tag) || translations?.get(unsanitizedTag);
|
||||
|
||||
|
||||
let escapedTag = escapeRegExp(tag);
|
||||
return { tag, escapedTag, translation };
|
||||
}
|
||||
@@ -839,14 +1042,14 @@ function updateRuby(textArea, prompt) {
|
||||
// First try to find direct matches
|
||||
[...rubyTags].forEach(tag => {
|
||||
let tuple = prepareTag(tag);
|
||||
|
||||
|
||||
if (tuple.translation) {
|
||||
html = replaceOccurences(html, tuple);
|
||||
} else {
|
||||
let subTags = tuple.tag.split(" ").filter(x => x.trim().length > 0);
|
||||
// Return if there is only one word
|
||||
if (subTags.length === 1) return;
|
||||
|
||||
|
||||
let subHtml = tag.replaceAll("$", "\\(").replaceAll("%", "\\)");
|
||||
|
||||
let translateNgram = (windows) => {
|
||||
@@ -861,14 +1064,14 @@ function updateRuby(textArea, prompt) {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
// Perform n-gram sliding window search
|
||||
translateNgram(toNgrams(subTags, 3));
|
||||
translateNgram(toNgrams(subTags, 2));
|
||||
translateNgram(toNgrams(subTags, 1));
|
||||
|
||||
let escapedTag = escapeRegExp(tuple.tag);
|
||||
|
||||
|
||||
let searchRegex = new RegExp(`(?<!<ruby>)(?:\\b)${escapedTag}(?:\\b|$|(?=[,|: \\t\\n\\r]))(?!<rt>)`, "g");
|
||||
html = html.replaceAll(searchRegex, subHtml);
|
||||
}
|
||||
@@ -905,6 +1108,7 @@ function checkKeywordInsertionUndo(textArea, event) {
|
||||
if (lastEditWasKeywordInsertion && !keywordInsertionUndone) {
|
||||
keywordInsertionUndone = true;
|
||||
textArea.value = textBeforeKeywordInsertion;
|
||||
tacSelfTrigger = true;
|
||||
updateInput(textArea);
|
||||
}
|
||||
break;
|
||||
@@ -912,6 +1116,7 @@ function checkKeywordInsertionUndo(textArea, event) {
|
||||
if (lastEditWasKeywordInsertion && keywordInsertionUndone) {
|
||||
keywordInsertionUndone = false;
|
||||
textArea.value = textAfterKeywordInsertion;
|
||||
tacSelfTrigger = true;
|
||||
updateInput(textArea);
|
||||
}
|
||||
case undefined:
|
||||
@@ -943,11 +1148,29 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
// Match tags with RegEx to get the last edited one
|
||||
// We also match for the weighting format (e.g. "tag:1.0") here, and combine the two to get the full tag word set
|
||||
let weightedTags = [...prompt.matchAll(WEIGHT_REGEX)]
|
||||
.map(match => match[1]);
|
||||
let tags = prompt.match(TAG_REGEX)
|
||||
.map(match => match[1])
|
||||
.sort((a, b) => a.length - b.length);
|
||||
let tags = [...prompt.match(TAG_REGEX())].sort((a, b) => a.length - b.length);
|
||||
|
||||
if (weightedTags !== null && tags !== null) {
|
||||
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted) && !tag.startsWith("<[")))
|
||||
.concat(weightedTags);
|
||||
// Create a working copy of the normal tags
|
||||
let workingTags = [...tags];
|
||||
|
||||
// For each weighted tag
|
||||
for (const weightedTag of weightedTags) {
|
||||
// Find first matching tag and remove it from working set
|
||||
const matchIndex = workingTags.findIndex(tag =>
|
||||
tag === weightedTag && !tag.startsWith("<[") && !tag.startsWith("$(")
|
||||
);
|
||||
|
||||
if (matchIndex !== -1) {
|
||||
// Remove the matched tag from the working set
|
||||
workingTags.splice(matchIndex, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Combine filtered normal tags with weighted tags
|
||||
tags = workingTags.concat(weightedTags);
|
||||
}
|
||||
|
||||
// Guard for no tags
|
||||
@@ -980,46 +1203,29 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
}
|
||||
|
||||
results = [];
|
||||
resultCountBeforeNormalTags = 0;
|
||||
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
|
||||
|
||||
// Needed for slicing check later
|
||||
let normalTags = false;
|
||||
|
||||
// Process all parsers
|
||||
let resultCandidates = await processParsers(textArea, prompt);
|
||||
let resultCandidates = (await processParsers(textArea, prompt))?.filter(x => x.length > 0);
|
||||
// If one ore more result candidates match, use their results
|
||||
if (resultCandidates && resultCandidates.length > 0) {
|
||||
// Flatten our candidate(s)
|
||||
results = resultCandidates.flat();
|
||||
// If there was more than one candidate, sort the results by text to mix them
|
||||
// instead of having them added in the order of the parsers
|
||||
let shouldSort = resultCandidates.length > 1;
|
||||
if (shouldSort) {
|
||||
results = results.sort((a, b) => {
|
||||
let sortByA = a.type === ResultType.chant ? a.aliases : a.text;
|
||||
let sortByB = b.type === ResultType.chant ? b.aliases : b.text;
|
||||
return sortByA.localeCompare(sortByB);
|
||||
});
|
||||
// Sort results, but not if it's umi tags since they are sorted by count
|
||||
if (!(resultCandidates.length === 1 && results[0].type === ResultType.umiWildcard))
|
||||
results = results.sort(getSortFunction());
|
||||
}
|
||||
// Else search the normal tag list
|
||||
if (!resultCandidates || resultCandidates.length === 0
|
||||
|| (TAC_CFG.includeEmbeddingsInNormalResults && !(tagword.startsWith("<") || tagword.startsWith("*<")))
|
||||
) {
|
||||
normalTags = true;
|
||||
resultCountBeforeNormalTags = results.length;
|
||||
|
||||
// Since some tags are kaomoji, we have to add the normal results in some cases
|
||||
if (tagword.startsWith("<") || tagword.startsWith("*<")) {
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
if (tagword.startsWith("*")) {
|
||||
tagword = tagword.slice(1);
|
||||
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
|
||||
} else {
|
||||
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
|
||||
}
|
||||
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, TAC_CFG.maxResults);
|
||||
|
||||
genericResults.forEach(g => {
|
||||
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
|
||||
result.category = g[1];
|
||||
result.count = g[2];
|
||||
result.aliases = g[3];
|
||||
results.push(result);
|
||||
});
|
||||
}
|
||||
}
|
||||
} else { // Else search the normal tag list
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
if (tagword.startsWith("*")) {
|
||||
@@ -1034,7 +1240,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
let aliasFilter = (x) => x[3] && x[3].toLowerCase().search(searchRegex) > -1;
|
||||
let translationFilter = (x) => (translations.has(x[0]) && translations.get(x[0]).toLowerCase().search(searchRegex) > -1)
|
||||
|| x[3] && x[3].split(",").some(y => translations.has(y) && translations.get(y).toLowerCase().search(searchRegex) > -1);
|
||||
|
||||
|
||||
let fil;
|
||||
if (TAC_CFG.alias.searchByAlias && TAC_CFG.translation.searchByTranslation)
|
||||
fil = (x) => baseFilter(x) || aliasFilter(x) || translationFilter(x);
|
||||
@@ -1072,11 +1278,6 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
results = results.concat(extraResults);
|
||||
}
|
||||
}
|
||||
|
||||
// Slice if the user has set a max result count
|
||||
if (!TAC_CFG.showAllResults) {
|
||||
results = results.slice(0, TAC_CFG.maxResults);
|
||||
}
|
||||
}
|
||||
|
||||
// Guard for empty results
|
||||
@@ -1086,6 +1287,57 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort again with frequency / usage count if enabled
|
||||
if (TAC_CFG.frequencySort) {
|
||||
// Split our results into a list of names and types
|
||||
let tagNames = [];
|
||||
let aliasNames = [];
|
||||
let types = [];
|
||||
// Limit to 2k for performance reasons
|
||||
const aliasTypes = [ResultType.tag, ResultType.extra];
|
||||
results.slice(0,2000).forEach(r => {
|
||||
const name = r.type === ResultType.chant ? r.aliases : r.text;
|
||||
// Add to alias list or tag list depending on if the name includes the tagword
|
||||
// (the same criteria is used in the filter in calculateUsageBias)
|
||||
if (aliasTypes.includes(r.type) && !name.includes(tagword)) {
|
||||
aliasNames.push(name);
|
||||
} else {
|
||||
tagNames.push(name);
|
||||
}
|
||||
types.push(r.type);
|
||||
});
|
||||
|
||||
// Check if it's a negative prompt
|
||||
let textAreaId = getTextAreaIdentifier(textArea);
|
||||
let isNegative = textAreaId.includes("n");
|
||||
|
||||
// Request use counts from the DB
|
||||
const names = TAC_CFG.frequencyIncludeAlias ? tagNames.concat(aliasNames) : tagNames;
|
||||
const counts = await getUseCounts(names, types, isNegative) || [];
|
||||
|
||||
// Pre-calculate weights to prevent duplicate work
|
||||
const resultBiasMap = new Map();
|
||||
results.forEach(result => {
|
||||
const name = result.type === ResultType.chant ? result.aliases : result.text;
|
||||
const type = result.type;
|
||||
// Find matching pair from DB results
|
||||
const useStats = counts.find(c => c.name === name && c.type === type);
|
||||
const uses = useStats?.count || 0;
|
||||
// Calculate & set weight
|
||||
const weight = calculateUsageBias(result, result.count, uses)
|
||||
resultBiasMap.set(result, weight);
|
||||
});
|
||||
// Actual sorting with the pre-calculated weights
|
||||
results = results.sort((a, b) => {
|
||||
return resultBiasMap.get(b) - resultBiasMap.get(a);
|
||||
});
|
||||
}
|
||||
|
||||
// Slice if the user has set a max result count and we are not in a extra networks / wildcard list
|
||||
if (!TAC_CFG.showAllResults && normalTags) {
|
||||
results = results.slice(0, TAC_CFG.maxResults + resultCountBeforeNormalTags);
|
||||
}
|
||||
|
||||
addResultsToList(textArea, results, tagword, true);
|
||||
showResults(textArea);
|
||||
}
|
||||
@@ -1093,7 +1345,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
function navigateInList(textArea, event) {
|
||||
// Return if the function is deactivated in the UI or the current model is excluded due to white/blacklist settings
|
||||
if (!isEnabled()) return;
|
||||
|
||||
|
||||
let keys = TAC_CFG.keymap;
|
||||
|
||||
// Close window if Home or End is pressed while not a keybinding, since it would break completion on leaving the original tag
|
||||
@@ -1108,12 +1360,17 @@ function navigateInList(textArea, event) {
|
||||
|
||||
if (!validKeys.includes(event.key)) return;
|
||||
if (!isVisible(textArea)) return
|
||||
// Return if ctrl key is pressed to not interfere with weight editing shortcut
|
||||
if (event.ctrlKey || event.altKey) return;
|
||||
// Add modifier keys to base as text+.
|
||||
let modKey = "";
|
||||
if (event.ctrlKey) modKey += "Ctrl+";
|
||||
if (event.altKey) modKey += "Alt+";
|
||||
if (event.shiftKey) modKey += "Shift+";
|
||||
if (event.metaKey) modKey += "Meta+";
|
||||
modKey += event.key;
|
||||
|
||||
oldSelectedTag = selectedTag;
|
||||
|
||||
switch (event.key) {
|
||||
switch (modKey) {
|
||||
case keys["MoveUp"]:
|
||||
if (selectedTag === null) {
|
||||
selectedTag = resultCount - 1;
|
||||
@@ -1143,10 +1400,25 @@ function navigateInList(textArea, event) {
|
||||
}
|
||||
break;
|
||||
case keys["JumpToStart"]:
|
||||
selectedTag = 0;
|
||||
if (TAC_CFG.includeEmbeddingsInNormalResults &&
|
||||
selectedTag > resultCountBeforeNormalTags &&
|
||||
resultCountBeforeNormalTags > 0
|
||||
) {
|
||||
selectedTag = resultCountBeforeNormalTags;
|
||||
} else {
|
||||
selectedTag = 0;
|
||||
}
|
||||
break;
|
||||
case keys["JumpToEnd"]:
|
||||
selectedTag = resultCount - 1;
|
||||
// Jump to the end of the list, or the end of embeddings if they are included in the normal results
|
||||
if (TAC_CFG.includeEmbeddingsInNormalResults &&
|
||||
selectedTag < resultCountBeforeNormalTags &&
|
||||
resultCountBeforeNormalTags > 0
|
||||
) {
|
||||
selectedTag = Math.min(resultCountBeforeNormalTags, resultCount - 1);
|
||||
} else {
|
||||
selectedTag = resultCount - 1;
|
||||
}
|
||||
break;
|
||||
case keys["ChooseSelected"]:
|
||||
if (selectedTag !== null) {
|
||||
@@ -1169,6 +1441,8 @@ function navigateInList(textArea, event) {
|
||||
case keys["Close"]:
|
||||
hideResults(textArea);
|
||||
break;
|
||||
default:
|
||||
if (event.ctrlKey || event.altKey || event.shiftKey || event.metaKey) return;
|
||||
}
|
||||
let moveKeys = [keys["MoveUp"], keys["MoveDown"], keys["JumpUp"], keys["JumpDown"], keys["JumpToStart"], keys["JumpToEnd"]];
|
||||
if (selectedTag === resultCount - 1 && moveKeys.includes(event.key)) {
|
||||
@@ -1183,8 +1457,8 @@ function navigateInList(textArea, event) {
|
||||
event.stopPropagation();
|
||||
}
|
||||
|
||||
async function refreshTacTempFiles() {
|
||||
setTimeout(async () => {
|
||||
async function refreshTacTempFiles(api = false) {
|
||||
const reload = async () => {
|
||||
wildcardFiles = [];
|
||||
wildcardExtFiles = [];
|
||||
umiWildcards = [];
|
||||
@@ -1196,7 +1470,23 @@ async function refreshTacTempFiles() {
|
||||
await processQueue(QUEUE_FILE_LOAD, null);
|
||||
|
||||
console.log("TAC: Refreshed temp files");
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
if (api) {
|
||||
await postTacAPI("tacapi/v1/refresh-temp-files");
|
||||
await reload();
|
||||
} else {
|
||||
setTimeout(async () => {
|
||||
await reload();
|
||||
}, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
async function refreshEmbeddings() {
|
||||
await postTacAPI("tacapi/v1/refresh-embeddings", null);
|
||||
embeddings = [];
|
||||
await processQueue(QUEUE_FILE_LOAD, null);
|
||||
console.log("TAC: Refreshed embeddings");
|
||||
}
|
||||
|
||||
function addAutocompleteToArea(area) {
|
||||
@@ -1219,8 +1509,19 @@ function addAutocompleteToArea(area) {
|
||||
|
||||
// Add autocomplete event listener
|
||||
area.addEventListener('input', (e) => {
|
||||
debounce(autocomplete(area, area.value), TAC_CFG.delayTime);
|
||||
updateRuby(area, area.value);
|
||||
|
||||
// Cancel autocomplete itself if the event has no inputType (e.g. because it was triggered by the updateInput() function)
|
||||
if (!e.inputType && !tacSelfTrigger) return;
|
||||
tacSelfTrigger = false;
|
||||
|
||||
// Block hide we are composing (IME), so enter doesn't close the results
|
||||
if (e.isComposing) {
|
||||
hideBlocked = true;
|
||||
setTimeout(() => { hideBlocked = false; }, 100);
|
||||
}
|
||||
|
||||
debounce(autocomplete(area, area.value), TAC_CFG.delayTime);
|
||||
checkKeywordInsertionUndo(area, e);
|
||||
});
|
||||
// Add focusout event listener
|
||||
@@ -1281,6 +1582,20 @@ async function setup() {
|
||||
// Listener for internal temp files refresh button
|
||||
gradioApp().querySelector("#refresh_tac_refreshTempFiles")?.addEventListener("click", refreshTacTempFiles);
|
||||
|
||||
// Also add listener for external network refresh button (plus triggering python code)
|
||||
let alreadyAdded = new Set();
|
||||
["#img2img_extra_refresh", "#txt2img_extra_refresh", ".extra-network-control--refresh"].forEach(e => {
|
||||
const elems = gradioApp().querySelectorAll(e);
|
||||
elems.forEach(elem => {
|
||||
if (!elem || alreadyAdded.has(elem)) return;
|
||||
|
||||
alreadyAdded.add(elem);
|
||||
elem.addEventListener("click", ()=>{
|
||||
refreshTacTempFiles(true);
|
||||
});
|
||||
});
|
||||
})
|
||||
|
||||
// Add mutation observer for the model hash text to also allow hash-based blacklist again
|
||||
let modelHashText = gradioApp().querySelector("#sd_checkpoint_hash");
|
||||
updateModelName();
|
||||
@@ -1291,6 +1606,7 @@ async function setup() {
|
||||
if (mutation.type === "attributes" && mutation.attributeName === "title") {
|
||||
currentModelHash = mutation.target.title;
|
||||
updateModelName();
|
||||
refreshEmbeddings();
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -1315,7 +1631,7 @@ async function setup() {
|
||||
let mode = (document.querySelector(".dark") || gradioApp().querySelector(".dark")) ? 0 : 1;
|
||||
// Check if we are on webkit
|
||||
let browser = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 ? "firefox" : "other";
|
||||
|
||||
|
||||
let css = autocompleteCSS;
|
||||
// Replace vars with actual values (can't use actual css vars because of the way we inject the css)
|
||||
Object.keys(styleColors).forEach((key) => {
|
||||
@@ -1324,13 +1640,13 @@ async function setup() {
|
||||
Object.keys(browserVars).forEach((key) => {
|
||||
css = css.replaceAll(`var(${key})`, browserVars[key][browser]);
|
||||
})
|
||||
|
||||
|
||||
if (acStyle.styleSheet) {
|
||||
acStyle.styleSheet.cssText = css;
|
||||
} else {
|
||||
acStyle.appendChild(document.createTextNode(css));
|
||||
}
|
||||
gradioApp().appendChild(acStyle);
|
||||
document.head.appendChild(acStyle);
|
||||
|
||||
// Callback
|
||||
await processQueue(QUEUE_AFTER_SETUP, null);
|
||||
|
||||
@@ -16,6 +16,8 @@ hash_dict = {}
|
||||
|
||||
|
||||
def load_hash_cache():
|
||||
if not known_hashes_file.exists():
|
||||
known_hashes_file.touch()
|
||||
with open(known_hashes_file, "r", encoding="utf-8") as file:
|
||||
reader = csv.reader(
|
||||
file.readlines(), delimiter=",", quotechar='"', skipinitialspace=True
|
||||
@@ -28,6 +30,8 @@ def load_hash_cache():
|
||||
def update_hash_cache():
|
||||
global file_needs_update
|
||||
if file_needs_update:
|
||||
if not known_hashes_file.exists():
|
||||
known_hashes_file.touch()
|
||||
with open(known_hashes_file, "w", encoding="utf-8", newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
for name, (hash, mtime) in hash_dict.items():
|
||||
|
||||
@@ -1,35 +1,57 @@
|
||||
from pathlib import Path
|
||||
|
||||
from modules import scripts, shared
|
||||
|
||||
try:
|
||||
from modules.paths import extensions_dir, script_path
|
||||
|
||||
# Webui root path
|
||||
FILE_DIR = Path(script_path)
|
||||
FILE_DIR = Path(script_path).absolute()
|
||||
|
||||
# The extension base path
|
||||
EXT_PATH = Path(extensions_dir)
|
||||
EXT_PATH = Path(extensions_dir).absolute()
|
||||
except ImportError:
|
||||
# Webui root path
|
||||
FILE_DIR = Path().absolute()
|
||||
# The extension base path
|
||||
EXT_PATH = FILE_DIR.joinpath("extensions")
|
||||
EXT_PATH = FILE_DIR.joinpath("extensions").absolute()
|
||||
|
||||
# Tags base path
|
||||
TAGS_PATH = Path(scripts.basedir()).joinpath("tags")
|
||||
TAGS_PATH = Path(scripts.basedir()).joinpath("tags").absolute()
|
||||
|
||||
# The path to the folder containing the wildcards and embeddings
|
||||
WILDCARD_PATH = FILE_DIR.joinpath("scripts/wildcards")
|
||||
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
|
||||
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
|
||||
try: # SD.Next
|
||||
WILDCARD_PATH = Path(shared.opts.wildcards_dir).absolute()
|
||||
except Exception: # A1111
|
||||
WILDCARD_PATH = FILE_DIR.joinpath("scripts/wildcards").absolute()
|
||||
EMB_PATH = Path(shared.cmd_opts.embeddings_dir).absolute()
|
||||
|
||||
# Forge Classic detection
|
||||
try:
|
||||
from modules_forge.forge_version import version as forge_version
|
||||
IS_FORGE_CLASSIC = forge_version == "classic"
|
||||
except ImportError:
|
||||
IS_FORGE_CLASSIC = False
|
||||
|
||||
# Forge Classic skips it
|
||||
if not IS_FORGE_CLASSIC:
|
||||
try:
|
||||
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir).absolute()
|
||||
except AttributeError:
|
||||
HYP_PATH = None
|
||||
else:
|
||||
HYP_PATH = None
|
||||
|
||||
try:
|
||||
LORA_PATH = Path(shared.cmd_opts.lora_dir)
|
||||
LORA_PATH = Path(shared.cmd_opts.lora_dir).absolute()
|
||||
except AttributeError:
|
||||
LORA_PATH = None
|
||||
|
||||
try:
|
||||
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
|
||||
try:
|
||||
LYCO_PATH = Path(shared.cmd_opts.lyco_dir_backcompat).absolute()
|
||||
except:
|
||||
LYCO_PATH = Path(shared.cmd_opts.lyco_dir).absolute() # attempt original non-backcompat path
|
||||
except AttributeError:
|
||||
LYCO_PATH = None
|
||||
|
||||
@@ -37,6 +59,21 @@ except AttributeError:
|
||||
def find_ext_wildcard_paths():
|
||||
"""Returns the path to the extension wildcards folder"""
|
||||
found = list(EXT_PATH.glob("*/wildcards/"))
|
||||
# Try to find the wildcard path from the shared opts
|
||||
try:
|
||||
from modules.shared import opts
|
||||
except ImportError: # likely not in an a1111 context
|
||||
opts = None
|
||||
|
||||
# Append custom wildcard paths
|
||||
custom_paths = [
|
||||
getattr(shared.cmd_opts, "wildcards_dir", None), # Cmd arg from the wildcard extension
|
||||
getattr(opts, "wildcard_dir", None), # Custom path from sd-dynamic-prompts
|
||||
]
|
||||
for path in [Path(p).absolute() for p in custom_paths if p is not None]:
|
||||
if path.exists():
|
||||
found.append(path)
|
||||
|
||||
return found
|
||||
|
||||
|
||||
@@ -45,8 +82,8 @@ WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
|
||||
# The path to the temporary files
|
||||
# In the webui root, on windows it exists by default, on linux it doesn't
|
||||
STATIC_TEMP_PATH = FILE_DIR.joinpath("tmp")
|
||||
TEMP_PATH = TAGS_PATH.joinpath("temp") # Extension specific temp files
|
||||
STATIC_TEMP_PATH = FILE_DIR.joinpath("tmp").absolute()
|
||||
TEMP_PATH = TAGS_PATH.joinpath("temp").absolute() # Extension specific temp files
|
||||
|
||||
# Make sure these folders exist
|
||||
if not TEMP_PATH.exists():
|
||||
|
||||
@@ -2,37 +2,130 @@
|
||||
# to a temporary file to expose it to the javascript side
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
import urllib.parse
|
||||
from asyncio import sleep
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from modules import script_callbacks, sd_hijack, shared
|
||||
from fastapi.responses import FileResponse, JSONResponse, Response
|
||||
from modules import hashes, script_callbacks, sd_hijack, sd_models, shared
|
||||
from pydantic import BaseModel
|
||||
|
||||
from scripts.model_keyword_support import (get_lora_simple_hash,
|
||||
load_hash_cache, update_hash_cache,
|
||||
write_model_keyword_path)
|
||||
from scripts.shared_paths import *
|
||||
|
||||
try:
|
||||
try:
|
||||
from scripts import tag_frequency_db as tdb
|
||||
except ModuleNotFoundError:
|
||||
from inspect import currentframe, getframeinfo
|
||||
filename = getframeinfo(currentframe()).filename
|
||||
parent = Path(filename).resolve().parent
|
||||
sys.path.append(str(parent))
|
||||
import tag_frequency_db as tdb
|
||||
|
||||
# Ensure the db dependency is reloaded on script reload
|
||||
importlib.reload(tdb)
|
||||
|
||||
db = tdb.TagFrequencyDb()
|
||||
if int(db.version) != int(tdb.db_ver):
|
||||
raise ValueError("Database version mismatch")
|
||||
except (ImportError, ValueError, sqlite3.Error) as e:
|
||||
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
|
||||
db = None
|
||||
|
||||
def get_embed_db(sd_model=None):
|
||||
"""Returns the embedding database, if available."""
|
||||
try:
|
||||
return sd_hijack.model_hijack.embedding_db
|
||||
except Exception:
|
||||
try: # sd next with diffusers backend
|
||||
sdnext_model = sd_model if sd_model is not None else shared.sd_model
|
||||
return sdnext_model.embedding_db
|
||||
except Exception:
|
||||
try: # forge webui
|
||||
forge_model = sd_model if sd_model is not None else sd_models.model_data.get_sd_model()
|
||||
if type(forge_model).__name__ == "FakeInitialModel":
|
||||
return None
|
||||
else:
|
||||
processer = getattr(forge_model, "text_processing_engine", getattr(forge_model, "text_processing_engine_l"))
|
||||
return processer.embeddings
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Attempt to get embedding load function, using the same call as api.
|
||||
try:
|
||||
embed_db = get_embed_db()
|
||||
if embed_db is not None:
|
||||
load_textual_inversion_embeddings = embed_db.load_textual_inversion_embeddings
|
||||
else:
|
||||
load_textual_inversion_embeddings = lambda *args, **kwargs: None
|
||||
except Exception as e: # Not supported.
|
||||
load_textual_inversion_embeddings = lambda *args, **kwargs: None
|
||||
print("Tag Autocomplete: Cannot reload embeddings instantly:", e)
|
||||
|
||||
# Sorting functions for extra networks / embeddings stuff
|
||||
sort_criteria = {
|
||||
"Name": lambda path, name, subpath: name.lower() if subpath else path.stem.lower(),
|
||||
"Date Modified (newest first)": lambda path, name, subpath: path.stat().st_mtime if path.exists() else name.lower(),
|
||||
"Date Modified (oldest first)": lambda path, name, subpath: path.stat().st_mtime if path.exists() else name.lower()
|
||||
}
|
||||
|
||||
def sort_models(model_list, sort_method = None, name_has_subpath = False):
|
||||
"""Sorts models according to the setting.
|
||||
|
||||
Input: list of (full_path, display_name, {hash}) models.
|
||||
Returns models in the format of name, sort key, meta.
|
||||
Meta is optional and can be a hash, version string or other required info.
|
||||
"""
|
||||
if len(model_list) == 0:
|
||||
return model_list
|
||||
|
||||
if sort_method is None:
|
||||
sort_method = getattr(shared.opts, "tac_modelSortOrder", "Name")
|
||||
|
||||
# Get sorting method from dictionary
|
||||
sorter = sort_criteria.get(sort_method, sort_criteria["Name"])
|
||||
|
||||
# During merging on the JS side we need to re-sort anyway, so here only the sort criteria are calculated.
|
||||
# The list itself doesn't need to get sorted at this point.
|
||||
if len(model_list[0]) > 2:
|
||||
results = [f'{name},"{sorter(path, name, name_has_subpath)}",{meta}' for path, name, meta in model_list]
|
||||
else:
|
||||
results = [f'{name},"{sorter(path, name, name_has_subpath)}"' for path, name in model_list]
|
||||
return results
|
||||
|
||||
|
||||
def get_wildcards():
|
||||
"""Returns a list of all wildcards. Works on nested folders."""
|
||||
wildcard_files = list(WILDCARD_PATH.rglob("*.txt"))
|
||||
resolved = [w.relative_to(WILDCARD_PATH).as_posix(
|
||||
) for w in wildcard_files if w.name != "put wildcards here.txt"]
|
||||
return resolved
|
||||
resolved = [(w, w.relative_to(WILDCARD_PATH).as_posix())
|
||||
for w in wildcard_files
|
||||
if w.name != "put wildcards here.txt"
|
||||
and w.is_file()]
|
||||
return sort_models(resolved, name_has_subpath=True)
|
||||
|
||||
|
||||
def get_ext_wildcards():
|
||||
"""Returns a list of all extension wildcards. Works on nested folders."""
|
||||
wildcard_files = []
|
||||
|
||||
excluded_folder_names = [s.strip() for s in getattr(shared.opts, "tac_wildcardExclusionList", "").split(",")]
|
||||
for path in WILDCARD_EXT_PATHS:
|
||||
wildcard_files.append(path.as_posix())
|
||||
wildcard_files.extend(p.relative_to(path).as_posix() for p in path.rglob("*.txt") if p.name != "put wildcards here.txt")
|
||||
resolved = [(w, w.relative_to(path).as_posix())
|
||||
for w in path.rglob("*.txt")
|
||||
if w.name != "put wildcards here.txt"
|
||||
and not any(excluded in w.parts for excluded in excluded_folder_names)
|
||||
and w.is_file()]
|
||||
wildcard_files.extend(sort_models(resolved, name_has_subpath=True))
|
||||
wildcard_files.append("-----")
|
||||
|
||||
return wildcard_files
|
||||
@@ -41,16 +134,22 @@ def is_umi_format(data):
|
||||
"""Returns True if the YAML file is in UMI format."""
|
||||
issue_found = False
|
||||
for item in data:
|
||||
if not (data[item] and 'Tags' in data[item] and isinstance(data[item]['Tags'], list)):
|
||||
try:
|
||||
if not (data[item] and 'Tags' in data[item] and isinstance(data[item]['Tags'], list)):
|
||||
issue_found = True
|
||||
break
|
||||
except:
|
||||
issue_found = True
|
||||
break
|
||||
return not issue_found
|
||||
|
||||
def parse_umi_format(umi_tags, count, data):
|
||||
count = 0
|
||||
def parse_umi_format(umi_tags, data):
|
||||
global count
|
||||
for item in data:
|
||||
umi_tags[count] = ','.join(data[item]['Tags'])
|
||||
count += 1
|
||||
|
||||
|
||||
|
||||
def parse_dynamic_prompt_format(yaml_wildcards, data, path):
|
||||
# Recurse subkeys, delete those without string lists as values
|
||||
@@ -60,23 +159,25 @@ def parse_dynamic_prompt_format(yaml_wildcards, data, path):
|
||||
recurse_dict(value)
|
||||
elif not (isinstance(value, list) and all(isinstance(v, str) for v in value)):
|
||||
del d[key]
|
||||
|
||||
recurse_dict(data)
|
||||
# Add to yaml_wildcards
|
||||
yaml_wildcards[path.name] = data
|
||||
|
||||
try:
|
||||
recurse_dict(data)
|
||||
# Add to yaml_wildcards
|
||||
yaml_wildcards[path.name] = data
|
||||
except:
|
||||
return
|
||||
|
||||
|
||||
def get_yaml_wildcards():
|
||||
"""Returns a list of all tags found in extension YAML files found under a Tags: key."""
|
||||
yaml_files = []
|
||||
for path in WILDCARD_EXT_PATHS:
|
||||
yaml_files.extend(p for p in path.rglob("*.yml"))
|
||||
yaml_files.extend(p for p in path.rglob("*.yaml"))
|
||||
yaml_files.extend(p for p in path.rglob("*.yml") if p.is_file())
|
||||
yaml_files.extend(p for p in path.rglob("*.yaml") if p.is_file())
|
||||
|
||||
yaml_wildcards = {}
|
||||
|
||||
umi_tags = {} # { tag: count }
|
||||
count = 0
|
||||
|
||||
for path in yaml_files:
|
||||
try:
|
||||
@@ -84,21 +185,25 @@ def get_yaml_wildcards():
|
||||
data = yaml.safe_load(file)
|
||||
if (data):
|
||||
if (is_umi_format(data)):
|
||||
parse_umi_format(umi_tags, count, data)
|
||||
parse_umi_format(umi_tags, data)
|
||||
else:
|
||||
parse_dynamic_prompt_format(yaml_wildcards, data, path)
|
||||
else:
|
||||
print('No data found in ' + path.name)
|
||||
except yaml.YAMLError:
|
||||
print('Issue in parsing YAML file ' + path.name)
|
||||
except (yaml.YAMLError, UnicodeDecodeError, AttributeError, TypeError) as e:
|
||||
# YAML file not in wildcard format or couldn't be read
|
||||
print(f'Issue in parsing YAML file {path.name}: {e}')
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# Something else went wrong, just skip
|
||||
continue
|
||||
|
||||
# Sort by count
|
||||
umi_sorted = sorted(umi_tags.items(), key=lambda item: item[1], reverse=True)
|
||||
umi_output = []
|
||||
for tag, count in umi_sorted:
|
||||
umi_output.append(f"{tag},{count}")
|
||||
|
||||
|
||||
if (len(umi_output) > 0):
|
||||
write_to_temp_file('umi_tags.txt', umi_output)
|
||||
|
||||
@@ -112,48 +217,59 @@ def get_embeddings(sd_model):
|
||||
# Version constants
|
||||
V1_SHAPE = 768
|
||||
V2_SHAPE = 1024
|
||||
VXL_SHAPE = 2048
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
emb_vXL = []
|
||||
emb_unknown = []
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
# Get the shape of the first item in the dict
|
||||
emb_a_shape = -1
|
||||
emb_b_shape = -1
|
||||
if (len(emb_type_a) > 0):
|
||||
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
|
||||
if (len(emb_type_b) > 0):
|
||||
emb_b_shape = next(iter(emb_type_b.items()))[1].shape
|
||||
embed_db = get_embed_db(sd_model)
|
||||
# Re-register callback if needed
|
||||
global load_textual_inversion_embeddings
|
||||
if embed_db is not None and load_textual_inversion_embeddings != embed_db.load_textual_inversion_embeddings:
|
||||
load_textual_inversion_embeddings = embed_db.load_textual_inversion_embeddings
|
||||
|
||||
loaded = embed_db.word_embeddings
|
||||
skipped = embed_db.skipped_embeddings
|
||||
|
||||
# Add embeddings to the correct list
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_a.keys())
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_a.keys())
|
||||
for key, emb in (skipped | loaded).items():
|
||||
filename = getattr(emb, "filename", None)
|
||||
|
||||
if filename is None:
|
||||
if emb.shape is None:
|
||||
emb_unknown.append((Path(key), key, ""))
|
||||
elif emb.shape == V1_SHAPE:
|
||||
emb_v1.append((Path(key), key, "v1"))
|
||||
elif emb.shape == V2_SHAPE:
|
||||
emb_v2.append((Path(key), key, "v2"))
|
||||
elif emb.shape == VXL_SHAPE:
|
||||
emb_vXL.append((Path(key), key, "vXL"))
|
||||
else:
|
||||
emb_unknown.append((Path(key), key, ""))
|
||||
|
||||
else:
|
||||
if emb.filename is None:
|
||||
continue
|
||||
|
||||
if (emb_b_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_b.keys())
|
||||
elif (emb_b_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_b.keys())
|
||||
if emb.shape is None:
|
||||
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
|
||||
elif emb.shape == V1_SHAPE:
|
||||
emb_v1.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v1"))
|
||||
elif emb.shape == V2_SHAPE:
|
||||
emb_v2.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "v2"))
|
||||
elif emb.shape == VXL_SHAPE:
|
||||
emb_vXL.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), "vXL"))
|
||||
else:
|
||||
emb_unknown.append((Path(emb.filename), Path(emb.filename).relative_to(EMB_PATH).as_posix(), ""))
|
||||
|
||||
# Get shape of current model
|
||||
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
#model_shape = vec.shape[1]
|
||||
# Show relevant entries at the top
|
||||
#if (model_shape == V1_SHAPE):
|
||||
# results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2]
|
||||
#elif (model_shape == V2_SHAPE):
|
||||
# results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
|
||||
#else:
|
||||
# raise AttributeError # Fallback to old method
|
||||
results = sorted([e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2], key=lambda x: x.lower())
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL) + sort_models(emb_unknown)
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
|
||||
# Get a list of all embeddings in the folder
|
||||
all_embeds = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.rglob("*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
|
||||
all_embeds = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.rglob("*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'} and e.is_file()]
|
||||
# Remove files with a size of 0
|
||||
all_embeds = [e for e in all_embeds if EMB_PATH.joinpath(e).stat().st_size > 0]
|
||||
# Remove file extensions
|
||||
@@ -167,53 +283,129 @@ def get_hypernetworks():
|
||||
|
||||
# Get a list of all hypernetworks in the folder
|
||||
hyp_paths = [Path(h) for h in glob.glob(HYP_PATH.joinpath("**/*").as_posix(), recursive=True)]
|
||||
all_hypernetworks = [str(h.name) for h in hyp_paths if h.suffix in {".pt"}]
|
||||
# Remove file extensions
|
||||
return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower())
|
||||
all_hypernetworks = [(h, h.stem) for h in hyp_paths if h.suffix in {".pt"} and h.is_file()]
|
||||
return sort_models(all_hypernetworks)
|
||||
|
||||
model_keyword_installed = write_model_keyword_path()
|
||||
|
||||
|
||||
def _get_lora():
|
||||
"""
|
||||
Write a list of all lora.
|
||||
Fallback method for when the built-in Lora.networks module is not available.
|
||||
"""
|
||||
# Get a list of all lora in the folder
|
||||
lora_paths = [
|
||||
Path(l)
|
||||
for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)
|
||||
]
|
||||
# Get hashes
|
||||
valid_loras = [
|
||||
lf
|
||||
for lf in lora_paths
|
||||
if lf.suffix in {".safetensors", ".ckpt", ".pt"} and lf.is_file()
|
||||
]
|
||||
|
||||
return valid_loras
|
||||
|
||||
|
||||
def _get_lyco():
|
||||
"""
|
||||
Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris
|
||||
Fallback method for when the built-in Lora.networks module is not available.
|
||||
"""
|
||||
# Get a list of all LyCORIS in the folder
|
||||
lyco_paths = [
|
||||
Path(ly)
|
||||
for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)
|
||||
]
|
||||
|
||||
# Get hashes
|
||||
valid_lycos = [
|
||||
lyf
|
||||
for lyf in lyco_paths
|
||||
if lyf.suffix in {".safetensors", ".ckpt", ".pt"} and lyf.is_file()
|
||||
]
|
||||
return valid_lycos
|
||||
|
||||
|
||||
# Attempt to use the build-in Lora.networks Lora/LyCORIS models lists.
|
||||
try:
|
||||
import sys
|
||||
from modules import extensions
|
||||
sys.path.append(Path(extensions.extensions_builtin_dir).joinpath("Lora").as_posix())
|
||||
import lora # pyright: ignore [reportMissingImports]
|
||||
|
||||
def _get_lora():
|
||||
return [
|
||||
Path(model.filename).absolute()
|
||||
for model in lora.available_loras.values()
|
||||
if Path(model.filename).absolute().is_relative_to(LORA_PATH)
|
||||
]
|
||||
|
||||
def _get_lyco():
|
||||
return [
|
||||
Path(model.filename).absolute()
|
||||
for model in lora.available_loras.values()
|
||||
if Path(model.filename).absolute().is_relative_to(LYCO_PATH)
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
# no need to report
|
||||
# print(f'Exception setting-up performant fetchers: {e}')
|
||||
|
||||
|
||||
def is_visible(p: Path) -> bool:
|
||||
if getattr(shared.opts, "extra_networks_hidden_models", "When searched") != "Never":
|
||||
return True
|
||||
for part in p.parts:
|
||||
if part.startswith('.'):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_lora():
|
||||
"""Write a list of all lora"""
|
||||
global model_keyword_installed
|
||||
|
||||
# Get a list of all lora in the folder
|
||||
lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)]
|
||||
# Get hashes
|
||||
valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
hashes = {}
|
||||
valid_loras = _get_lora()
|
||||
loras_with_hash = []
|
||||
for l in valid_loras:
|
||||
if not l.exists() or not l.is_file() or not is_visible(l):
|
||||
continue
|
||||
name = l.relative_to(LORA_PATH).as_posix()
|
||||
if model_keyword_installed:
|
||||
hashes[name] = get_lora_simple_hash(l)
|
||||
hash = get_lora_simple_hash(l)
|
||||
else:
|
||||
hashes[name] = ""
|
||||
hash = ""
|
||||
loras_with_hash.append((l, name, hash))
|
||||
# Sort
|
||||
sorted_loras = dict(sorted(hashes.items()))
|
||||
# Add hashes and return
|
||||
return [f"\"{name}\",{hash}" for name, hash in sorted_loras.items()]
|
||||
return sort_models(loras_with_hash)
|
||||
|
||||
|
||||
def get_lyco():
|
||||
"""Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris"""
|
||||
|
||||
# Get a list of all LyCORIS in the folder
|
||||
lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)]
|
||||
|
||||
# Get hashes
|
||||
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
hashes = {}
|
||||
valid_lycos = _get_lyco()
|
||||
lycos_with_hash = []
|
||||
for ly in valid_lycos:
|
||||
if not ly.exists() or not ly.is_file() or not is_visible(ly):
|
||||
continue
|
||||
name = ly.relative_to(LYCO_PATH).as_posix()
|
||||
if model_keyword_installed:
|
||||
hashes[name] = get_lora_simple_hash(ly)
|
||||
hash = get_lora_simple_hash(ly)
|
||||
else:
|
||||
hashes[name] = ""
|
||||
|
||||
hash = ""
|
||||
lycos_with_hash.append((ly, name, hash))
|
||||
# Sort
|
||||
sorted_lycos = dict(sorted(hashes.items()))
|
||||
# Add hashes and return
|
||||
return [f"\"{name}\",{hash}" for name, hash in sorted_lycos.items()]
|
||||
return sort_models(lycos_with_hash)
|
||||
|
||||
def get_style_names():
|
||||
try:
|
||||
style_names: list[str] = shared.prompt_styles.styles.keys()
|
||||
style_names = sorted(style_names, key=len, reverse=True)
|
||||
return style_names
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def write_tag_base_path():
|
||||
"""Writes the tag base path to a fixed location temporary file"""
|
||||
@@ -229,19 +421,19 @@ def write_to_temp_file(name, data):
|
||||
|
||||
csv_files = []
|
||||
csv_files_withnone = []
|
||||
def update_tag_files():
|
||||
def update_tag_files(*args, **kwargs):
|
||||
"""Returns a list of all potential tag files"""
|
||||
global csv_files, csv_files_withnone
|
||||
files = [str(t.relative_to(TAGS_PATH)) for t in TAGS_PATH.glob("*.csv")]
|
||||
files = [str(t.relative_to(TAGS_PATH)) for t in TAGS_PATH.glob("*.csv") if t.is_file()]
|
||||
csv_files = files
|
||||
csv_files_withnone = ["None"] + files
|
||||
|
||||
json_files = []
|
||||
json_files_withnone = []
|
||||
def update_json_files():
|
||||
def update_json_files(*args, **kwargs):
|
||||
"""Returns a list of all potential json files"""
|
||||
global json_files, json_files_withnone
|
||||
files = [str(j.relative_to(TAGS_PATH)) for j in TAGS_PATH.glob("*.json")]
|
||||
files = [str(j.relative_to(TAGS_PATH)) for j in TAGS_PATH.glob("*.json") if j.is_file()]
|
||||
json_files = files
|
||||
json_files_withnone = ["None"] + files
|
||||
|
||||
@@ -268,6 +460,7 @@ write_to_temp_file('umi_tags.txt', [])
|
||||
write_to_temp_file('hyp.txt', [])
|
||||
write_to_temp_file('lora.txt', [])
|
||||
write_to_temp_file('lyco.txt', [])
|
||||
write_to_temp_file('styles.txt', [])
|
||||
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
|
||||
if not TEMP_PATH.joinpath("emb.txt").exists():
|
||||
write_to_temp_file('emb.txt', [])
|
||||
@@ -277,28 +470,59 @@ if EMB_PATH.exists():
|
||||
# Get embeddings after the model loaded callback
|
||||
script_callbacks.on_model_loaded(get_embeddings)
|
||||
|
||||
def refresh_temp_files():
|
||||
global WILDCARD_EXT_PATHS
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
write_temp_files()
|
||||
get_embeddings(shared.sd_model)
|
||||
def refresh_embeddings(force: bool, *args, **kwargs):
|
||||
try:
|
||||
# Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand.
|
||||
# This will just skip embedding loading if no model is loaded yet (or there really are no embeddings).
|
||||
# Try catch is just for safety incase sd_hijack access fails for some reason.
|
||||
embed_db = get_embed_db()
|
||||
if embed_db is None:
|
||||
return
|
||||
loaded = embed_db.word_embeddings
|
||||
skipped = embed_db.skipped_embeddings
|
||||
if len((loaded | skipped)) > 0:
|
||||
load_textual_inversion_embeddings(force_reload=force)
|
||||
get_embeddings(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def write_temp_files():
|
||||
def refresh_temp_files(*args, **kwargs):
|
||||
global WILDCARD_EXT_PATHS
|
||||
skip_wildcard_refresh = getattr(shared.opts, "tac_skipWildcardRefresh", False)
|
||||
if skip_wildcard_refresh:
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
write_temp_files(skip_wildcard_refresh)
|
||||
force_embed_refresh = getattr(shared.opts, "tac_forceRefreshEmbeddings", False)
|
||||
refresh_embeddings(force=force_embed_refresh)
|
||||
|
||||
def write_style_names(*args, **kwargs):
|
||||
styles = get_style_names()
|
||||
if styles:
|
||||
write_to_temp_file('styles.txt', styles)
|
||||
|
||||
def write_temp_files(skip_wildcard_refresh = False):
|
||||
# Write wildcards to wc.txt if found
|
||||
if WILDCARD_PATH.exists():
|
||||
wildcards = [WILDCARD_PATH.relative_to(FILE_DIR).as_posix()] + get_wildcards()
|
||||
if WILDCARD_PATH.exists() and not skip_wildcard_refresh:
|
||||
try:
|
||||
# Attempt to create a relative path, but fall back to an absolute path if not possible
|
||||
relative_wildcard_path = WILDCARD_PATH.relative_to(FILE_DIR).as_posix()
|
||||
except ValueError:
|
||||
# If the paths are not relative, use the absolute path
|
||||
relative_wildcard_path = WILDCARD_PATH.as_posix()
|
||||
|
||||
wildcards = [relative_wildcard_path] + get_wildcards()
|
||||
if wildcards:
|
||||
write_to_temp_file('wc.txt', wildcards)
|
||||
|
||||
# Write extension wildcards to wce.txt if found
|
||||
if WILDCARD_EXT_PATHS is not None:
|
||||
if WILDCARD_EXT_PATHS is not None and not skip_wildcard_refresh:
|
||||
wildcards_ext = get_ext_wildcards()
|
||||
if wildcards_ext:
|
||||
write_to_temp_file('wce.txt', wildcards_ext)
|
||||
# Write yaml extension wildcards to umi_tags.txt and wc_yaml.json if found
|
||||
get_yaml_wildcards()
|
||||
|
||||
if HYP_PATH.exists():
|
||||
if HYP_PATH is not None and HYP_PATH.exists():
|
||||
hypernets = get_hypernetworks()
|
||||
if hypernets:
|
||||
write_to_temp_file('hyp.txt', hypernets)
|
||||
@@ -311,7 +535,7 @@ def write_temp_files():
|
||||
lora = get_lora()
|
||||
if lora:
|
||||
write_to_temp_file('lora.txt', lora)
|
||||
|
||||
|
||||
lyco_exists = LYCO_PATH is not None and LYCO_PATH.exists()
|
||||
if lyco_exists and not (lora_exists and LYCO_PATH.samefile(LORA_PATH)):
|
||||
lyco = get_lyco()
|
||||
@@ -323,6 +547,8 @@ def write_temp_files():
|
||||
if model_keyword_installed:
|
||||
update_hash_cache()
|
||||
|
||||
if shared.prompt_styles is not None:
|
||||
write_style_names()
|
||||
|
||||
write_temp_files()
|
||||
|
||||
@@ -342,6 +568,13 @@ def on_ui_settings():
|
||||
return self
|
||||
shared.OptionInfo.needs_restart = needs_restart
|
||||
|
||||
# Dictionary of function options and their explanations
|
||||
frequency_sort_functions = {
|
||||
"Logarithmic (weak)": "Will respect the base order and slightly prefer often used tags",
|
||||
"Logarithmic (strong)": "Same as Logarithmic (weak), but with a stronger bias",
|
||||
"Usage first": "Will list used tags by frequency before all others",
|
||||
}
|
||||
|
||||
tac_options = {
|
||||
# Main tag file
|
||||
"tac_tagFile": shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files),
|
||||
@@ -361,19 +594,35 @@ def on_ui_settings():
|
||||
"tac_delayTime": shared.OptionInfo(100, "Time in ms to wait before triggering completion again").needs_restart(),
|
||||
"tac_useWildcards": shared.OptionInfo(True, "Search for wildcards"),
|
||||
"tac_sortWildcardResults": shared.OptionInfo(True, "Sort wildcard file contents alphabetically").info("If your wildcard files have a specific custom order, disable this to keep it"),
|
||||
"tac_wildcardExclusionList": shared.OptionInfo("", "Wildcard folder exclusion list").info("Add folder names that shouldn't be searched for wildcards, separated by comma.").needs_restart(),
|
||||
"tac_skipWildcardRefresh": shared.OptionInfo(False, "Don't re-scan for wildcard files when pressing the extra networks refresh button").info("Useful to prevent hanging if you use a very large wildcard collection."),
|
||||
"tac_useEmbeddings": shared.OptionInfo(True, "Search for embeddings"),
|
||||
"tac_forceRefreshEmbeddings": shared.OptionInfo(False, "Force refresh embeddings when pressing the extra networks refresh button").info("Turn this on if you have issues with new embeddings not registering correctly in TAC. Warning: Seems to cause reloading issues in gradio for some users."),
|
||||
"tac_includeEmbeddingsInNormalResults": shared.OptionInfo(False, "Include embeddings in normal tag results").info("The 'JumpTo...' keybinds (End & Home key by default) will select the first non-embedding result of their direction on the first press for quick navigation in longer lists."),
|
||||
"tac_useHypernetworks": shared.OptionInfo(True, "Search for hypernetworks"),
|
||||
"tac_useLoras": shared.OptionInfo(True, "Search for Loras"),
|
||||
"tac_useLycos": shared.OptionInfo(True, "Search for LyCORIS/LoHa"),
|
||||
"tac_useLoraPrefixForLycos": shared.OptionInfo(True, "Use the '<lora:' prefix instead of '<lyco:' for models in the LyCORIS folder").info("The lyco prefix is included for backwards compatibility and not used anymore by default. Disable this if you are on an old webui version without built-in lyco support."),
|
||||
"tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"),
|
||||
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
|
||||
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": list(sort_criteria.keys())}).info("Order for extra network models and wildcards in dropdown"),
|
||||
"tac_useStyleVars": shared.OptionInfo(False, "Search for webui style names").info("Suggests style names from the webui dropdown with '$'. Currently requires a secondary extension like <a href=\"https://github.com/SirVeggie/extension-style-vars\" target=\"_blank\">style-vars</a> to actually apply the styles before generating."),
|
||||
# Frequency sorting settings
|
||||
"tac_frequencySort": shared.OptionInfo(True, "Locally record tag usage and sort frequent tags higher").info("Will also work for extra networks, keeping the specified base order"),
|
||||
"tac_frequencyFunction": shared.OptionInfo("Logarithmic (weak)", "Function to use for frequency sorting", gr.Dropdown, lambda: {"choices": list(frequency_sort_functions.keys())}).info("; ".join([f'<b>{key}</b>: {val}' for key, val in frequency_sort_functions.items()])),
|
||||
"tac_frequencyMinCount": shared.OptionInfo(3, "Minimum number of uses for a tag to be considered frequent").info("Tags with less uses than this will not be sorted higher, even if the sorting function would normally result in a higher position."),
|
||||
"tac_frequencyMaxAge": shared.OptionInfo(30, "Maximum days since last use for a tag to be considered frequent").info("Similar to the above, tags that haven't been used in this many days will not be sorted higher. Set to 0 to disable."),
|
||||
"tac_frequencyRecommendCap": shared.OptionInfo(10, "Maximum number of recommended tags").info("Limits the maximum number of recommended tags to not drown out normal results. Set to 0 to disable."),
|
||||
"tac_frequencyIncludeAlias": shared.OptionInfo(False, "Frequency sorting matches aliases for frequent tags").info("Tag frequency will be increased for the main tag even if an alias is used for completion. This option can be used to override the default behavior of alias results being ignored for frequency sorting."),
|
||||
# Insertion related settings
|
||||
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
|
||||
"tac_undersocreReplacementExclusionList": shared.OptionInfo("0_0,(o)_(o),+_+,+_-,._.,<o>_<o>,<|>_<|>,=_=,>_<,3_3,6_9,>_o,@_@,^_^,o_o,u_u,x_x,|_|,||_||", "Underscore replacement exclusion list").info("Add tags that shouldn't have underscores replaced with spaces, separated by comma."),
|
||||
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
|
||||
"tac_appendComma": shared.OptionInfo(True, "Append comma on tag autocompletion"),
|
||||
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
|
||||
"tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"),
|
||||
"tac_modelKeywordCompletion": shared.OptionInfo("Never", "Try to add known trigger words for LORA/LyCO models", gr.Dropdown, lambda: {"choices": ["Never","Only user list","Always"]}).info("Will use & prefer the native activation keywords settable in the extra networks UI. Other functionality requires the <a href=\"https://github.com/mix1009/model-keyword\" target=\"_blank\">model-keyword</a> extension to be installed, but will work with it disabled.").needs_restart(),
|
||||
"tac_modelKeywordLocation": shared.OptionInfo("Start of prompt", "Where to insert the trigger keyword", gr.Dropdown, lambda: {"choices": ["Start of prompt","End of prompt","Before LORA/LyCO"]}).info("Only relevant if the above option is enabled"),
|
||||
"tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""),
|
||||
# Alias settings
|
||||
"tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),
|
||||
@@ -430,6 +679,37 @@ def on_ui_settings():
|
||||
"6": ["red", "maroon"],
|
||||
"7": ["whitesmoke", "black"],
|
||||
"8": ["seagreen", "darkseagreen"]
|
||||
},
|
||||
"derpibooru": {
|
||||
"-1": ["red", "maroon"],
|
||||
"0": ["#60d160", "#3d9d3d"],
|
||||
"1": ["#fff956", "#918e2e"],
|
||||
"3": ["#fd9961", "#a14c2e"],
|
||||
"4": ["#cf5bbe", "#6c1e6c"],
|
||||
"5": ["#3c8ad9", "#1e5e93"],
|
||||
"6": ["#a6a6a6", "#555555"],
|
||||
"7": ["#47abc1", "#1f6c7c"],
|
||||
"8": ["#7871d0", "#392f7d"],
|
||||
"9": ["#df3647", "#8e1c2b"],
|
||||
"10": ["#c98f2b", "#7b470e"],
|
||||
"11": ["#e87ebe", "#a83583"]
|
||||
},
|
||||
"danbooru_e621_merged": {
|
||||
"-1": ["red", "maroon"],
|
||||
"0": ["lightblue", "dodgerblue"],
|
||||
"1": ["indianred", "firebrick"],
|
||||
"3": ["violet", "darkorchid"],
|
||||
"4": ["lightgreen", "darkgreen"],
|
||||
"5": ["orange", "darkorange"],
|
||||
"6": ["red", "maroon"],
|
||||
"7": ["lightblue", "dodgerblue"],
|
||||
"8": ["gold", "goldenrod"],
|
||||
"9": ["gold", "goldenrod"],
|
||||
"10": ["violet", "darkorchid"],
|
||||
"11": ["lightgreen", "darkgreen"],
|
||||
"12": ["tomato", "darksalmon"],
|
||||
"14": ["whitesmoke", "black"],
|
||||
"15": ["seagreen", "darkseagreen"]
|
||||
}
|
||||
}\
|
||||
"""
|
||||
@@ -444,53 +724,188 @@ def on_ui_settings():
|
||||
shared.opts.add_option("tac_colormap", shared.OptionInfo(colorDefault, colorLabel, gr.Textbox, section=TAC_SECTION))
|
||||
|
||||
shared.opts.add_option("tac_refreshTempFiles", shared.OptionInfo("Refresh TAC temp files", "Refresh internal temp files", gr.HTML, {}, refresh=refresh_temp_files, section=TAC_SECTION))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
||||
def get_style_mtime():
|
||||
try:
|
||||
style_file = getattr(shared, "styles_filename", "styles.csv")
|
||||
# Check in case a list is returned
|
||||
if isinstance(style_file, list):
|
||||
style_file = style_file[0]
|
||||
|
||||
style_file = Path(FILE_DIR).joinpath(style_file)
|
||||
if Path.exists(style_file):
|
||||
return style_file.stat().st_mtime
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
last_style_mtime = get_style_mtime()
|
||||
|
||||
def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
async def get_json_info(base_path: Path, filename: str = None):
|
||||
if base_path is None or (not base_path.exists()):
|
||||
return json.dumps({})
|
||||
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
json_candidates = glob.glob(base_path.as_posix() + f"/**/{filename}.json", recursive=True)
|
||||
if json_candidates is not None and len(json_candidates) > 0:
|
||||
if json_candidates is not None and len(json_candidates) > 0 and Path(json_candidates[0]).is_file():
|
||||
return FileResponse(json_candidates[0])
|
||||
except Exception as e:
|
||||
return json.dumps({"error": e})
|
||||
|
||||
async def get_preview_thumbnail(base_path: Path, filename: str = None):
|
||||
return JSONResponse({"error": e}, status_code=500)
|
||||
|
||||
async def get_preview_thumbnail(base_path: Path, filename: str = None, blob: bool = False):
|
||||
if base_path is None or (not base_path.exists()):
|
||||
return json.dumps({})
|
||||
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
img_glob = glob.glob(base_path.as_posix() + f"/**/{filename}.*", recursive=True)
|
||||
img_candidates = [img for img in img_glob if Path(img).suffix in [".png", ".jpg", ".jpeg", ".webp"]]
|
||||
img_candidates = [img for img in img_glob if Path(img).suffix in [".png", ".jpg", ".jpeg", ".webp", ".gif"] and Path(img).is_file()]
|
||||
if img_candidates is not None and len(img_candidates) > 0:
|
||||
return JSONResponse({"url": urllib.parse.quote(img_candidates[0])})
|
||||
if blob:
|
||||
return FileResponse(img_candidates[0])
|
||||
else:
|
||||
return JSONResponse({"url": urllib.parse.quote(img_candidates[0])})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": e})
|
||||
return JSONResponse({"error": e}, status_code=500)
|
||||
|
||||
@app.post("/tacapi/v1/refresh-temp-files")
|
||||
async def api_refresh_temp_files():
|
||||
await sleep(0) # might help with refresh blocking gradio
|
||||
refresh_temp_files()
|
||||
|
||||
@app.post("/tacapi/v1/refresh-embeddings")
|
||||
async def api_refresh_embeddings():
|
||||
refresh_embeddings(force=False)
|
||||
|
||||
@app.get("/tacapi/v1/lora-info/{lora_name}")
|
||||
async def get_lora_info(lora_name):
|
||||
return await get_json_info(LORA_PATH, lora_name)
|
||||
|
||||
|
||||
@app.get("/tacapi/v1/lyco-info/{lyco_name}")
|
||||
async def get_lyco_info(lyco_name):
|
||||
return await get_json_info(LYCO_PATH, lyco_name)
|
||||
|
||||
|
||||
@app.get("/tacapi/v1/lora-cached-hash/{lora_name}")
|
||||
async def get_lora_cached_hash(lora_name: str):
|
||||
path_glob = glob.glob(LORA_PATH.as_posix() + f"/**/{lora_name}.*", recursive=True)
|
||||
paths = [lora for lora in path_glob if Path(lora).suffix in [".safetensors", ".ckpt", ".pt"] and Path(lora).is_file()]
|
||||
if paths is not None and len(paths) > 0:
|
||||
path = paths[0]
|
||||
hash = hashes.sha256_from_cache(path, f"lora/{lora_name}", path.endswith(".safetensors"))
|
||||
if hash is not None:
|
||||
return hash
|
||||
|
||||
return None
|
||||
|
||||
def get_path_for_type(type):
|
||||
if type == "lora":
|
||||
return LORA_PATH
|
||||
elif type == "lyco":
|
||||
return LYCO_PATH
|
||||
elif type == "hypernetwork":
|
||||
return HYP_PATH
|
||||
elif type == "embedding":
|
||||
return EMB_PATH
|
||||
else:
|
||||
return None
|
||||
|
||||
@app.get("/tacapi/v1/thumb-preview/{filename}")
|
||||
async def get_thumb_preview(filename, type):
|
||||
if type == "lora":
|
||||
return await get_preview_thumbnail(LORA_PATH, filename)
|
||||
elif type == "lyco":
|
||||
return await get_preview_thumbnail(LYCO_PATH, filename)
|
||||
elif type == "hyper":
|
||||
return await get_preview_thumbnail(HYP_PATH, filename)
|
||||
elif type == "embed":
|
||||
return await get_preview_thumbnail(EMB_PATH, filename)
|
||||
return await get_preview_thumbnail(get_path_for_type(type), filename, False)
|
||||
|
||||
@app.get("/tacapi/v1/thumb-preview-blob/{filename}")
|
||||
async def get_thumb_preview_blob(filename, type):
|
||||
return await get_preview_thumbnail(get_path_for_type(type), filename, True)
|
||||
|
||||
@app.get("/tacapi/v1/wildcard-contents")
|
||||
async def get_wildcard_contents(basepath: str, filename: str):
|
||||
if basepath is None or basepath == "":
|
||||
return Response(status_code=404)
|
||||
|
||||
base = Path(basepath)
|
||||
if base is None or (not base.exists()):
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
wildcard_path = base.joinpath(filename)
|
||||
if wildcard_path.exists() and wildcard_path.is_file():
|
||||
return FileResponse(wildcard_path)
|
||||
else:
|
||||
return Response(status_code=404)
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": e}, status_code=500)
|
||||
|
||||
@app.get("/tacapi/v1/refresh-styles-if-changed")
|
||||
async def refresh_styles_if_changed():
|
||||
global last_style_mtime
|
||||
|
||||
mtime = get_style_mtime()
|
||||
if mtime is not None and mtime > last_style_mtime:
|
||||
last_style_mtime = mtime
|
||||
# Update temp file
|
||||
if shared.prompt_styles is not None:
|
||||
write_style_names()
|
||||
|
||||
return Response(status_code=200) # Success
|
||||
else:
|
||||
return "Invalid type"
|
||||
return Response(status_code=304) # Not modified
|
||||
def db_request(func, get = False):
|
||||
if db is not None:
|
||||
try:
|
||||
if get:
|
||||
ret = func()
|
||||
if ret is list:
|
||||
ret = [{"name": t[0], "type": t[1], "count": t[2], "lastUseDate": t[3]} for t in ret]
|
||||
return JSONResponse({"result": ret})
|
||||
else:
|
||||
func()
|
||||
except sqlite3.Error as e:
|
||||
return JSONResponse({"error": e.__cause__}, status_code=500)
|
||||
else:
|
||||
return JSONResponse({"error": "Database not initialized"}, status_code=500)
|
||||
|
||||
@app.post("/tacapi/v1/increase-use-count")
|
||||
async def increase_use_count(tagname: str, ttype: int, neg: bool):
|
||||
db_request(lambda: db.increase_tag_count(tagname, ttype, neg))
|
||||
|
||||
script_callbacks.on_app_started(api_tac)
|
||||
@app.get("/tacapi/v1/get-use-count")
|
||||
async def get_use_count(tagname: str, ttype: int, neg: bool):
|
||||
return db_request(lambda: db.get_tag_count(tagname, ttype, neg), get=True)
|
||||
|
||||
# Small dataholder class
|
||||
class UseCountListRequest(BaseModel):
|
||||
tagNames: list[str]
|
||||
tagTypes: list[int]
|
||||
neg: bool = False
|
||||
|
||||
# Semantically weird to use post here, but it's required for the body on js side
|
||||
@app.post("/tacapi/v1/get-use-count-list")
|
||||
async def get_use_count_list(body: UseCountListRequest):
|
||||
# If a date limit is set > 0, pass it to the db
|
||||
date_limit = getattr(shared.opts, "tac_frequencyMaxAge", 30)
|
||||
date_limit = date_limit if date_limit > 0 else None
|
||||
|
||||
if db:
|
||||
count_list = list(db.get_tag_counts(body.tagNames, body.tagTypes, body.neg, date_limit))
|
||||
else:
|
||||
count_list = None
|
||||
|
||||
# If a limit is set, return at max the top n results by count
|
||||
if count_list and len(count_list):
|
||||
limit = int(min(getattr(shared.opts, "tac_frequencyRecommendCap", 10), len(count_list)))
|
||||
# Sort by count and return the top n
|
||||
if limit > 0:
|
||||
count_list = sorted(count_list, key=lambda x: x[2], reverse=True)[:limit]
|
||||
|
||||
return db_request(lambda: count_list, get=True)
|
||||
|
||||
@app.put("/tacapi/v1/reset-use-count")
|
||||
async def reset_use_count(tagname: str, ttype: int, pos: bool, neg: bool):
|
||||
db_request(lambda: db.reset_tag_count(tagname, ttype, pos, neg))
|
||||
|
||||
@app.get("/tacapi/v1/get-all-use-counts")
|
||||
async def get_all_tag_counts():
|
||||
return db_request(lambda: db.get_all_tags(), get=True)
|
||||
|
||||
script_callbacks.on_app_started(api_tac)
|
||||
|
||||
190
scripts/tag_frequency_db.py
Normal file
190
scripts/tag_frequency_db.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
|
||||
from scripts.shared_paths import TAGS_PATH
|
||||
|
||||
db_file = TAGS_PATH.joinpath("tag_frequency.db")
|
||||
timeout = 30
|
||||
db_ver = 1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction(db=db_file):
|
||||
"""Context manager for database transactions.
|
||||
Ensures that the connection is properly closed after the transaction.
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(db, timeout=timeout)
|
||||
|
||||
conn.isolation_level = None
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN")
|
||||
yield cursor
|
||||
cursor.execute("COMMIT")
|
||||
except sqlite3.Error as e:
|
||||
print("Tag Autocomplete: Frequency database error:", e)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
class TagFrequencyDb:
|
||||
"""Class containing creation and interaction methods for the tag frequency database"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.version = self.__check()
|
||||
|
||||
def __check(self):
|
||||
if not db_file.exists():
|
||||
print("Tag Autocomplete: Creating frequency database")
|
||||
with transaction() as cursor:
|
||||
self.__create_db(cursor)
|
||||
self.__update_db_data(cursor, "version", db_ver)
|
||||
print("Tag Autocomplete: Database successfully created")
|
||||
|
||||
return self.__get_version()
|
||||
|
||||
def __create_db(self, cursor: sqlite3.Cursor):
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS db_data (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tag_frequency (
|
||||
name TEXT NOT NULL,
|
||||
type INT NOT NULL,
|
||||
count_pos INT,
|
||||
count_neg INT,
|
||||
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (name, type)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def __update_db_data(self, cursor: sqlite3.Cursor, key, value):
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE
|
||||
INTO db_data (key, value)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(key, value),
|
||||
)
|
||||
|
||||
def __get_version(self):
|
||||
db_version = None
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM db_data
|
||||
WHERE key = 'version'
|
||||
"""
|
||||
)
|
||||
db_version = cursor.fetchone()
|
||||
|
||||
return db_version[0] if db_version else 0
|
||||
|
||||
def get_all_tags(self):
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT name, type, count_pos, count_neg, last_used
|
||||
FROM tag_frequency
|
||||
WHERE count_pos > 0 OR count_neg > 0
|
||||
ORDER BY count_pos + count_neg DESC
|
||||
"""
|
||||
)
|
||||
tags = cursor.fetchall()
|
||||
|
||||
return tags
|
||||
|
||||
def get_tag_count(self, tag, ttype, negative=False):
|
||||
count_str = "count_neg" if negative else "count_pos"
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT {count_str}, last_used
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag, ttype),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
|
||||
if tag_count:
|
||||
return tag_count[0], tag_count[1]
|
||||
else:
|
||||
return 0, None
|
||||
|
||||
def get_tag_counts(self, tags: list[str], ttypes: list[str], negative=False, date_limit=None):
|
||||
count_str = "count_neg" if negative else "count_pos"
|
||||
with transaction() as cursor:
|
||||
for tag, ttype in zip(tags, ttypes):
|
||||
if date_limit is not None:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT {count_str}, last_used
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
AND last_used > datetime('now', '-' || ? || ' days')
|
||||
""",
|
||||
(tag, ttype, date_limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT {count_str}, last_used
|
||||
FROM tag_frequency
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag, ttype),
|
||||
)
|
||||
tag_count = cursor.fetchone()
|
||||
if tag_count:
|
||||
yield (tag, ttype, tag_count[0], tag_count[1])
|
||||
else:
|
||||
yield (tag, ttype, 0, None)
|
||||
|
||||
def increase_tag_count(self, tag, ttype, negative=False):
|
||||
pos_count = self.get_tag_count(tag, ttype, False)[0]
|
||||
neg_count = self.get_tag_count(tag, ttype, True)[0]
|
||||
|
||||
if negative:
|
||||
neg_count += 1
|
||||
else:
|
||||
pos_count += 1
|
||||
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT OR REPLACE
|
||||
INTO tag_frequency (name, type, count_pos, count_neg)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(tag, ttype, pos_count, neg_count),
|
||||
)
|
||||
|
||||
def reset_tag_count(self, tag, ttype, positive=True, negative=False):
|
||||
if positive and negative:
|
||||
set_str = "count_pos = 0, count_neg = 0"
|
||||
elif positive:
|
||||
set_str = "count_pos = 0"
|
||||
elif negative:
|
||||
set_str = "count_neg = 0"
|
||||
|
||||
with transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
UPDATE tag_frequency
|
||||
SET {set_str}
|
||||
WHERE name = ? AND type = ?
|
||||
""",
|
||||
(tag, ttype),
|
||||
)
|
||||
113301
tags/EnglishDictionary.csv
Normal file
113301
tags/EnglishDictionary.csv
Normal file
File diff suppressed because it is too large
Load Diff
238668
tags/danbooru.csv
238668
tags/danbooru.csv
File diff suppressed because it is too large
Load Diff
221787
tags/danbooru_e621_merged.csv
Normal file
221787
tags/danbooru_e621_merged.csv
Normal file
File diff suppressed because one or more lines are too long
@@ -28,5 +28,17 @@
|
||||
"terms": "Water, Magic, Fancy",
|
||||
"content": "(extremely detailed CG unity 8k wallpaper), (masterpiece), (best quality), (ultra-detailed), (best illustration),(best shadow), (an extremely delicate and beautiful), classic, dynamic angle, floating, fine detail, Depth of field, classic, (painting), (sketch), (bloom), (shine), glinting stars,\n\na girl, solo, bare shoulders, flat chest, diamond and glaring eyes, beautiful detailed cold face, very long blue and sliver hair, floating black feathers, wavy hair, extremely delicate and beautiful girls, beautiful detailed eyes, glowing eyes,\n\nriver, (forest),palace, (fairyland,feather,flowers, nature),(sunlight),Hazy fog, mist",
|
||||
"color": 5
|
||||
},
|
||||
{
|
||||
"name": "Pony-Positive",
|
||||
"terms": "Pony,Score,Positive,Quality",
|
||||
"content": "score_9, score_8_up, score_7_up, score_6_up, source_anime, source_furry, source_pony, source_cartoon",
|
||||
"color": 1
|
||||
},
|
||||
{
|
||||
"name": "Pony-Negative",
|
||||
"terms": "Pony,Score,Negative,Quality",
|
||||
"content": "score_1, score_2, score_3, score_4, score_5, source_anime, source_furry, source_pony, source_cartoon",
|
||||
"color": 3
|
||||
}
|
||||
]
|
||||
110665
tags/derpibooru.csv
Normal file
110665
tags/derpibooru.csv
Normal file
File diff suppressed because it is too large
Load Diff
200358
tags/e621.csv
200358
tags/e621.csv
File diff suppressed because one or more lines are too long
22419
tags/e621_sfw.csv
Normal file
22419
tags/e621_sfw.csv
Normal file
File diff suppressed because one or more lines are too long
160178
tags/noob_characters-chants.json
Normal file
160178
tags/noob_characters-chants.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user