Simplify lora and hypernetwork loading

This commit is contained in:
Dominik Reh
2023-01-24 14:08:11 +01:00
parent ae01f41f30
commit b29b496b88
3 changed files with 29 additions and 84 deletions

View File

@@ -139,35 +139,21 @@ def get_embeddings(sd_model):
write_to_temp_file('emb.txt', results)
def get_hypernetworks(sd_model):
def get_hypernetworks():
"""Write a list of all hypernetworks"""
results = []
# Get a list of all hypernetworks in the folder
all_hypernetworks = [str(h.relative_to(HYP_PATH)) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
# Remove files with a size of 0
all_hypernetworks = [h for h in all_hypernetworks if HYP_PATH.joinpath(h).stat().st_size > 0]
all_hypernetworks = [str(h.name) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
# Remove file extensions
all_hypernetworks = [h[:h.rfind('.')] for h in all_hypernetworks]
results = [h + "," for h in all_hypernetworks]
return [h[:h.rfind('.')] for h in all_hypernetworks]
write_to_temp_file('hyp.txt', results)
def get_lora(sd_model):
def get_lora():
"""Write a list of all lora"""
results = []
# Get a list of all lora in the folder
all_lora = [str(l.relative_to(LORA_PATH)) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors"}]
# Remove files with a size of 0
all_lora = [l for l in all_lora if LORA_PATH.joinpath(l).stat().st_size > 0]
all_lora = [str(l.name) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
all_lora = [l[:l.rfind('.')] for l in all_lora]
results = [l + "," for l in all_lora]
write_to_temp_file('lora.txt', results)
return [l[:l.rfind('.')] for l in all_lora]
def write_tag_base_path():
@@ -210,6 +196,8 @@ if not TEMP_PATH.exists():
write_to_temp_file('wc.txt', [])
write_to_temp_file('wce.txt', [])
write_to_temp_file('wcet.txt', [])
write_to_temp_file('hyp.txt', [])
write_to_temp_file('lora.txt', [])
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
if not TEMP_PATH.joinpath("emb.txt").exists():
write_to_temp_file('emb.txt', [])
@@ -234,9 +222,16 @@ if WILDCARD_EXT_PATHS is not None:
if EMB_PATH.exists():
# Get embeddings after the model loaded callback
script_callbacks.on_model_loaded(get_embeddings)
script_callbacks.on_model_loaded(get_hypernetworks)
script_callbacks.on_model_loaded(get_lora)
if HYP_PATH.exists():
hypernets = get_hypernetworks()
if hypernets:
write_to_temp_file('hyp.txt', hypernets)
if LORA_PATH.exists():
lora = get_lora()
if lora:
write_to_temp_file('lora.txt', lora)
# Register autocomplete options
def on_ui_settings():
@@ -258,6 +253,8 @@ def on_ui_settings():
shared.opts.add_option("tac_delayTime", shared.OptionInfo(100, "Time in ms to wait before triggering completion again (Requires restart)", section=TAC_SECTION))
shared.opts.add_option("tac_useWildcards", shared.OptionInfo(True, "Search for wildcards", section=TAC_SECTION))
shared.opts.add_option("tac_useEmbeddings", shared.OptionInfo(True, "Search for embeddings", section=TAC_SECTION))
shared.opts.add_option("tac_useHypernetworks", shared.OptionInfo(True, "Search for hypernetworks", section=TAC_SECTION))
shared.opts.add_option("tac_useLora", shared.OptionInfo(True, "Search for Loras", section=TAC_SECTION))
shared.opts.add_option("tac_showWikiLinks", shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page (Warning: This is an external site and very likely contains NSFW examples!)", section=TAC_SECTION))
# Insertion related settings
shared.opts.add_option("tac_replaceUnderscores", shared.OptionInfo(True, "Replace underscores with spaces on insertion", section=TAC_SECTION))