mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-27 03:29:55 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
454c13ef6d | ||
|
|
6deefda279 | ||
|
|
b57042edd0 | ||
|
|
ceba61163e | ||
|
|
16201605d0 | ||
|
|
0c3397aee6 | ||
|
|
4f582f4528 | ||
|
|
d2b5142d7d | ||
|
|
f11abe60c2 | ||
|
|
16bf9d9a51 | ||
|
|
bdd8cf68c7 | ||
|
|
63a0d2e73e | ||
|
|
34ba08d804 | ||
|
|
f1a437ff48 | ||
|
|
97cbada882 | ||
|
|
860a4034bb | ||
|
|
255d7420fd | ||
|
|
6b34d8ccd1 | ||
|
|
b35ee10f8e | ||
|
|
fc8540589a | ||
|
|
3d1ca6893a | ||
|
|
73c3424ab3 | ||
|
|
5f8a5d468d | ||
|
|
4296d8e3b7 |
@@ -1,3 +1,5 @@
|
||||

|
||||
|
||||
# Booru tag autocompletion for A1111
|
||||
|
||||
[](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/releases)
|
||||
@@ -140,6 +142,9 @@ Count in the extra file is optional, since there isn't always a post count for c
|
||||
The extra files can also be used to just add new / custom tags not included in the main set, provided `onlyAliasExtraFile` is false.
|
||||
If an extra tag doesn't match any existing tag, it will be added to the list as a new tag instead. For this, it will need to include the post count and alias columns even if they don't contain anything, so it could be in the form of `tag,type,,`.
|
||||
|
||||
##### WARNING
|
||||
Do not use e621.csv or danbooru.csv as an extra file. Alias comparison has exponential runtime, so for the combination of danbooru+e621, it will need to do 10,000,000,000 (yes, ten billion) lookups and usually take multiple minutes to load.
|
||||
|
||||
## CSV tag data
|
||||
The script expects a CSV file with tags saved in the following way:
|
||||
```csv
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||

|
||||
|
||||
# Booru tag autocompletion for A1111
|
||||
|
||||
[](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/releases)
|
||||
|
||||
@@ -34,6 +34,10 @@ function getTextAreas() {
|
||||
textAreas = textAreas.concat([...gradioApp().querySelectorAll(entry.selectors.join(", "))]);
|
||||
} else { // Otherwise, we have to find the text areas by their adjacent labels
|
||||
let base = gradioApp().querySelector(entry.base);
|
||||
|
||||
// Safety check
|
||||
if (!base) continue;
|
||||
|
||||
let allTextAreas = [...base.querySelectorAll("textarea")];
|
||||
|
||||
// Filter the text areas where the adjacent label matches one of the selectors
|
||||
|
||||
@@ -7,7 +7,9 @@ const styleColors = {
|
||||
"--results-bg-odd": ["#111827", "#f9fafb"],
|
||||
"--results-hover": ["#1f2937", "#f5f6f8"],
|
||||
"--results-selected": ["#374151", "#e5e7eb"],
|
||||
"--post-count-color": ["#6b6f7b", "#a2a9b4"]
|
||||
"--post-count-color": ["#6b6f7b", "#a2a9b4"],
|
||||
"--embedding-v1-color": ["lightsteelblue", "#2b5797"],
|
||||
"--embedding-v2-color": ["skyblue", "#2d89ef"],
|
||||
}
|
||||
const browserVars = {
|
||||
"--results-overflow-y": {
|
||||
@@ -66,6 +68,12 @@ const autocompleteCSS = `
|
||||
flex-grow: 1;
|
||||
color: var(--post-count-color);
|
||||
}
|
||||
.acListItem.acEmbeddingV1 {
|
||||
color: var(--embedding-v1-color);
|
||||
}
|
||||
.acListItem.acEmbeddingV2 {
|
||||
color: var(--embedding-v2-color);
|
||||
}
|
||||
`;
|
||||
|
||||
// Parse the CSV file into a 2D array. Doesn't use regex, so it is very lightweight.
|
||||
@@ -132,7 +140,7 @@ var translations = new Map();
|
||||
|
||||
async function loadTags(c) {
|
||||
// Load main tags and aliases
|
||||
if (allTags.length === 0) {
|
||||
if (allTags.length === 0 && c.tagFile && c.tagFile !== "None") {
|
||||
try {
|
||||
allTags = await loadCSV(`${tagBasePath}/${c.tagFile}?${new Date().getTime()}`);
|
||||
} catch (e) {
|
||||
@@ -342,7 +350,10 @@ function escapeHTML(unsafeText) {
|
||||
}
|
||||
|
||||
const WEIGHT_REGEX = /[([]([^,()[\]:| ]+)(?::(?:\d+(?:\.\d+)?|\.\d+))?[)\]]/g;
|
||||
const TAG_REGEX = /([^\s,|]+)/g
|
||||
const TAG_REGEX = /(<[^\t\n\r,>]+>?|[^\s,|<>]+|<)/g
|
||||
const WC_REGEX = /\b__([^, ]+)__([^, ]*)\b/g;
|
||||
const UMI_PROMPT_REGEX = /<[^\s]*?\[[^,<>]*[\]|]?>?/gi;
|
||||
const UMI_TAG_REGEX = /(?:\[|\||--)([^<>\[\]\-|]+)/gi;
|
||||
let hideBlocked = false;
|
||||
|
||||
// On click, insert the tag into the prompt textbox with respect to the cursor position
|
||||
@@ -358,8 +369,10 @@ function insertTextAtCursor(textArea, result, tagword) {
|
||||
sanitizedText = "__" + text.replace("Wildcards: ", "") + "__";
|
||||
} else if (tagType === "wildcardTag") {
|
||||
sanitizedText = text.replace(/^.*?: /g, "");
|
||||
} else if (tagType === "yamlWildcard" && !yamlWildcards.includes(text)) {
|
||||
sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them
|
||||
} else if (tagType === "embedding") {
|
||||
sanitizedText = `<${text.replace(/^.*?: /g, "")}>`;
|
||||
sanitizedText = `${text.replace(/^.*?: /g, "")}`;
|
||||
} else {
|
||||
sanitizedText = CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
|
||||
}
|
||||
@@ -382,7 +395,7 @@ function insertTextAtCursor(textArea, result, tagword) {
|
||||
let afterInsertCursorPos = editStart + match.index + sanitizedText.length;
|
||||
|
||||
var optionalComma = "";
|
||||
if (CFG.appendComma && tagType !== "wildcardFile") {
|
||||
if (CFG.appendComma && tagType !== "wildcardFile" && tagType !== "yamlWildcard") {
|
||||
optionalComma = surrounding.match(new RegExp(`${escapeRegExp(tagword)}[,:]`, "i")) !== null ? "" : ", ";
|
||||
}
|
||||
|
||||
@@ -409,6 +422,26 @@ function insertTextAtCursor(textArea, result, tagword) {
|
||||
}
|
||||
previousTags = tags;
|
||||
|
||||
// If it was a yaml wildcard, also update the umiPreviousTags
|
||||
if (tagType === "yamlWildcard" && originalTagword.length > 0) {
|
||||
let editStart = Math.max(cursorPos - tagword.length, 0);
|
||||
let editEnd = Math.min(cursorPos + tagword.length, originalTagword.length);
|
||||
let surrounding = originalTagword.substring(editStart, editEnd);
|
||||
let match = surrounding.match(new RegExp(escapeRegExp(`${tagword}`), "i"));
|
||||
let insert = surrounding.replace(match, sanitizedText);
|
||||
|
||||
let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)];
|
||||
|
||||
let umiTags = [];
|
||||
umiSubPrompts.forEach(umiSubPrompt => {
|
||||
umiTags = umiTags.concat([...umiSubPrompt[0].matchAll(UMI_TAG_REGEX)].map(x => x[1].toLowerCase()));
|
||||
});
|
||||
|
||||
umiPreviousTags = umiTags;
|
||||
|
||||
hideResults(textArea);
|
||||
}
|
||||
|
||||
// Hide results after inserting
|
||||
if (tagType === "wildcardFile") {
|
||||
// If it's a wildcard, we want to keep the results open so the user can select another wildcard
|
||||
@@ -490,18 +523,20 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
// Add post count & color if it's a tag
|
||||
// Wildcards & Embeds have no tag type
|
||||
if (!result[1].startsWith("wildcard") && result[1] !== "embedding") {
|
||||
// Set the color of the tag
|
||||
let tagType = result[1];
|
||||
let colorGroup = tagColors[tagFileName];
|
||||
// Default to danbooru scheme if no matching one is found
|
||||
if (!colorGroup)
|
||||
colorGroup = tagColors["danbooru"];
|
||||
if (!result[1].startsWith("yaml")) {
|
||||
// Set the color of the tag
|
||||
let tagType = result[1];
|
||||
let colorGroup = tagColors[tagFileName];
|
||||
// Default to danbooru scheme if no matching one is found
|
||||
if (!colorGroup)
|
||||
colorGroup = tagColors["danbooru"];
|
||||
|
||||
// Set tag type to invalid if not found
|
||||
if (!colorGroup[tagType])
|
||||
tagType = "-1";
|
||||
// Set tag type to invalid if not found
|
||||
if (!colorGroup[tagType])
|
||||
tagType = "-1";
|
||||
|
||||
itemText.style = `color: ${colorGroup[tagType][mode]};`;
|
||||
itemText.style = `color: ${colorGroup[tagType][mode]};`;
|
||||
}
|
||||
|
||||
// Post count
|
||||
if (result[2] && !isNaN(result[2])) {
|
||||
@@ -521,6 +556,17 @@ function addResultsToList(textArea, results, tagword, resetList) {
|
||||
countDiv.classList.add("acPostCount");
|
||||
flexDiv.appendChild(countDiv);
|
||||
}
|
||||
} else if (result[1] === "embedding" && result[2]) { // Check if it is an embedding we have version info for
|
||||
let versionDiv = document.createElement("div");
|
||||
versionDiv.textContent = result[2];
|
||||
versionDiv.classList.add("acPostCount");
|
||||
|
||||
if (result[2].startsWith("v1"))
|
||||
itemText.classList.add("acEmbeddingV1");
|
||||
else if (result[2].startsWith("v2"))
|
||||
itemText.classList.add("acEmbeddingV2");
|
||||
|
||||
flexDiv.appendChild(versionDiv);
|
||||
}
|
||||
|
||||
// Add listener
|
||||
@@ -555,9 +601,12 @@ function updateSelectionStyle(textArea, newIndex, oldIndex) {
|
||||
|
||||
var wildcardFiles = [];
|
||||
var wildcardExtFiles = [];
|
||||
var yamlWildcards = [];
|
||||
var umiPreviousTags = [];
|
||||
var embeddings = [];
|
||||
var results = [];
|
||||
var tagword = "";
|
||||
var originalTagword = "";
|
||||
var resultCount = 0;
|
||||
async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
// Return if the function is deactivated in the UI
|
||||
@@ -575,8 +624,8 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
let weightedTags = [...prompt.matchAll(WEIGHT_REGEX)]
|
||||
.map(match => match[1]);
|
||||
let tags = prompt.match(TAG_REGEX)
|
||||
if (weightedTags !== null) {
|
||||
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted)))
|
||||
if (weightedTags !== null && tags !== null) {
|
||||
tags = tags.filter(tag => !weightedTags.some(weighted => tag.includes(weighted) && !tag.startsWith("<[")))
|
||||
.concat(weightedTags);
|
||||
}
|
||||
|
||||
@@ -603,9 +652,9 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
|
||||
tagword = tagword.toLowerCase().replace(/[\n\r]/g, "");
|
||||
|
||||
if (CFG.useWildcards && [...tagword.matchAll(/\b__([^, ]+)__([^, ]*)\b/g)].length > 0) {
|
||||
if (CFG.useWildcards && [...tagword.matchAll(WC_REGEX)].length > 0) {
|
||||
// Show wildcards from a file with that name
|
||||
wcMatch = [...tagword.matchAll(/\b__([^, ]+)__([^, ]*)\b/g)]
|
||||
wcMatch = [...tagword.matchAll(WC_REGEX)]
|
||||
let wcFile = wcMatch[0][1];
|
||||
let wcWord = wcMatch[0][2];
|
||||
|
||||
@@ -632,11 +681,164 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
tempResults = wildcardFiles.concat(wildcardExtFiles);
|
||||
}
|
||||
results = tempResults.map(x => ["Wildcards: " + x[1].trim(), "wildcardFile"]); // Mark as wildcard
|
||||
} else if (CFG.useWildcards && [...tagword.matchAll(UMI_PROMPT_REGEX)].length > 0) {
|
||||
// We are in a UMI yaml tag definition, parse further
|
||||
let umiSubPrompts = [...prompt.matchAll(UMI_PROMPT_REGEX)];
|
||||
|
||||
let umiTags = [];
|
||||
let umiTagsWithOperators = []
|
||||
|
||||
const insertAt = (str,char,pos) => str.slice(0,pos) + char + str.slice(pos);
|
||||
|
||||
umiSubPrompts.forEach(umiSubPrompt => {
|
||||
umiTags = umiTags.concat([...umiSubPrompt[0].matchAll(UMI_TAG_REGEX)].map(x => x[1].toLowerCase()));
|
||||
|
||||
const start = umiSubPrompt.index;
|
||||
const end = umiSubPrompt.index + umiSubPrompt[0].length;
|
||||
if (textArea.selectionStart >= start && textArea.selectionStart <= end) {
|
||||
umiTagsWithOperators = insertAt(umiSubPrompt[0], '###', textArea.selectionStart - start);
|
||||
}
|
||||
});
|
||||
|
||||
const promptSplitToTags = umiTagsWithOperators.replace(']###[', '][').split("][");
|
||||
|
||||
const clean = (str) => str
|
||||
.replaceAll('>', '')
|
||||
.replaceAll('<', '')
|
||||
.replaceAll('[', '')
|
||||
.replaceAll(']', '')
|
||||
.trim();
|
||||
|
||||
const matches = promptSplitToTags.reduce((acc, curr) => {
|
||||
isOptional = curr.includes("|");
|
||||
isNegative = curr.startsWith("--");
|
||||
let out;
|
||||
if (isOptional) {
|
||||
out = {
|
||||
hasCursor: curr.includes("###"),
|
||||
tags: clean(curr).split('|').map(x => ({
|
||||
hasCursor: x.includes("###"),
|
||||
isNegative: x.startsWith("--"),
|
||||
tag: clean(x).replaceAll("###", '').replaceAll("--", '')
|
||||
}))
|
||||
};
|
||||
acc.optional.push(out);
|
||||
acc.all.push(...out.tags.map(x => x.tag));
|
||||
} else if (isNegative) {
|
||||
out = {
|
||||
hasCursor: curr.includes("###"),
|
||||
tags: clean(curr).replaceAll("###", '').split('|'),
|
||||
};
|
||||
out.tags = out.tags.map(x => x.startsWith("--") ? x.substring(2) : x);
|
||||
acc.negative.push(out);
|
||||
acc.all.push(...out.tags);
|
||||
} else {
|
||||
out = {
|
||||
hasCursor: curr.includes("###"),
|
||||
tags: clean(curr).replaceAll("###", '').split('|'),
|
||||
};
|
||||
acc.positive.push(out);
|
||||
acc.all.push(...out.tags);
|
||||
}
|
||||
return acc;
|
||||
}, { positive: [], negative: [], optional: [], all: [] });
|
||||
|
||||
//console.log({ matches })
|
||||
|
||||
const filteredWildcards = (tagword) => {
|
||||
const wildcards = yamlWildcards.filter(x => {
|
||||
let tags = x[1];
|
||||
const matchesNeg =
|
||||
matches.negative.length === 0
|
||||
|| matches.negative.every(x =>
|
||||
x.hasCursor
|
||||
|| x.tags.every(t => !tags[t])
|
||||
);
|
||||
if (!matchesNeg) return false;
|
||||
const matchesPos =
|
||||
matches.positive.length === 0
|
||||
|| matches.positive.every(x =>
|
||||
x.hasCursor
|
||||
|| x.tags.every(t => tags[t])
|
||||
);
|
||||
if (!matchesPos) return false;
|
||||
const matchesOpt =
|
||||
matches.optional.length === 0
|
||||
|| matches.optional.some(x =>
|
||||
x.tags.some(t =>
|
||||
t.hasCursor
|
||||
|| t.isNegative
|
||||
? !tags[t.tag]
|
||||
: tags[t.tag]
|
||||
));
|
||||
if (!matchesOpt) return false;
|
||||
return true;
|
||||
}).reduce((acc, val) => {
|
||||
Object.keys(val[1]).forEach(tag => acc[tag] = acc[tag] + 1 || 1);
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
return Object.entries(wildcards)
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.filter(x =>
|
||||
x[0] === tagword
|
||||
|| !matches.all.includes(x[0])
|
||||
);
|
||||
}
|
||||
|
||||
if (umiTags.length > 0) {
|
||||
// Get difference for subprompt
|
||||
let tagCountChange = umiTags.length - umiPreviousTags.length;
|
||||
let diff = difference(umiTags, umiPreviousTags);
|
||||
umiPreviousTags = umiTags;
|
||||
|
||||
// Show all condition
|
||||
let showAll = tagword.endsWith("[") || tagword.endsWith("[--") || tagword.endsWith("|");
|
||||
|
||||
// Exit early if the user closed the bracket manually
|
||||
if ((!diff || diff.length === 0 || (diff.length === 1 && tagCountChange < 0)) && !showAll) {
|
||||
if (!hideBlocked) hideResults(textArea);
|
||||
return;
|
||||
}
|
||||
|
||||
let umiTagword = diff[0] || '';
|
||||
let tempResults = [];
|
||||
if (umiTagword && umiTagword.length > 0) {
|
||||
umiTagword = umiTagword.toLowerCase().replace(/[\n\r]/g, "");
|
||||
originalTagword = tagword;
|
||||
tagword = umiTagword;
|
||||
let filteredWildcardsSorted = filteredWildcards(umiTagword);
|
||||
let searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(umiTagword)}`, 'i')
|
||||
let baseFilter = x => x[0].toLowerCase().search(searchRegex) > -1;
|
||||
let spaceIncludeFilter = x => x[0].toLowerCase().replaceAll(" ", "_").search(searchRegex) > -1;
|
||||
tempResults = filteredWildcardsSorted.filter(x => baseFilter(x) || spaceIncludeFilter(x)) // Filter by tagword
|
||||
results = tempResults.map(x => [x[0].trim(), "yamlWildcard", x[1]]); // Mark as yaml wildcard
|
||||
} else if (showAll) {
|
||||
let filteredWildcardsSorted = filteredWildcards("");
|
||||
results = filteredWildcardsSorted.map(x => [x[0].trim(), "yamlWildcard", x[1]]); // Mark as yaml wildcard
|
||||
originalTagword = tagword;
|
||||
tagword = "";
|
||||
}
|
||||
} else {
|
||||
let filteredWildcardsSorted = filteredWildcards("");
|
||||
results = filteredWildcardsSorted.map(x => [x[0].trim(), "yamlWildcard", x[1]]); // Mark as yaml wildcard
|
||||
originalTagword = tagword;
|
||||
tagword = "";
|
||||
}
|
||||
} else if (CFG.useEmbeddings && tagword.match(/<[^,> ]*>?/g)) {
|
||||
// Show embeddings
|
||||
let tempResults = [];
|
||||
if (tagword !== "<") {
|
||||
tempResults = embeddings.filter(x => x.toLowerCase().includes(tagword.replace("<", ""))) // Filter by tagword
|
||||
let searchTerm = tagword.replace("<", "")
|
||||
let versionString;
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
searchTerm = searchTerm.slice(2);
|
||||
}
|
||||
if (versionString)
|
||||
tempResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm) && x[1] && x[1] === versionString); // Filter by tagword
|
||||
else
|
||||
tempResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = embeddings;
|
||||
}
|
||||
@@ -650,7 +852,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
|
||||
}
|
||||
genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
|
||||
results = genericResults.concat(tempResults.map(x => ["Embeddings: " + x.trim(), "embedding"])); // Mark as embedding
|
||||
results = tempResults.map(x => [x[0].trim(), "embedding", x[1] + " Embedding"]).concat(genericResults); // Mark as embedding
|
||||
} else {
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
@@ -662,13 +864,13 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
}
|
||||
// If onlyShowAlias is enabled, we don't need to include normal results
|
||||
if (CFG.alias.onlyShowAlias) {
|
||||
results = allTags.filter(x => x[3] && x[3].toLowerCase().search(searchRegex) >- 1);
|
||||
results = allTags.filter(x => x[3] && x[3].toLowerCase().search(searchRegex) > -1);
|
||||
} else {
|
||||
// Else both normal tags and aliases/translations are included depending on the config
|
||||
let baseFilter = (x) => x[0].toLowerCase().search(searchRegex) >- 1;
|
||||
let aliasFilter = (x) => x[3] && x[3].toLowerCase().search(searchRegex) >- 1;
|
||||
let translationFilter = (x) => (translations.has(x[0]) && translations.get(x[0]).toLowerCase().search(searchRegex) >- 1)
|
||||
|| x[3] && x[3].split(",").some(y => translations.has(y) && translations.get(y).toLowerCase().search(searchRegex) >- 1);
|
||||
let baseFilter = (x) => x[0].toLowerCase().search(searchRegex) > -1;
|
||||
let aliasFilter = (x) => x[3] && x[3].toLowerCase().search(searchRegex) > -1;
|
||||
let translationFilter = (x) => (translations.has(x[0]) && translations.get(x[0]).toLowerCase().search(searchRegex) > -1)
|
||||
|| x[3] && x[3].split(",").some(y => translations.has(y) && translations.get(y).toLowerCase().search(searchRegex) > -1);
|
||||
|
||||
let fil;
|
||||
if (CFG.alias.searchByAlias && CFG.translation.searchByTranslation)
|
||||
@@ -690,6 +892,7 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
|
||||
// Guard for empty results
|
||||
if (!results.length) {
|
||||
//console.log('No results found for "' + tagword + '"');
|
||||
hideResults(textArea);
|
||||
return;
|
||||
}
|
||||
@@ -822,12 +1025,31 @@ async function setup() {
|
||||
console.error("Error loading wildcards: " + e);
|
||||
}
|
||||
}
|
||||
// Load yaml wildcards
|
||||
if (yamlWildcards.length === 0) {
|
||||
try {
|
||||
let yamlTags = (await readFile(`${tagBasePath}/temp/wcet.txt?${new Date().getTime()}`)).split("\n");
|
||||
// Split into tag, count pairs
|
||||
yamlWildcards = yamlTags.map(x => x
|
||||
.trim()
|
||||
.split(","))
|
||||
.map(([i, ...rest]) => [
|
||||
i,
|
||||
rest.reduce((a, b) => {
|
||||
a[b.toLowerCase()] = true;
|
||||
return a;
|
||||
}, {}),
|
||||
]);
|
||||
} catch (e) {
|
||||
console.error("Error loading yaml wildcards: " + e);
|
||||
}
|
||||
}
|
||||
// Load embeddings
|
||||
if (embeddings.length === 0) {
|
||||
try {
|
||||
embeddings = (await readFile(`${tagBasePath}/temp/emb.txt?${new Date().getTime()}`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => x.replace(".bin", "").replace(".pt", "").replace(".png", "")); // Remove file extensions
|
||||
.map(x => x.trim().split(",")); // Split into name, version type pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading embeddings.txt: " + e);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
import gradio as gr
|
||||
from pathlib import Path
|
||||
from modules import scripts, script_callbacks, shared
|
||||
from modules import scripts, script_callbacks, shared, sd_hijack
|
||||
import yaml
|
||||
import time
|
||||
import threading
|
||||
|
||||
# Webui root path
|
||||
FILE_DIR = Path().absolute()
|
||||
@@ -53,9 +56,78 @@ def get_ext_wildcards():
|
||||
return wildcard_files
|
||||
|
||||
|
||||
def get_ext_wildcard_tags():
|
||||
"""Returns a list of all tags found in extension YAML files found under a Tags: key."""
|
||||
wildcard_tags = {} # { tag: count }
|
||||
yaml_files = []
|
||||
for path in WILDCARD_EXT_PATHS:
|
||||
yaml_files.extend(p for p in path.rglob("*.yml"))
|
||||
yaml_files.extend(p for p in path.rglob("*.yaml"))
|
||||
count = 0
|
||||
for path in yaml_files:
|
||||
try:
|
||||
with open(path, encoding="utf8") as file:
|
||||
data = yaml.safe_load(file)
|
||||
for item in data:
|
||||
wildcard_tags[count] = ','.join(data[item]['Tags'])
|
||||
count += 1
|
||||
except yaml.YAMLError as exc:
|
||||
print(exc)
|
||||
# Sort by count
|
||||
sorted_tags = sorted(wildcard_tags.items(), key=lambda item: item[1], reverse=True)
|
||||
output = []
|
||||
for tag, count in sorted_tags:
|
||||
output.append(f"{tag},{count}")
|
||||
return output
|
||||
|
||||
|
||||
def get_embeddings():
|
||||
"""Returns a list of all embeddings"""
|
||||
return [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png"}]
|
||||
"""Write a list of all embeddings with their version"""
|
||||
# Get a list of all embeddings in the folder
|
||||
embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
|
||||
# Remove file extensions
|
||||
embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir]
|
||||
|
||||
# Wait for all embeddings to be loaded
|
||||
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
|
||||
time.sleep(2) # Sleep for 2 seconds
|
||||
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
# Get the shape of the first item in the dict
|
||||
emb_a_shape = -1
|
||||
if (len(emb_type_a) > 0):
|
||||
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
|
||||
|
||||
# Add embeddings to the correct list
|
||||
V1_SHAPE = 768
|
||||
V2_SHAPE = 1024
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_a.keys())
|
||||
emb_v2 = list(emb_type_b)
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v1 = list(emb_type_b)
|
||||
emb_v2 = list(emb_type_a.keys())
|
||||
|
||||
# Create a new list to store the modified strings
|
||||
results = []
|
||||
|
||||
# Iterate through each string in the big list
|
||||
for string in embs_in_dir:
|
||||
if string in emb_v1:
|
||||
results.append(string + ",v1")
|
||||
elif string in emb_v2:
|
||||
results.append(string + ",v2")
|
||||
# If the string is not in either, default to v1
|
||||
# (we can't know what it is since the startup model loaded none of them, but it's probably v1 since v2 is newer)
|
||||
else:
|
||||
results.append(string + ",v1")
|
||||
|
||||
write_to_temp_file('emb.txt', results)
|
||||
|
||||
|
||||
def write_tag_base_path():
|
||||
@@ -97,6 +169,7 @@ if not TEMP_PATH.exists():
|
||||
# even if no wildcards or embeddings are found
|
||||
write_to_temp_file('wc.txt', [])
|
||||
write_to_temp_file('wce.txt', [])
|
||||
write_to_temp_file('wcet.txt', [])
|
||||
write_to_temp_file('emb.txt', [])
|
||||
|
||||
# Write wildcards to wc.txt if found
|
||||
@@ -110,18 +183,23 @@ if WILDCARD_EXT_PATHS is not None:
|
||||
wildcards_ext = get_ext_wildcards()
|
||||
if wildcards_ext:
|
||||
write_to_temp_file('wce.txt', wildcards_ext)
|
||||
# Write yaml extension wildcards to wcet.txt if found
|
||||
wildcards_yaml_ext = get_ext_wildcard_tags()
|
||||
if wildcards_yaml_ext:
|
||||
write_to_temp_file('wcet.txt', wildcards_yaml_ext)
|
||||
|
||||
# Write embeddings to emb.txt if found
|
||||
if EMB_PATH.exists():
|
||||
embeddings = get_embeddings()
|
||||
if embeddings:
|
||||
write_to_temp_file('emb.txt', embeddings)
|
||||
# We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads)
|
||||
thread = threading.Thread(target=get_embeddings)
|
||||
thread.start()
|
||||
|
||||
|
||||
# Register autocomplete options
|
||||
def on_ui_settings():
|
||||
TAC_SECTION = ("tac", "Tag Autocomplete")
|
||||
# Main tag file
|
||||
shared.opts.add_option("tac_tagFile", shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files}, refresh=update_tag_files, section=TAC_SECTION))
|
||||
shared.opts.add_option("tac_tagFile", shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION))
|
||||
# Active in settings
|
||||
shared.opts.add_option("tac_active", shared.OptionInfo(True, "Enable Tag Autocompletion", section=TAC_SECTION))
|
||||
shared.opts.add_option("tac_activeIn.txt2img", shared.OptionInfo(True, "Active in txt2img (Requires restart)", section=TAC_SECTION))
|
||||
@@ -147,7 +225,7 @@ def on_ui_settings():
|
||||
shared.opts.add_option("tac_translation.oldFormat", shared.OptionInfo(False, "Translation file uses old 3-column translation format instead of the new 2-column one", section=TAC_SECTION))
|
||||
shared.opts.add_option("tac_translation.searchByTranslation", shared.OptionInfo(True, "Search by translation", section=TAC_SECTION))
|
||||
# Extra file settings
|
||||
shared.opts.add_option("tac_extra.extraFile", shared.OptionInfo("None", "Extra filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION))
|
||||
shared.opts.add_option("tac_extra.extraFile", shared.OptionInfo("None", "Extra filename (do not use e621.csv here!)", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION))
|
||||
shared.opts.add_option("tac_extra.onlyAliasExtraFile", shared.OptionInfo(False, "Extra file in alias only format", section=TAC_SECTION))
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
||||
Reference in New Issue
Block a user