Add trigger word completion using the model-keyword extension

Works for both the built-in and user defined list
Restructure some of the python helper for path reusability
This commit is contained in:
DominikDoom
2023-07-08 16:35:56 +02:00
parent 737b697357
commit 3496fa58d9
9 changed files with 221 additions and 58 deletions

View File

@@ -0,0 +1,76 @@
# This file provides support for the model-keyword extension to add known lora keywords on completion
import hashlib
from pathlib import Path
from scripts.shared_paths import EXT_PATH, STATIC_TEMP_PATH, TEMP_PATH
# Set up our hash cache
known_hashes_file = TEMP_PATH.joinpath("known_lora_hashes.txt")
known_hashes_file.touch()
file_needs_update = False
# Load the hashes from the file
hash_dict = {}
def load_hash_cache():
with open(known_hashes_file, "r") as file:
for line in file:
name, hash, mtime = line.replace("\n", "").split(",")
hash_dict[name] = (hash, mtime)
def update_hash_cache():
global file_needs_update
if file_needs_update:
with open(known_hashes_file, "w") as file:
for name, (hash, mtime) in hash_dict.items():
file.write(f"{name},{hash},{mtime}\n")
# Copy of the fast inaccurate hash function from the extension
# with some modifications to load from and write to the cache
def get_lora_simple_hash(path):
global file_needs_update
mtime = str(Path(path).stat().st_mtime)
filename = Path(path).name
if filename in hash_dict:
(hash, old_mtime) = hash_dict[filename]
if mtime == old_mtime:
return hash
try:
with open(path, "rb") as file:
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
hash = m.hexdigest()[0:8]
hash_dict[filename] = (hash, mtime)
file_needs_update = True
return hash
except FileNotFoundError:
return "NOFILE"
# Find the path of the original model-keyword extension
def write_model_keyword_path():
# Ensure the file exists even if the extension is not installed
mk_path = STATIC_TEMP_PATH.joinpath("modelKeywordPath.txt")
mk_path.write_text("")
base_keywords = list(EXT_PATH.glob("*/lora-keyword.txt"))
custom_keywords = list(EXT_PATH.glob("*/lora-keyword-user.txt"))
custom_found = custom_keywords is not None and len(custom_keywords) > 0
if base_keywords is not None and len(base_keywords) > 0:
with open(mk_path, "w", encoding="utf-8") as f:
f.write(f"{base_keywords[0].parent.as_posix()},{custom_found}")
return True
else:
print(
"Tag Autocomplete: Could not locate model-keyword extension, LORA/LYCO trigger word completion will be unavailable."
)
return False

47
scripts/shared_paths.py Normal file
View File

@@ -0,0 +1,47 @@
from pathlib import Path
from modules import scripts, shared
try:
from modules.paths import extensions_dir, script_path
# Webui root path
FILE_DIR = Path(script_path)
# The extension base path
EXT_PATH = Path(extensions_dir)
except ImportError:
# Webui root path
FILE_DIR = Path().absolute()
# The extension base path
EXT_PATH = FILE_DIR.joinpath('extensions')
# Tags base path
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
try:
LORA_PATH = Path(shared.cmd_opts.lora_dir)
except AttributeError:
LORA_PATH = None
try:
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
except AttributeError:
LYCO_PATH = None
def find_ext_wildcard_paths():
"""Returns the path to the extension wildcards folder"""
found = list(EXT_PATH.glob('*/wildcards/'))
return found
# The path to the extension wildcards folder
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
# The path to the temporary files
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files

View File

@@ -6,52 +6,12 @@ from pathlib import Path
import gradio as gr
import yaml
from modules import script_callbacks, scripts, sd_hijack, shared
from modules import script_callbacks, sd_hijack, shared
try:
from modules.paths import extensions_dir, script_path
# Webui root path
FILE_DIR = Path(script_path)
# The extension base path
EXT_PATH = Path(extensions_dir)
except ImportError:
# Webui root path
FILE_DIR = Path().absolute()
# The extension base path
EXT_PATH = FILE_DIR.joinpath('extensions')
# Tags base path
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
try:
LORA_PATH = Path(shared.cmd_opts.lora_dir)
except AttributeError:
LORA_PATH = None
try:
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
except AttributeError:
LYCO_PATH = None
def find_ext_wildcard_paths():
"""Returns the path to the extension wildcards folder"""
found = list(EXT_PATH.glob('*/wildcards/'))
return found
# The path to the extension wildcards folder
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
# The path to the temporary files
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files
from scripts.model_keyword_support import (get_lora_simple_hash,
load_hash_cache, update_hash_cache,
write_model_keyword_path)
from scripts.shared_paths import *
def get_wildcards():
@@ -171,23 +131,47 @@ def get_hypernetworks():
# Remove file extensions
return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower())
model_keyword_installed = write_model_keyword_path()
def get_lora():
"""Write a list of all lora"""
global model_keyword_installed
# Get a list of all lora in the folder
lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)]
all_lora = [str(l.name) for l in lora_paths if l.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower())
# Get hashes
valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
for l in valid_loras:
name = l.name[:l.name.rfind('.')]
if model_keyword_installed:
hashes[name] = get_lora_simple_hash(l)
else:
hashes[name] = ""
# Sort
sorted_loras = dict(sorted(hashes.items()))
# Add hashes and return
return [f"{name},{hash}" for name, hash in sorted_loras.items()]
def get_lyco():
"""Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris"""
# Get a list of all LyCORIS in the folder
lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)]
all_lyco = [str(ly.name) for ly in lyco_paths if ly.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return sorted([ly[:ly.rfind('.')] for ly in all_lyco], key=lambda x: x.lower())
# Get hashes
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
for ly in valid_lycos:
name = ly.name[:ly.name.rfind('.')]
hashes[name] = get_lora_simple_hash(ly)
# Sort
sorted_lycos = dict(sorted(hashes.items()))
# Add hashes and return
return [f"{name},{hash}" for name, hash in sorted_lycos.items()]
def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
@@ -276,6 +260,9 @@ def write_temp_files():
if hypernets:
write_to_temp_file('hyp.txt', hypernets)
if model_keyword_installed:
load_hash_cache()
if LORA_PATH is not None and LORA_PATH.exists():
lora = get_lora()
if lora:
@@ -286,6 +273,9 @@ def write_temp_files():
if lyco:
write_to_temp_file('lyco.txt', lyco)
if model_keyword_installed:
update_hash_cache()
write_temp_files()
@@ -334,6 +324,7 @@ def on_ui_settings():
"tac_appendComma": shared.OptionInfo(True, "Append comma on tag autocompletion"),
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
"tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"),
"tac_modelKeywordCompletion": shared.OptionInfo(False, "Try to add known trigger words for LORA/LyCO models", gr.Checkbox, lambda: {"interactive": model_keyword_installed}).info("Requires the <a href=\"https://github.com/mix1009/model-keyword\" target=\"_blank\">model-keyword</a> extension to be installed, but will work with it disabled"),
"tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""),
# Alias settings
"tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),