From 3496fa58d9b789d91e86a6e1c617e151e06e36e9 Mon Sep 17 00:00:00 2001 From: DominikDoom Date: Sat, 8 Jul 2023 16:35:56 +0200 Subject: [PATCH] Add trigger word completion using the model-keyword extension Works for both the built-in and user defined list Restructure some of the python helper for path reusability --- javascript/__globals.js | 2 + javascript/_result.js | 1 + javascript/ext_loras.js | 7 ++- javascript/ext_lycos.js | 7 ++- javascript/ext_modelKeyword.js | 29 ++++++++++ javascript/tagAutocomplete.js | 17 +++++- scripts/model_keyword_support.py | 76 ++++++++++++++++++++++++ scripts/shared_paths.py | 47 +++++++++++++++ scripts/tag_autocomplete_helper.py | 93 ++++++++++++++---------------- 9 files changed, 221 insertions(+), 58 deletions(-) create mode 100644 javascript/ext_modelKeyword.js create mode 100644 scripts/model_keyword_support.py create mode 100644 scripts/shared_paths.py diff --git a/javascript/__globals.js b/javascript/__globals.js index 166744b..a066ce2 100644 --- a/javascript/__globals.js +++ b/javascript/__globals.js @@ -1,6 +1,7 @@ // Core components var TAC_CFG = null; var tagBasePath = ""; +var modelKeywordPath = ""; // Tag completion data loaded from files var allTags = []; @@ -14,6 +15,7 @@ var embeddings = []; var hypernetworks = []; var loras = []; var lycos = []; +var modelKeywordDict = new Map(); var chants = []; // Selected model info for black/whitelisting diff --git a/javascript/_result.js b/javascript/_result.js index 0bf02d1..6953ba6 100644 --- a/javascript/_result.js +++ b/javascript/_result.js @@ -25,6 +25,7 @@ class AutocompleteResult { count = null; aliases = null; meta = null; + hash = null; // Constructor constructor(text, type) { diff --git a/javascript/ext_loras.js b/javascript/ext_loras.js index dff4a1b..eb72644 100644 --- a/javascript/ext_loras.js +++ b/javascript/ext_loras.js @@ -8,7 +8,7 @@ class LoraParser extends BaseTagParser { if (tagword !== "<" && tagword !== " x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm); - tempResults = loras.filter(x => filterCondition(x)); // Filter by tagword + tempResults = loras.filter(x => filterCondition(x[0])); // Filter by tagword } else { tempResults = loras; } @@ -16,8 +16,9 @@ class LoraParser extends BaseTagParser { // Add final results let finalResults = []; tempResults.forEach(t => { - let result = new AutocompleteResult(t.trim(), ResultType.lora) + let result = new AutocompleteResult(t[0].trim(), ResultType.lora) result.meta = "Lora"; + result.hash = t[1]; finalResults.push(result); }); @@ -30,7 +31,7 @@ async function load() { try { loras = (await readFile(`${tagBasePath}/temp/lora.txt`)).split("\n") .filter(x => x.trim().length > 0) // Remove empty lines - .map(x => x.trim()); // Remove carriage returns and padding if it exists + .map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs } catch (e) { console.error("Error loading lora.txt: " + e); } diff --git a/javascript/ext_lycos.js b/javascript/ext_lycos.js index 86f7552..5effb9d 100644 --- a/javascript/ext_lycos.js +++ b/javascript/ext_lycos.js @@ -8,7 +8,7 @@ class LycoParser extends BaseTagParser { if (tagword !== "<" && tagword !== " x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm); - tempResults = lycos.filter(x => filterCondition(x)); // Filter by tagword + tempResults = lycos.filter(x => filterCondition(x[0])); // Filter by tagword } else { tempResults = lycos; } @@ -16,8 +16,9 @@ class LycoParser extends BaseTagParser { // Add final results let finalResults = []; tempResults.forEach(t => { - let result = new AutocompleteResult(t.trim(), ResultType.lyco) + let result = new AutocompleteResult(t[0].trim(), ResultType.lyco) result.meta = "Lyco"; + result.hash = t[1]; finalResults.push(result); }); @@ -30,7 +31,7 @@ async function load() { try { lycos = (await readFile(`${tagBasePath}/temp/lyco.txt`)).split("\n") .filter(x => x.trim().length > 0) // Remove empty lines - .map(x => x.trim()); // Remove carriage returns and padding if it exists + .map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs } catch (e) { console.error("Error loading lyco.txt: " + e); } diff --git a/javascript/ext_modelKeyword.js b/javascript/ext_modelKeyword.js new file mode 100644 index 0000000..8fe98c2 --- /dev/null +++ b/javascript/ext_modelKeyword.js @@ -0,0 +1,29 @@ +async function load() { + let modelKeywordParts = (await readFile(`tmp/modelKeywordPath.txt`)).split(",") + modelKeywordPath = modelKeywordParts[0]; + let customFileExists = modelKeywordParts[1] === "True"; + + if (modelKeywordPath.length > 0 && modelKeywordDict.size === 0) { + try { + let lines = (await readFile(`${modelKeywordPath}/lora-keyword.txt`)).split("\n"); + // Add custom user keywords if the file exists + if (customFileExists) + lines = lines.concat((await readFile(`${modelKeywordPath}/lora-keyword-user.txt`)).split("\n")); + + lines = lines.filter(x => x.trim().length > 0 && x.trim()[0] !== "#") // Remove empty lines and comments + + // Add to the dict + lines.forEach(line => { + const parts = line.split(","); + const hash = parts[0]; + const keywords = parts[1].replaceAll("| ", ", ").replaceAll("|", ", ").trim(); + + modelKeywordDict.set(hash, keywords); + }); + } catch (e) { + console.error("Error loading model-keywords list: " + e); + } + } +} + +QUEUE_FILE_LOAD.push(load); \ No newline at end of file diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index 6d3360a..3fa622b 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -202,6 +202,7 @@ async function syncOptions() { appendSpace: opts["tac_appendSpace"], alwaysSpaceAtEnd: opts["tac_alwaysSpaceAtEnd"], wildcardCompletionMode: opts["tac_wildcardCompletionMode"], + modelKeywordCompletion: opts["tac_modelKeywordCompletion"], // Alias settings alias: { searchByAlias: opts["tac_alias.searchByAlias"], @@ -441,8 +442,22 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout // Add back start var newPrompt = prompt.substring(0, editStart) + insert + prompt.substring(editEnd); + + // Add lora/lyco keywords if enabled and found + let keywordsLength = 0; + if (TAC_CFG.modelKeywordCompletion && modelKeywordPath.length > 0 && (tagType === ResultType.lora || tagType === ResultType.lyco)) { + if (result.hash && result.hash !== "NOFILE" && result.hash.length > 0) { + let keywords = modelKeywordDict.get(result.hash); + if (keywords && keywords.length > 0) { + newPrompt = `${keywords}, ${newPrompt}`; + keywordsLength = keywords.length + 2; // +2 for the comma and space + } + } + } + + // Insert into prompt textbox and reposition cursor textArea.value = newPrompt; - textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length; + textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength; textArea.selectionEnd = textArea.selectionStart // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python. diff --git a/scripts/model_keyword_support.py b/scripts/model_keyword_support.py new file mode 100644 index 0000000..897eee8 --- /dev/null +++ b/scripts/model_keyword_support.py @@ -0,0 +1,76 @@ +# This file provides support for the model-keyword extension to add known lora keywords on completion + +import hashlib +from pathlib import Path + +from scripts.shared_paths import EXT_PATH, STATIC_TEMP_PATH, TEMP_PATH + +# Set up our hash cache +known_hashes_file = TEMP_PATH.joinpath("known_lora_hashes.txt") +known_hashes_file.touch() +file_needs_update = False + +# Load the hashes from the file +hash_dict = {} + + +def load_hash_cache(): + with open(known_hashes_file, "r") as file: + for line in file: + name, hash, mtime = line.replace("\n", "").split(",") + hash_dict[name] = (hash, mtime) + + +def update_hash_cache(): + global file_needs_update + if file_needs_update: + with open(known_hashes_file, "w") as file: + for name, (hash, mtime) in hash_dict.items(): + file.write(f"{name},{hash},{mtime}\n") + + +# Copy of the fast inaccurate hash function from the extension +# with some modifications to load from and write to the cache +def get_lora_simple_hash(path): + global file_needs_update + mtime = str(Path(path).stat().st_mtime) + filename = Path(path).name + + if filename in hash_dict: + (hash, old_mtime) = hash_dict[filename] + if mtime == old_mtime: + return hash + try: + with open(path, "rb") as file: + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + hash = m.hexdigest()[0:8] + + hash_dict[filename] = (hash, mtime) + file_needs_update = True + + return hash + except FileNotFoundError: + return "NOFILE" + + +# Find the path of the original model-keyword extension +def write_model_keyword_path(): + # Ensure the file exists even if the extension is not installed + mk_path = STATIC_TEMP_PATH.joinpath("modelKeywordPath.txt") + mk_path.write_text("") + + base_keywords = list(EXT_PATH.glob("*/lora-keyword.txt")) + custom_keywords = list(EXT_PATH.glob("*/lora-keyword-user.txt")) + custom_found = custom_keywords is not None and len(custom_keywords) > 0 + if base_keywords is not None and len(base_keywords) > 0: + with open(mk_path, "w", encoding="utf-8") as f: + f.write(f"{base_keywords[0].parent.as_posix()},{custom_found}") + return True + else: + print( + "Tag Autocomplete: Could not locate model-keyword extension, LORA/LYCO trigger word completion will be unavailable." + ) + return False diff --git a/scripts/shared_paths.py b/scripts/shared_paths.py new file mode 100644 index 0000000..591df23 --- /dev/null +++ b/scripts/shared_paths.py @@ -0,0 +1,47 @@ +from pathlib import Path +from modules import scripts, shared + +try: + from modules.paths import extensions_dir, script_path + + # Webui root path + FILE_DIR = Path(script_path) + + # The extension base path + EXT_PATH = Path(extensions_dir) +except ImportError: + # Webui root path + FILE_DIR = Path().absolute() + # The extension base path + EXT_PATH = FILE_DIR.joinpath('extensions') + +# Tags base path +TAGS_PATH = Path(scripts.basedir()).joinpath('tags') + +# The path to the folder containing the wildcards and embeddings +WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards') +EMB_PATH = Path(shared.cmd_opts.embeddings_dir) +HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir) + +try: + LORA_PATH = Path(shared.cmd_opts.lora_dir) +except AttributeError: + LORA_PATH = None + +try: + LYCO_PATH = Path(shared.cmd_opts.lyco_dir) +except AttributeError: + LYCO_PATH = None + +def find_ext_wildcard_paths(): + """Returns the path to the extension wildcards folder""" + found = list(EXT_PATH.glob('*/wildcards/')) + return found + + +# The path to the extension wildcards folder +WILDCARD_EXT_PATHS = find_ext_wildcard_paths() + +# The path to the temporary files +STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't +TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files \ No newline at end of file diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 01e570d..4a1821d 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -6,52 +6,12 @@ from pathlib import Path import gradio as gr import yaml -from modules import script_callbacks, scripts, sd_hijack, shared +from modules import script_callbacks, sd_hijack, shared -try: - from modules.paths import extensions_dir, script_path - - # Webui root path - FILE_DIR = Path(script_path) - - # The extension base path - EXT_PATH = Path(extensions_dir) -except ImportError: - # Webui root path - FILE_DIR = Path().absolute() - # The extension base path - EXT_PATH = FILE_DIR.joinpath('extensions') - -# Tags base path -TAGS_PATH = Path(scripts.basedir()).joinpath('tags') - -# The path to the folder containing the wildcards and embeddings -WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards') -EMB_PATH = Path(shared.cmd_opts.embeddings_dir) -HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir) - -try: - LORA_PATH = Path(shared.cmd_opts.lora_dir) -except AttributeError: - LORA_PATH = None - -try: - LYCO_PATH = Path(shared.cmd_opts.lyco_dir) -except AttributeError: - LYCO_PATH = None - -def find_ext_wildcard_paths(): - """Returns the path to the extension wildcards folder""" - found = list(EXT_PATH.glob('*/wildcards/')) - return found - - -# The path to the extension wildcards folder -WILDCARD_EXT_PATHS = find_ext_wildcard_paths() - -# The path to the temporary files -STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't -TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files +from scripts.model_keyword_support import (get_lora_simple_hash, + load_hash_cache, update_hash_cache, + write_model_keyword_path) +from scripts.shared_paths import * def get_wildcards(): @@ -171,23 +131,47 @@ def get_hypernetworks(): # Remove file extensions return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower()) +model_keyword_installed = write_model_keyword_path() def get_lora(): """Write a list of all lora""" + global model_keyword_installed # Get a list of all lora in the folder lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)] - all_lora = [str(l.name) for l in lora_paths if l.suffix in {".safetensors", ".ckpt", ".pt"}] - # Remove file extensions - return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower()) + # Get hashes + valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}] + hashes = {} + for l in valid_loras: + name = l.name[:l.name.rfind('.')] + if model_keyword_installed: + hashes[name] = get_lora_simple_hash(l) + else: + hashes[name] = "" + + # Sort + sorted_loras = dict(sorted(hashes.items())) + # Add hashes and return + return [f"{name},{hash}" for name, hash in sorted_loras.items()] + def get_lyco(): """Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris""" # Get a list of all LyCORIS in the folder lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)] - all_lyco = [str(ly.name) for ly in lyco_paths if ly.suffix in {".safetensors", ".ckpt", ".pt"}] - # Remove file extensions - return sorted([ly[:ly.rfind('.')] for ly in all_lyco], key=lambda x: x.lower()) + + # Get hashes + valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}] + hashes = {} + for ly in valid_lycos: + name = ly.name[:ly.name.rfind('.')] + hashes[name] = get_lora_simple_hash(ly) + + # Sort + sorted_lycos = dict(sorted(hashes.items())) + # Add hashes and return + return [f"{name},{hash}" for name, hash in sorted_lycos.items()] + def write_tag_base_path(): """Writes the tag base path to a fixed location temporary file""" @@ -276,6 +260,9 @@ def write_temp_files(): if hypernets: write_to_temp_file('hyp.txt', hypernets) + if model_keyword_installed: + load_hash_cache() + if LORA_PATH is not None and LORA_PATH.exists(): lora = get_lora() if lora: @@ -286,6 +273,9 @@ def write_temp_files(): if lyco: write_to_temp_file('lyco.txt', lyco) + if model_keyword_installed: + update_hash_cache() + write_temp_files() @@ -334,6 +324,7 @@ def on_ui_settings(): "tac_appendComma": shared.OptionInfo(True, "Append comma on tag autocompletion"), "tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"), "tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"), + "tac_modelKeywordCompletion": shared.OptionInfo(False, "Try to add known trigger words for LORA/LyCO models", gr.Checkbox, lambda: {"interactive": model_keyword_installed}).info("Requires the model-keyword extension to be installed, but will work with it disabled"), "tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""), # Alias settings "tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),