Compare commits

...

12 Commits

Author SHA1 Message Date
Dominik Reh
e418a867b3 Merge branch 'hyp-lora-support' into main 2023-01-24 15:23:53 +01:00
Dominik Reh
040be35162 Don't escape parentheses for loras and hypernets 2023-01-24 15:03:56 +01:00
Dominik Reh
316d45e2fa Use extra network multiplier from settings 2023-01-24 15:03:35 +01:00
Dominik Reh
8ab0e2504b Fix meta display, add mixed results
< will show all three, while <e: <h: or <l: will limit it to that type.
2023-01-24 14:51:55 +01:00
Dominik Reh
b29b496b88 Simplify lora and hypernetwork loading 2023-01-24 14:08:11 +01:00
Dominik Reh
e144f0d388 Make script work without settings tab
Fixes #116
2023-01-24 13:08:43 +01:00
JM
ae01f41f30 add support for hypernetworks and lora 2023-01-22 19:24:59 +01:00
DominikDoom
fb27ac9187 Update README_ZH.md 2023-01-18 16:31:57 +01:00
DominikDoom
770bb495a5 Update README.md 2023-01-18 16:29:55 +01:00
Dominik Reh
7fdad1bf62 Add back ability to use hashes in black/whitelist
They are displayed in the UI after all, just not in the dropdown but at the bottom
2023-01-14 14:57:39 +01:00
Dominik Reh
a91a098243 Change blacklist to use model name instead of hash
Hotfix for recent webui changes to use proper sha256 hashes, which is currently not displayed in the UI
2023-01-14 14:24:44 +01:00
Dominik Reh
c663abcbcb Fix wiki links showing on embeddings & wildcards 2023-01-13 19:33:43 +01:00
5 changed files with 222 additions and 58 deletions

View File

@@ -41,8 +41,8 @@ git clone "https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git" extens
Or create a folder there manually and place the `javascript`, `scripts` and `tags` folders in it.
### In the root folder (old)
Copy the `javascript`, `scripts` and `tags` folder into your web UI installation root. It will run automatically the next time the web UI is started.
### In the root folder (legacy)
This installation method is for old webui versions pre-extension system, it will not work on current versions!
---

View File

@@ -39,8 +39,8 @@ git clone "https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git" extens
或者手动创建一个文件夹,将 `javascript``scripts``tags`文件夹放在其中。
### 在根目录下(方法)
只需要将`javascript``scripts``tags`文件夹复制到你的Web UI安装根目录下.下次启动Web UI时它将自动启动
### 在根目录下(过时的方法
这种安装方法适用于添加扩展系统之前的旧版webui在目前的版本上是行不通的
---
在这两种配置中,标签文件夹包含`colors.json`和脚本用于自动完成的标签数据。

View File

@@ -6,7 +6,9 @@ const ResultType = Object.freeze({
"embedding": 2,
"wildcardTag": 3,
"wildcardFile": 4,
"yamlWildcard": 5
"yamlWildcard": 5,
"hypernetwork": 6,
"lora": 7
});
// Class to hold result data and annotations to make it clearer to use

View File

