Simplify lora and hypernetwork loading

This commit is contained in:
Dominik Reh
2023-01-24 14:08:11 +01:00
parent ae01f41f30
commit b29b496b88
3 changed files with 29 additions and 84 deletions

View File

@@ -7,7 +7,7 @@ const ResultType = Object.freeze({
"wildcardTag": 3,
"wildcardFile": 4,
"yamlWildcard": 5,
"hypernetworks": 6,
"hypernetwork": 6,
"lora": 7
});

View File

@@ -314,7 +314,7 @@ function insertTextAtCursor(textArea, result, tagword) {
sanitizedText = text.replaceAll("_", " "); // Replace underscores only if the yaml tag is not using them
} else if (tagType === ResultType.embedding) {
sanitizedText = `${text.replace(/^.*?: /g, "")}`;
} else if (tagType === ResultType.hypernetworks) {
} else if (tagType === ResultType.hypernetwork) {
sanitizedText = `<hypernet:${text.replace(/^.*?: /g, "")}:1>`;
} else if(tagType === ResultType.lora) {
sanitizedText = `<lora:${text.replace(/^.*?: /g, "")}:1>`;
@@ -573,7 +573,7 @@ var yamlWildcards = [];
var umiPreviousTags = [];
var embeddings = [];
var hypernetworks = [];
var lora = [];
var loras = [];
var results = [];
var tagword = "";
var originalTagword = "";
@@ -884,69 +884,26 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
let tempResults = [];
if (tagword !== "<h:") {
let searchTerm = tagword.replace("<h:", "")
let versionString;
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
searchTerm = searchTerm.slice(2);
}
if (versionString)
tempResults = hypernetworks.filter(x => x[0].toLowerCase().includes(searchTerm) && x[1] && x[1] === versionString); // Filter by tagword
else
tempResults = hypernetworks.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
tempResults = hypernetworks.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
} else {
tempResults = hypernetworks;
}
// Since some tags are kaomoji, we have to still get the normal results first.
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
tagword = tagword.slice(1);
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
} else {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
// Add final results
tempResults.forEach(t => {
let result = new AutocompleteResult(t[0].trim(), ResultType.hypernetworks)
let result = new AutocompleteResult(t[0].trim(), ResultType.hypernetwork)
result.meta = t[1] + " Hypernetworks";
results.push(result);
});
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
result.count = g[2];
result.aliases = g[3];
results.push(result);
});
} else if(tagword.match(/<l:[^,> ]*>?/g)){
// Show lora
let tempResults = [];
if (tagword !== "<l:") {
let searchTerm = tagword.replace("<l:", "")
let versionString;
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
searchTerm = searchTerm.slice(2);
}
if (versionString)
tempResults = lora.filter(x => x[0].toLowerCase().includes(searchTerm) && x[1] && x[1] === versionString); // Filter by tagword
else
tempResults = lora.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
tempResults = loras.filter(x => x[0].toLowerCase().includes(searchTerm)); // Filter by tagword
} else {
tempResults = lora;
tempResults = loras;
}
// Since some tags are kaomoji, we have to still get the normal results first.
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
if (tagword.startsWith("*")) {
tagword = tagword.slice(1);
searchRegex = new RegExp(`${escapeRegExp(tagword)}`, 'i');
} else {
searchRegex = new RegExp(`(^|[^a-zA-Z])${escapeRegExp(tagword)}`, 'i');
}
let genericResults = allTags.filter(x => x[0].toLowerCase().search(searchRegex) > -1).slice(0, CFG.maxResults);
// Add final results
tempResults.forEach(t => {
@@ -954,13 +911,6 @@ async function autocomplete(textArea, prompt, fixedTag = null) {
result.meta = t[1] + " Lora";
results.push(result);
});
genericResults.forEach(g => {
let result = new AutocompleteResult(g[0].trim(), ResultType.tag)
result.category = g[1];
result.count = g[2];
result.aliases = g[3];
results.push(result);
});
} else {
// Create escaped search regex with support for * as a start placeholder
let searchRegex;
@@ -1174,17 +1124,15 @@ async function setup() {
try {
hypernetworks = (await readFile(`${tagBasePath}/temp/hyp.txt?${new Date().getTime()}`)).split("\n")
.filter(x => x.trim().length > 0) //Remove empty lines
.map(x => x.trim().split(",")); // Split into name, version type pairs
} catch (e) {
console.error("Error loading hypernetworks.txt: " + e);
}
}
// Load lora
if (lora.length === 0) {
if (loras.length === 0) {
try {
lora = (await readFile(`${tagBasePath}/temp/lora.txt?${new Date().getTime()}`)).split("\n")
loras = (await readFile(`${tagBasePath}/temp/lora.txt?${new Date().getTime()}`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.trim().split(",")); // Split into name, version type pairs
} catch (e) {
console.error("Error loading lora.txt: " + e);
}

View File

@@ -139,35 +139,21 @@ def get_embeddings(sd_model):
write_to_temp_file('emb.txt', results)
def get_hypernetworks(sd_model):
def get_hypernetworks():
"""Write a list of all hypernetworks"""
results = []
# Get a list of all hypernetworks in the folder
all_hypernetworks = [str(h.relative_to(HYP_PATH)) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
# Remove files with a size of 0
all_hypernetworks = [h for h in all_hypernetworks if HYP_PATH.joinpath(h).stat().st_size > 0]
all_hypernetworks = [str(h.name) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
# Remove file extensions
all_hypernetworks = [h[:h.rfind('.')] for h in all_hypernetworks]
results = [h + "," for h in all_hypernetworks]
return [h[:h.rfind('.')] for h in all_hypernetworks]
write_to_temp_file('hyp.txt', results)
def get_lora(sd_model):
def get_lora():
"""Write a list of all lora"""
results = []
# Get a list of all lora in the folder
all_lora = [str(l.relative_to(LORA_PATH)) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors"}]
# Remove files with a size of 0
all_lora = [l for l in all_lora if LORA_PATH.joinpath(l).stat().st_size > 0]
all_lora = [str(l.name) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
all_lora = [l[:l.rfind('.')] for l in all_lora]
results = [l + "," for l in all_lora]
write_to_temp_file('lora.txt', results)
return [l[:l.rfind('.')] for l in all_lora]
def write_tag_base_path():
@@ -210,6 +196,8 @@ if not TEMP_PATH.exists():
write_to_temp_file('wc.txt', [])
write_to_temp_file('wce.txt', [])
write_to_temp_file('wcet.txt', [])
write_to_temp_file('hyp.txt', [])
write_to_temp_file('lora.txt', [])
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
if not TEMP_PATH.joinpath("emb.txt").exists():
write_to_temp_file('emb.txt', [])
@@ -234,9 +222,16 @@ if WILDCARD_EXT_PATHS is not None:
if EMB_PATH.exists():
# Get embeddings after the model loaded callback
script_callbacks.on_model_loaded(get_embeddings)
script_callbacks.on_model_loaded(get_hypernetworks)
script_callbacks.on_model_loaded(get_lora)
if HYP_PATH.exists():
hypernets = get_hypernetworks()
if hypernets:
write_to_temp_file('hyp.txt', hypernets)
if LORA_PATH.exists():
lora = get_lora()
if lora:
write_to_temp_file('lora.txt', lora)
# Register autocomplete options
def on_ui_settings():
@@ -258,6 +253,8 @@ def on_ui_settings():
shared.opts.add_option("tac_delayTime", shared.OptionInfo(100, "Time in ms to wait before triggering completion again (Requires restart)", section=TAC_SECTION))
shared.opts.add_option("tac_useWildcards", shared.OptionInfo(True, "Search for wildcards", section=TAC_SECTION))
shared.opts.add_option("tac_useEmbeddings", shared.OptionInfo(True, "Search for embeddings", section=TAC_SECTION))
shared.opts.add_option("tac_useHypernetworks", shared.OptionInfo(True, "Search for hypernetworks", section=TAC_SECTION))
shared.opts.add_option("tac_useLora", shared.OptionInfo(True, "Search for Loras", section=TAC_SECTION))
shared.opts.add_option("tac_showWikiLinks", shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page (Warning: This is an external site and very likely contains NSFW examples!)", section=TAC_SECTION))
# Insertion related settings
shared.opts.add_option("tac_replaceUnderscores", shared.OptionInfo(True, "Replace underscores with spaces on insertion", section=TAC_SECTION))