From 5ef346cde3078537ee79ba72dc9b1cfd8c4793b6 Mon Sep 17 00:00:00 2001 From: DominikDoom <34448969+DominikDoom@users.noreply.github.com> Date: Mon, 11 Dec 2023 11:37:12 +0100 Subject: [PATCH] Attempt to use the build-in Lora.networks Lora/LyCORIS models lists (#258) Co-authored-by: Midcoastal --- scripts/shared_paths.py | 27 +++++----- scripts/tag_autocomplete_helper.py | 80 ++++++++++++++++++++++++++---- 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/scripts/shared_paths.py b/scripts/shared_paths.py index 971fbc9..ca4803b 100644 --- a/scripts/shared_paths.py +++ b/scripts/shared_paths.py @@ -6,31 +6,34 @@ try: from modules.paths import extensions_dir, script_path # Webui root path - FILE_DIR = Path(script_path) + FILE_DIR = Path(script_path).absolute() # The extension base path - EXT_PATH = Path(extensions_dir) + EXT_PATH = Path(extensions_dir).absolute() except ImportError: # Webui root path FILE_DIR = Path().absolute() # The extension base path - EXT_PATH = FILE_DIR.joinpath("extensions") + EXT_PATH = FILE_DIR.joinpath("extensions").absolute() # Tags base path -TAGS_PATH = Path(scripts.basedir()).joinpath("tags") +TAGS_PATH = Path(scripts.basedir()).joinpath("tags").absolute() # The path to the folder containing the wildcards and embeddings -WILDCARD_PATH = FILE_DIR.joinpath("scripts/wildcards") -EMB_PATH = Path(shared.cmd_opts.embeddings_dir) -HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir) +WILDCARD_PATH = FILE_DIR.joinpath("scripts/wildcards").absolute() +EMB_PATH = Path(shared.cmd_opts.embeddings_dir).absolute() +HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir).absolute() try: - LORA_PATH = Path(shared.cmd_opts.lora_dir) + LORA_PATH = Path(shared.cmd_opts.lora_dir).absolute() except AttributeError: LORA_PATH = None try: - LYCO_PATH = Path(shared.cmd_opts.lyco_dir_backcompat) + try: + LYCO_PATH = Path(shared.cmd_opts.lyco_dir_backcompat).absolute() + except: + LYCO_PATH = Path(shared.cmd_opts.lyco_dir).absolute() # attempt original non-backcompat path except AttributeError: LYCO_PATH = None @@ -49,7 +52,7 @@ def find_ext_wildcard_paths(): getattr(shared.cmd_opts, "wildcards_dir", None), # Cmd arg from the wildcard extension getattr(opts, "wildcard_dir", None), # Custom path from sd-dynamic-prompts ] - for path in [Path(p) for p in custom_paths if p is not None]: + for path in [Path(p).absolute() for p in custom_paths if p is not None]: if path.exists(): found.append(path) @@ -61,8 +64,8 @@ WILDCARD_EXT_PATHS = find_ext_wildcard_paths() # The path to the temporary files # In the webui root, on windows it exists by default, on linux it doesn't -STATIC_TEMP_PATH = FILE_DIR.joinpath("tmp") -TEMP_PATH = TAGS_PATH.joinpath("temp") # Extension specific temp files +STATIC_TEMP_PATH = FILE_DIR.joinpath("tmp").absolute() +TEMP_PATH = TAGS_PATH.joinpath("temp").absolute() # Extension specific temp files # Make sure these folders exist if not TEMP_PATH.exists(): diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 38e34bc..029d650 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -202,14 +202,77 @@ def get_hypernetworks(): return sort_models(all_hypernetworks) model_keyword_installed = write_model_keyword_path() + + +def _get_lora(): + """ + Write a list of all lora. + Fallback method for when the built-in Lora.networks module is not available. + """ + # Get a list of all lora in the folder + lora_paths = [ + Path(l) + for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True) + ] + # Get hashes + valid_loras = [ + lf + for lf in lora_paths + if lf.suffix in {".safetensors", ".ckpt", ".pt"} and lf.is_file() + ] + + return valid_loras + + +def _get_lyco(): + """ + Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris + Fallback method for when the built-in Lora.networks module is not available. + """ + # Get a list of all LyCORIS in the folder + lyco_paths = [ + Path(ly) + for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True) + ] + + # Get hashes + valid_lycos = [ + lyf + for lyf in lyco_paths + if lyf.suffix in {".safetensors", ".ckpt", ".pt"} and lyf.is_file() + ] + return valid_lycos + + +# Attempt to use the build-in Lora.networks Lora/LyCORIS models lists. +try: + import importlib + lora_networks = importlib.import_module("extensions-builtin.Lora.networks") + + def _get_lora(): + return [ + Path(model.filename).absolute() + for model in lora_networks.available_networks.values() + if Path(model.filename).absolute().is_relative_to(LORA_PATH) + ] + + def _get_lyco(): + return [ + Path(model.filename).absolute() + for model in lora_networks.available_networks.values() + if Path(model.filename).absolute().is_relative_to(LYCO_PATH) + ] + +except Exception as e: + pass + # no need to report + # print(f'Exception setting-up performant fetchers: {e}') + + def get_lora(): """Write a list of all lora""" - global model_keyword_installed - - # Get a list of all lora in the folder - lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)] # Get hashes - valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"} and lf.is_file()] + valid_loras = _get_lora() loras_with_hash = [] for l in valid_loras: name = l.relative_to(LORA_PATH).as_posix() @@ -224,12 +287,8 @@ def get_lora(): def get_lyco(): """Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris""" - - # Get a list of all LyCORIS in the folder - lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)] - # Get hashes - valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"} and lyf.is_file()] + valid_lycos = _get_lyco() lycos_with_hash = [] for ly in valid_lycos: name = ly.relative_to(LYCO_PATH).as_posix() @@ -241,6 +300,7 @@ def get_lyco(): # Sort return sort_models(lycos_with_hash) + def write_tag_base_path(): """Writes the tag base path to a fixed location temporary file""" with open(STATIC_TEMP_PATH.joinpath('tagAutocompletePath.txt'), 'w', encoding="utf-8") as f: