mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 11:09:54 +00:00
add support for hypernetworks and lora
This commit is contained in:
@@ -6,7 +6,9 @@ const ResultType = Object.freeze({
|
||||
"embedding": 2,
|
||||
"wildcardTag": 3,
|
||||
"wildcardFile": 4,
|
||||
"yamlWildcard": 5
|
||||
"yamlWildcard": 5,
|
||||
"hypernetworks": 6,
|
||||
"lora": 7
|
||||
});
|
||||
|
||||
// Class to hold result data and annotations to make it clearer to use
|
||||
|
||||
@@ -314,6 +314,10 @@ function insertTextAtCursor(textArea, result, tagword) {
|
||||
sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them
|
||||
} else if (tagType === ResultType.embedding) {
|
||||
sanitizedText = `${text.replace(/^.*?: /g, "")}`;
|
||||
} else if (tagType === ResultType.hypernetworks) {
|
||||
sanitizedText = `<hypernet:${text.replace(/^.*?: /g, "")}:1>`;
|
||||
} else if(tagType === ResultType.lora) {
|
||||
sanitizedText = `<lora:${text.replace(/^.*?: /g, "")}:1>`;
|
||||
} else {
|
||||
sanitizedText = CFG.replaceUnderscores ? text.replaceAll("_", " ") : text;
|
||||
}
|
||||
@@ -568,6 +572,8 @@ var wildcardExtFiles = [];
|
||||
var yamlWildcards = [];
|
||||
var umiPreviousTags = [];
|
||||
var embeddings = [];
|
||||
var hypernetworks = [];
|
||||
var lora = [];
|
||||
var results = [];
|
||||
var tagword = "";
|
||||
var originalTagword = "";
|
||||
@@ -831,11 +837,11 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
originalTagword = tagword;
|
||||
tagword = "";
|
||||
}
|
||||
} else if (CFG.useEmbeddings && tagword.match(/<[^,> ]*>?/g)) {
|
||||
} else if (CFG.useEmbeddings && tagword.match(/<e:[^,> ]*>?/g)) {
|
||||
// Show embeddings
|
||||
let tempResults = [];
|
||||
if (tagword !== "<") {
|
||||
let searchTerm = tagword.replace("<", "")
|
||||
if (tagword !== "<e:") {
|
||||
let searchTerm = tagword.replace("<e:", "")
|
||||
let versionString;
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
@@ -872,6 +878,89 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
|
||||
result.aliases = g[3];
|
||||
results.push(result);
|
||||
});
|
||||
|
||||
} else if(tagword.match(/<h:[^,> ]*>?/g)) {
|
||||
// Show hypernetworks
|
||||
let tempResults = [];
|
||||
if (tagword !== "<h:") {
|
||||
let searchTerm = tagword.replace("<h:", "")
|
||||
let versionString;
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
searchTerm = searchTerm.slice(2);
|
||||
}
|
||||
if (versionString)
|
||||
tempResults = hypernetworks.filter(x => x[0].toLowerCase().includes(searchTerm) && x[1] && x[1] === versionString); // Filter by tagword
|
||||
else
|
||||
tempResults = hypernetworks.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = hypernetworks;
|
||||
}
|
||||
// Since some tags are kaomoji, we have to still get the normal results first.
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
if (tagword.startsWith("*")) {
|
||||
tagword = tagword.slice(1);
|
||||
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
|
||||
} else {
|
||||
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
|
||||
}
|
||||
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
|
||||
|
||||
// Add final results
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.hypernetworks)
|
||||
result.meta = t[1] + " Hypernetworks";
|
||||
results.push(result);
|
||||
});
|
||||
genericResults.forEach(g => {
|
||||
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
|
||||
result.category = g[1];
|
||||
result.count = g[2];
|
||||
result.aliases = g[3];
|
||||
results.push(result);
|
||||
});
|
||||
} else if(tagword.match(/<l:[^,> ]*>?/g)){
|
||||
// Show lora
|
||||
let tempResults = [];
|
||||
if (tagword !== "<l:") {
|
||||
let searchTerm = tagword.replace("<l:", "")
|
||||
let versionString;
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
searchTerm = searchTerm.slice(2);
|
||||
}
|
||||
if (versionString)
|
||||
tempResults = lora.filter(x => x[0].toLowerCase().includes(searchTerm) && x[1] && x[1] === versionString); // Filter by tagword
|
||||
else
|
||||
tempResults = lora.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
|
||||
} else {
|
||||
tempResults = lora;
|
||||
}
|
||||
// Since some tags are kaomoji, we have to still get the normal results first.
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
if (tagword.startsWith("*")) {
|
||||
tagword = tagword.slice(1);
|
||||
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
|
||||
} else {
|
||||
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
|
||||
}
|
||||
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
|
||||
|
||||
// Add final results
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.lora)
|
||||
result.meta = t[1] + " Lora";
|
||||
results.push(result);
|
||||
});
|
||||
genericResults.forEach(g => {
|
||||
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
|
||||
result.category = g[1];
|
||||
result.count = g[2];
|
||||
result.aliases = g[3];
|
||||
results.push(result);
|
||||
});
|
||||
} else {
|
||||
// Create escaped search regex with support for * as a start placeholder
|
||||
let searchRegex;
|
||||
@@ -1080,6 +1169,26 @@ async function setup() {
|
||||
console.error("Error loading embeddings.txt: " + e);
|
||||
}
|
||||
}
|
||||
// Load hypernetworks
|
||||
if (hypernetworks.length === 0) {
|
||||
try {
|
||||
hypernetworks = (await readFile(`${tagBasePath}/temp/hyp.txt?${new Date().getTime()}`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) //Remove empty lines
|
||||
.map(x => x.trim().split(",")); // Split into name, version type pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading hypernetworks.txt: " + e);
|
||||
}
|
||||
}
|
||||
// Load lora
|
||||
if (lora.length === 0) {
|
||||
try {
|
||||
lora = (await readFile(`${tagBasePath}/temp/lora.txt?${new Date().getTime()}`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => x.trim().split(",")); // Split into name, version type pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading lora.txt: " + e);
|
||||
}
|
||||
}
|
||||
|
||||
// Find all textareas
|
||||
let textAreas = getTextAreas();
|
||||
|
||||
@@ -20,6 +20,8 @@ TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
|
||||
# The path to the folder containing the wildcards and embeddings
|
||||
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
|
||||
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
|
||||
LORA_PATH = Path(shared.cmd_opts.lora_dir)
|
||||
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
|
||||
|
||||
|
||||
def find_ext_wildcard_paths():
|
||||
@@ -137,6 +139,36 @@ def get_embeddings(sd_model):
|
||||
|
||||
write_to_temp_file('emb.txt', results)
|
||||
|
||||
def get_hypernetworks(sd_model):
|
||||
"""Write a list of all hypernetworks"""
|
||||
|
||||
results = []
|
||||
|
||||
# Get a list of all hypernetworks in the folder
|
||||
all_hypernetworks = [str(h.relative_to(HYP_PATH)) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
|
||||
# Remove files with a size of 0
|
||||
all_hypernetworks = [h for h in all_hypernetworks if HYP_PATH.joinpath(h).stat().st_size > 0]
|
||||
# Remove file extensions
|
||||
all_hypernetworks = [h[:h.rfind('.')] for h in all_hypernetworks]
|
||||
results = [h + "," for h in all_hypernetworks]
|
||||
|
||||
write_to_temp_file('hyp.txt', results)
|
||||
|
||||
def get_lora(sd_model):
|
||||
"""Write a list of all lora"""
|
||||
|
||||
results = []
|
||||
|
||||
# Get a list of all lora in the folder
|
||||
all_lora = [str(l.relative_to(LORA_PATH)) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors"}]
|
||||
# Remove files with a size of 0
|
||||
all_lora = [l for l in all_lora if LORA_PATH.joinpath(l).stat().st_size > 0]
|
||||
# Remove file extensions
|
||||
all_lora = [l[:l.rfind('.')] for l in all_lora]
|
||||
results = [l + "," for l in all_lora]
|
||||
|
||||
write_to_temp_file('lora.txt', results)
|
||||
|
||||
|
||||
def write_tag_base_path():
|
||||
"""Writes the tag base path to a fixed location temporary file"""
|
||||
@@ -202,6 +234,8 @@ if WILDCARD_EXT_PATHS is not None:
|
||||
if EMB_PATH.exists():
|
||||
# Get embeddings after the model loaded callback
|
||||
script_callbacks.on_model_loaded(get_embeddings)
|
||||
script_callbacks.on_model_loaded(get_hypernetworks)
|
||||
script_callbacks.on_model_loaded(get_lora)
|
||||
|
||||
|
||||
# Register autocomplete options
|
||||
|
||||
Reference in New Issue
Block a user