@@ -176,6 +176,8 @@ async function syncOptions() {
delayTime: opts["tac_delayTime"],
useWildcards: opts["tac_useWildcards"],
useEmbeddings: opts["tac_useEmbeddings"],
useHypernetworks: opts["tac_useHypernetworks"],
useLoras: opts["tac_useLoras"],
showWikiLinks: opts["tac_showWikiLinks"],
// Insertion related settings
replaceUnderscores: opts["tac_replaceUnderscores"],
@@ -196,7 +198,9 @@ async function syncOptions() {
extra: {
extraFile: opts["tac_extra.extraFile"],
onlyAliasExtraFile: opts["tac_extra.onlyAliasExtraFile"]
}
},
// Settings not from tac but still used by the script
extraNetworksDefaultMultiplier: opts["extra_networks_default_multiplier"]
}
if (CFG && CFG.colors) {
@@ -267,6 +271,7 @@ function hideResults(textArea) {
}
var currentModelHash = "";
var currentModelName = "";
// Function to check activation criteria
function isEnabled() {
if (CFG.activeIn.global) {
@@ -275,13 +280,14 @@ function isEnabled() {
.map(x => x.trim())
.filter(x => x.length > 0);
let shortHash = currentModelHash.substring(0, 10);
if (CFG.activeIn.modelListMode.toLowerCase() === "blacklist") {
// If the current model is in the blacklist, disable
return !modelList.includes(currentModelHash);
return modelList.filter(x => x === currentModelName || x === currentModelHash || x === shortHash).length === 0;
} else {
// If the current model is in the whitelist, enable.
// An empty whitelist is ignored.
return modelList.length === 0 || modelList.includes(currentModelHash);
return modelList.length === 0 || modelList.filter(x => x === currentModelName || x === currentModelHash || x === shortHash).length > 0;
}
} else {
return false;
@@ -293,7 +299,6 @@ const TAG_REGEX = /(<[^\t\n\r,>]+>?|[^\s,|<>]+|<)/g
const WC_REGEX = /\b__([^, ]+)__([^, ]*)\b/g;
const UMI_PROMPT_REGEX = /<[^\s]*?\[[^,<>]*[\]|]?>?/gi;
const UMI_TAG_REGEX = /(?:\[|\||--)([^<>\[\]\-|]+)/gi;
const MODEL_HASH_REGEX = /\[(.+)\]/g;
let hideBlocked = false;
// On click, insert the tag into the prompt textbox with respect to the cursor position
@@ -313,11 +318,15 @@ function insertTextAtCursor(textArea, result, tagword) {
sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them
} else if (tagType === ResultType.embedding) {
sanitizedText = `${text.replace(/^.*?: /g, "")}`;
} else if (tagType === ResultType.hypernetwork) {
sanitizedText = `<hypernet:${text}:${CFG.extraNetworksDefaultMultiplier}>`;
} else if(tagType === ResultType.lora) {
sanitizedText = `<lora:${text}:${CFG.extraNetworksDefaultMultiplier}>`;
} else {
sanitizedText = CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
}
if (CFG.escapeParentheses) {
if (CFG.escapeParentheses && tagType === ResultType.tag) {
sanitizedText = sanitizedText
.replaceAll("(", "\\(")
.replaceAll(")", "\\)")
@@ -454,7 +463,9 @@ function addResultsToList(textArea, results, tagword, resetList) {
itemText.innerHTML = displayText.replace(tagword, `<b>${tagword}</b>`);
// Add wiki link if the setting is enabled and a supported tag set loaded
if (CFG.showWikiLinks && (tagFileName.toLowerCase().startsWith("danbooru") || tagFileName.toLowerCase().startsWith("e621"))) {
if (CFG.showWikiLinks
&& (result.type === ResultType.tag)
&& (tagFileName.toLowerCase().startsWith("danbooru") || tagFileName.toLowerCase().startsWith("e621"))) {
let wikiLink = document.createElement("a");
wikiLink.classList.add("acWikiLink");
wikiLink.innerText = "?";
@@ -480,41 +491,39 @@ function addResultsToList(textArea, results, tagword, resetList) {
// Add post count & color if it's a tag
// Wildcards & Embeds have no tag category
if (![ResultType.wildcardFile, ResultType.wildcardTag, ResultType.embedding].includes(result.type)) {
if (result.category) {
// Set the color of the tag
let cat = result.category;
let colorGroup = tagColors[tagFileName];
// Default to danbooru scheme if no matching one is found
if (!colorGroup)
colorGroup = tagColors["danbooru"];
if (result.category) {
// Set the color of the tag
let cat = result.category;
let colorGroup = tagColors[tagFileName];
// Default to danbooru scheme if no matching one is found
if (!colorGroup)
colorGroup = tagColors["danbooru"];
// Set tag type to invalid if not found
if (!colorGroup[cat])
cat = "-1";
// Set tag type to invalid if not found
if (!colorGroup[cat])
cat = "-1";
flexDiv.style = `color: ${colorGroup[cat][mode]};`;
}
flexDiv.style = `color: ${colorGroup[cat][mode]};`;
}
// Post count
if (result.count && !isNaN(result.count)) {
let postCount = result.count;
let formatter;
// Post count
if (result.count && !isNaN(result.count)) {
let postCount = result.count;
let formatter;
// Danbooru formats numbers with a padded fraction for 1M or 1k, but not for 10/100k
if (postCount >= 1000000 || (postCount >= 1000 && postCount < 10000))
formatter = Intl.NumberFormat("en", { notation: "compact", minimumFractionDigits: 1, maximumFractionDigits: 1 });
else
formatter = Intl.NumberFormat("en", {notation: "compact"});
let formattedCount = formatter.format(postCount);
let countDiv = document.createElement("div");
countDiv.textContent = formattedCount;
countDiv.classList.add("acMetaText");
flexDiv.appendChild(countDiv);
}
} else if (result.meta) { // Check if it is an embedding we have version info for
// Danbooru formats numbers with a padded fraction for 1M or 1k, but not for 10/100k
if (postCount >= 1000000 || (postCount >= 1000 && postCount < 10000))
formatter = Intl.NumberFormat("en", { notation: "compact", minimumFractionDigits: 1, maximumFractionDigits: 1 });
else
formatter = Intl.NumberFormat("en", {notation: "compact"});
let formattedCount = formatter.format(postCount);
let countDiv = document.createElement("div");
countDiv.textContent = formattedCount;
countDiv.classList.add("acMetaText");
flexDiv.appendChild(countDiv);
} else if (result.meta) { // Check if there is meta info to display
let metaDiv = document.createElement("div");
metaDiv.textContent = result.meta;
metaDiv.classList.add("acMetaText");
@@ -565,6 +574,8 @@ var wildcardExtFiles = [];
var yamlWildcards = [];
var umiPreviousTags = [];
var embeddings = [];
var hypernetworks = [];
var loras = [];
var results = [];
var tagword = "";
var originalTagword = "";
@@ -828,11 +839,11 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
originalTagword = tagword;
tagword = "";
}
} else if (CFG.useEmbeddings && tagword.match(/<[^,> ]*>?/g)) {
} else if (CFG.useEmbeddings && tagword.match(/<e:[^,> ]*>?/g)) {
// Show embeddings
let tempResults = [];
if (tagword !== "<") {
let searchTerm = tagword.replace("<", "")
if (tagword !== "<e:") {
let searchTerm = tagword.replace("<e:", "")
let versionString;
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
@@ -845,6 +856,73 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
} else {
tempResults = embeddings;
}
// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
result.meta = t[1] + " Embedding";
results.push(result);
});
} else if(CFG.useHypernetworks && tagword.match(/<h:[^,> ]*>?/g)) {
// Show hypernetworks
let tempResults = [];
if (tagword !== "<h:") {
let searchTerm = tagword.replace("<h:", "")
tempResults = hypernetworks.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
} else {
tempResults = hypernetworks;
}
// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.hypernetwork)
result.meta = "Hypernetwork";
results.push(result);
});
} else if(CFG.useLoras && tagword.match(/<l:[^,> ]*>?/g)){
// Show lora
let tempResults = [];
if (tagword !== "<l:") {
let searchTerm = tagword.replace("<l:", "")
tempResults = loras.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
} else {
tempResults = loras;
}
// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lora)
result.meta = "Lora";
results.push(result);
});
} else if ((CFG.useEmbeddings || CFG.useHypernetworks || CFG.useLoras) && tagword.match(/<[^,> ]*>?/g)) {
// Embeddings, lora, wildcards all together with generic options
let tempEmbResults = [];
let tempHypResults = [];
let tempLoraResults = [];
if (tagword !== "<") {
let searchTerm = tagword.replace("<", "")
let versionString;
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
searchTerm = searchTerm.slice(2);
}
if (versionString && CFG.useEmbeddings) {
// Version string is only for embeddings atm, so we don't search the other lists here.
tempEmbResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm) && x[1] && x[1] === versionString); // Filter by tagword
} else {
tempEmbResults = embeddings.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
tempHypResults = hypernetworks.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
tempLoraResults = loras.filter(x => x.toLowerCase().includes(searchTerm)); // Filter by tagword
}
} else {
tempEmbResults = embeddings;
tempHypResults = hypernetworks;
tempLoraResults = loras;
}
// Since some tags are kaomoji, we have to still get the normal results first.
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
@@ -857,11 +935,32 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
result.meta = t[1] + " Embedding";
results.push(result);
});
let mixedResults = [];
if (CFG.useEmbeddings) {
tempEmbResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.embedding)
result.meta = t[1] + " Embedding";
mixedResults.push(result);
});
}
if (CFG.useHypernetworks) {
tempHypResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.hypernetwork)
result.meta = "Hypernetwork";
mixedResults.push(result);
});
}
if (CFG.useLoras) {
tempLoraResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lora)
result.meta = "Lora";
mixedResults.push(result);
});
}
// Add all mixed results to the final results, sorted by name so that they aren't after one another.
results = mixedResults.sort((a, b) => a.text.localeCompare(b.text));
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
@@ -1077,13 +1176,31 @@ async function setup() {
console.error("Error loading embeddings.txt: " + e);
}
}
// Load hypernetworks
if (hypernetworks.length === 0) {
try {
hypernetworks = (await readFile(`${tagBasePath}/temp/hyp.txt?${new Date().getTime()}`)).split("\n")
.filter(x => x.trim().length > 0) //Remove empty lines
} catch (e) {
console.error("Error loading hypernetworks.txt: " + e);
}
}
// Load lora
if (loras.length === 0) {
try {
loras = (await readFile(`${tagBasePath}/temp/lora.txt?${new Date().getTime()}`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
} catch (e) {
console.error("Error loading lora.txt: " + e);
}
}
// Find all textareas
let textAreas = getTextAreas();
// Add event listener to apply settings button so we can mirror the changes to our internal config
let applySettingsButton = gradioApp().querySelector("#tab_settings #settings_submit") || gradioApp().querySelector("#tab_settings > div > .gr-button-primary");
applySettingsButton.addEventListener("click", () => {
applySettingsButton?.addEventListener("click", () => {
// Wait 500ms to make sure the settings have been applied to the webui opts object
setTimeout(async () => {
await syncOptions();
@@ -1092,21 +1209,35 @@ async function setup() {
// Add change listener to our quicksettings to change our internal config without the apply button for them
let quicksettings = gradioApp().querySelector('#quicksettings');
let commonQueryPart = "[id^=setting_tac] > label >";
quicksettings.querySelectorAll(`${commonQueryPart} input, ${commonQueryPart} textarea, ${commonQueryPart} select`).forEach(e => {
quicksettings?.querySelectorAll(`${commonQueryPart} input, ${commonQueryPart} textarea, ${commonQueryPart} select`).forEach(e => {
e.addEventListener("change", () => {
setTimeout(async () => {
await syncOptions();
}, 500);
});
});
// Add change listener to model dropdown to react to model changes
let modelDropdown = gradioApp().querySelector("#setting_sd_model_checkpoint select");
currentModelHash = [...modelDropdown.value.matchAll(MODEL_HASH_REGEX)][0][1]; // Set initial model hash
modelDropdown.addEventListener("change", () => {
currentModelName = modelDropdown.value;
modelDropdown?.addEventListener("change", () => {
setTimeout(() => {
currentModelHash = [...modelDropdown.value.matchAll(MODEL_HASH_REGEX)][0][1];
currentModelName = modelDropdown.value;
}, 100);
});
// Add mutation observer for the model hash text to also allow hash-based blacklist again
let modelHashText = gradioApp().querySelector("#sd_checkpoint_hash");
if (modelHashText) {
currentModelHash = modelHashText.title
let modelHashObserver = new MutationObserver((mutationList, observer) => {
for (const mutation of mutationList) {
if (mutation.type === "attributes" && mutation.attributeName === "title") {
currentModelHash = mutation.target.title;
}
}
});
modelHashObserver.observe(modelHashText, { attributes: true });
}
// Not found, we're on a page without prompt textareas
if (textAreas.every(v => v === null || v === undefined)) return;

View File

@@ -20,6 +20,8 @@ TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
LORA_PATH = Path(shared.cmd_opts.lora_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
def find_ext_wildcard_paths():
@@ -137,6 +139,22 @@ def get_embeddings(sd_model):
write_to_temp_file('emb.txt', results)
def get_hypernetworks():
"""Write a list of all hypernetworks"""
# Get a list of all hypernetworks in the folder
all_hypernetworks = [str(h.name) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
# Remove file extensions
return [h[:h.rfind('.')] for h in all_hypernetworks]
def get_lora():
"""Write a list of all lora"""
# Get a list of all lora in the folder
all_lora = [str(l.name) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return [l[:l.rfind('.')] for l in all_lora]
def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
@@ -178,6 +196,8 @@ if not TEMP_PATH.exists():
write_to_temp_file('wc.txt', [])
write_to_temp_file('wce.txt', [])
write_to_temp_file('wcet.txt', [])
write_to_temp_file('hyp.txt', [])
write_to_temp_file('lora.txt', [])
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
if not TEMP_PATH.joinpath("emb.txt").exists():
write_to_temp_file('emb.txt', [])
@@ -202,7 +222,16 @@ if WILDCARD_EXT_PATHS is not None:
if EMB_PATH.exists():
# Get embeddings after the model loaded callback
script_callbacks.on_model_loaded(get_embeddings)
if HYP_PATH.exists():
hypernets = get_hypernetworks()
if hypernets:
write_to_temp_file('hyp.txt', hypernets)
if LORA_PATH.exists():
lora = get_lora()
if lora:
write_to_temp_file('lora.txt', lora)
# Register autocomplete options
def on_ui_settings():
@@ -215,8 +244,8 @@ def on_ui_settings():
shared.opts.add_option("tac_activeIn.img2img", shared.OptionInfo(True, "Active in img2img (Requires restart)", section=TAC_SECTION))
shared.opts.add_option("tac_activeIn.negativePrompts", shared.OptionInfo(True, "Active in negative prompts (Requires restart)", section=TAC_SECTION))
shared.opts.add_option("tac_activeIn.thirdParty", shared.OptionInfo(True, "Active in third party textboxes [Dataset Tag Editor] (Requires restart)", section=TAC_SECTION))
shared.opts.add_option("tac_activeIn.modelList", shared.OptionInfo("", "List of model hashes to use as black/whitelist, separated by commas.", section=TAC_SECTION))
shared.opts.add_option("tac_activeIn.modelListMode", shared.OptionInfo("Blacklist", "Mode to use for model hash list", gr.Dropdown, lambda: {"choices": ["Blacklist","Whitelist"]}, section=TAC_SECTION))
shared.opts.add_option("tac_activeIn.modelList", shared.OptionInfo("", "List of model names (with file extension) or their hashes to use as black/whitelist, separated by commas.", section=TAC_SECTION))
shared.opts.add_option("tac_activeIn.modelListMode", shared.OptionInfo("Blacklist", "Mode to use for model list", gr.Dropdown, lambda: {"choices": ["Blacklist","Whitelist"]}, section=TAC_SECTION))
# Results related settings
shared.opts.add_option("tac_maxResults", shared.OptionInfo(5, "Maximum results", section=TAC_SECTION))
shared.opts.add_option("tac_showAllResults", shared.OptionInfo(False, "Show all results", section=TAC_SECTION))
@@ -224,6 +253,8 @@ def on_ui_settings():
shared.opts.add_option("tac_delayTime", shared.OptionInfo(100, "Time in ms to wait before triggering completion again (Requires restart)", section=TAC_SECTION))
shared.opts.add_option("tac_useWildcards", shared.OptionInfo(True, "Search for wildcards", section=TAC_SECTION))
shared.opts.add_option("tac_useEmbeddings", shared.OptionInfo(True, "Search for embeddings", section=TAC_SECTION))
shared.opts.add_option("tac_useHypernetworks", shared.OptionInfo(True, "Search for hypernetworks", section=TAC_SECTION))
shared.opts.add_option("tac_useLoras", shared.OptionInfo(True, "Search for Loras", section=TAC_SECTION))
shared.opts.add_option("tac_showWikiLinks", shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page (Warning: This is an external site and very likely contains NSFW examples!)", section=TAC_SECTION))
# Insertion related settings
shared.opts.add_option("tac_replaceUnderscores", shared.OptionInfo(True, "Replace underscores with spaces on insertion", section=TAC_SECTION))