diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index b370c42..5bbbd3e 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -1,13 +1,16 @@ # This helper script scans folders for wildcards and embeddings and writes them # to a temporary file to expose it to the javascript side -import gradio as gr +import glob from pathlib import Path -from modules import scripts, script_callbacks, shared, sd_hijack + +import gradio as gr import yaml +from modules import script_callbacks, scripts, sd_hijack, shared try: - from modules.paths import script_path, extensions_dir + from modules.paths import extensions_dir, script_path + # Webui root path FILE_DIR = Path(script_path) @@ -159,7 +162,8 @@ def get_hypernetworks(): """Write a list of all hypernetworks""" # Get a list of all hypernetworks in the folder - all_hypernetworks = [str(h.name) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}] + hyp_paths = [Path(h) for h in glob.glob(HYP_PATH.joinpath("**/*").as_posix(), recursive=True)] + all_hypernetworks = [str(h.name) for h in hyp_paths if h.suffix in {".pt"}] # Remove file extensions return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower()) @@ -167,7 +171,8 @@ def get_lora(): """Write a list of all lora""" # Get a list of all lora in the folder - all_lora = [str(l.name) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors", ".ckpt", ".pt"}] + lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)] + all_lora = [str(l.name) for l in lora_paths if l.suffix in {".safetensors", ".ckpt", ".pt"}] # Remove file extensions return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower()) @@ -175,7 +180,8 @@ def get_lyco(): """Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris""" # Get a list of all LyCORIS in the folder - all_lyco = [str(ly.name) for ly in LYCO_PATH.rglob("*") if ly.suffix in {".safetensors", ".ckpt", ".pt"}] + lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)] + all_lyco = [str(ly.name) for ly in lyco_paths if ly.suffix in {".safetensors", ".ckpt", ".pt"}] # Remove file extensions return sorted([ly[:ly.rfind('.')] for ly in all_lyco], key=lambda x: x.lower())