Add trigger word completion using the model-keyword extension

Works for both the built-in and user defined list
Restructure some of the python helper for path reusability
This commit is contained in:
DominikDoom
2023-07-08 16:35:56 +02:00
parent 737b697357
commit 3496fa58d9
9 changed files with 221 additions and 58 deletions

View File

@@ -1,6 +1,7 @@
// Core components
var TAC_CFG = null;
var tagBasePath = "";
var modelKeywordPath = "";
// Tag completion data loaded from files
var allTags = [];
@@ -14,6 +15,7 @@ var embeddings = [];
var hypernetworks = [];
var loras = [];
var lycos = [];
var modelKeywordDict = new Map();
var chants = [];
// Selected model info for black/whitelisting

View File

@@ -25,6 +25,7 @@ class AutocompleteResult {
count = null;
aliases = null;
meta = null;
hash = null;
// Constructor
constructor(text, type) {

View File

@@ -8,7 +8,7 @@ class LoraParser extends BaseTagParser {
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lora:") {
let searchTerm = tagword.replace("<lora:", "").replace("<l:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
tempResults = loras.filter(x => filterCondition(x)); // Filter by tagword
tempResults = loras.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = loras;
}
@@ -16,8 +16,9 @@ class LoraParser extends BaseTagParser {
// Add final results
let finalResults = [];
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lora)
let result = new AutocompleteResult(t[0].trim(), ResultType.lora)
result.meta = "Lora";
result.hash = t[1];
finalResults.push(result);
});
@@ -30,7 +31,7 @@ async function load() {
try {
loras = (await readFile(`${tagBasePath}/temp/lora.txt`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.trim()); // Remove carriage returns and padding if it exists
.map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs
} catch (e) {
console.error("Error loading lora.txt: " + e);
}

View File

@@ -8,7 +8,7 @@ class LycoParser extends BaseTagParser {
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lyco:") {
let searchTerm = tagword.replace("<lyco:", "").replace("<l:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
tempResults = lycos.filter(x => filterCondition(x)); // Filter by tagword
tempResults = lycos.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = lycos;
}
@@ -16,8 +16,9 @@ class LycoParser extends BaseTagParser {
// Add final results
let finalResults = [];
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lyco)
let result = new AutocompleteResult(t[0].trim(), ResultType.lyco)
result.meta = "Lyco";
result.hash = t[1];
finalResults.push(result);
});
@@ -30,7 +31,7 @@ async function load() {
try {
lycos = (await readFile(`${tagBasePath}/temp/lyco.txt`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.trim()); // Remove carriage returns and padding if it exists
.map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs
} catch (e) {
console.error("Error loading lyco.txt: " + e);
}

View File

@@ -0,0 +1,29 @@
async function load() {
let modelKeywordParts = (await readFile(`tmp/modelKeywordPath.txt`)).split(",")
modelKeywordPath = modelKeywordParts[0];
let customFileExists = modelKeywordParts[1] === "True";
if (modelKeywordPath.length > 0 && modelKeywordDict.size === 0) {
try {
let lines = (await readFile(`${modelKeywordPath}/lora-keyword.txt`)).split("\n");
// Add custom user keywords if the file exists
if (customFileExists)
lines = lines.concat((await readFile(`${modelKeywordPath}/lora-keyword-user.txt`)).split("\n"));
lines = lines.filter(x => x.trim().length > 0 && x.trim()[0] !== "#") // Remove empty lines and comments
// Add to the dict
lines.forEach(line => {
const parts = line.split(",");
const hash = parts[0];
const keywords = parts[1].replaceAll("| ", ", ").replaceAll("|", ", ").trim();
modelKeywordDict.set(hash, keywords);
});
} catch (e) {
console.error("Error loading model-keywords list: " + e);
}
}
}
QUEUE_FILE_LOAD.push(load);

View File

@@ -202,6 +202,7 @@ async function syncOptions() {
appendSpace: opts["tac_appendSpace"],
alwaysSpaceAtEnd: opts["tac_alwaysSpaceAtEnd"],
wildcardCompletionMode: opts["tac_wildcardCompletionMode"],
modelKeywordCompletion: opts["tac_modelKeywordCompletion"],
// Alias settings
alias: {
searchByAlias: opts["tac_alias.searchByAlias"],
@@ -441,8 +442,22 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
// Add back start
var newPrompt = prompt.substring(0, editStart) + insert + prompt.substring(editEnd);
// Add lora/lyco keywords if enabled and found
let keywordsLength = 0;
if (TAC_CFG.modelKeywordCompletion && modelKeywordPath.length > 0 && (tagType === ResultType.lora || tagType === ResultType.lyco)) {
if (result.hash && result.hash !== "NOFILE" && result.hash.length > 0) {
let keywords = modelKeywordDict.get(result.hash);
if (keywords && keywords.length > 0) {
newPrompt = `${keywords}, ${newPrompt}`;
keywordsLength = keywords.length + 2; // +2 for the comma and space
}
}
}
// Insert into prompt textbox and reposition cursor
textArea.value = newPrompt;
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length;
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength;
textArea.selectionEnd = textArea.selectionStart
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.

View File

@@ -0,0 +1,76 @@
# This file provides support for the model-keyword extension to add known lora keywords on completion
import hashlib
from pathlib import Path
from scripts.shared_paths import EXT_PATH, STATIC_TEMP_PATH, TEMP_PATH
# Set up our hash cache
known_hashes_file = TEMP_PATH.joinpath("known_lora_hashes.txt")
known_hashes_file.touch()
file_needs_update = False
# Load the hashes from the file
hash_dict = {}
def load_hash_cache():
with open(known_hashes_file, "r") as file:
for line in file:
name, hash, mtime = line.replace("\n", "").split(",")
hash_dict[name] = (hash, mtime)
def update_hash_cache():
global file_needs_update
if file_needs_update:
with open(known_hashes_file, "w") as file:
for name, (hash, mtime) in hash_dict.items():
file.write(f"{name},{hash},{mtime}\n")
# Copy of the fast inaccurate hash function from the extension
# with some modifications to load from and write to the cache
def get_lora_simple_hash(path):
global file_needs_update
mtime = str(Path(path).stat().st_mtime)
filename = Path(path).name
if filename in hash_dict:
(hash, old_mtime) = hash_dict[filename]
if mtime == old_mtime:
return hash
try:
with open(path, "rb") as file:
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
hash = m.hexdigest()[0:8]
hash_dict[filename] = (hash, mtime)
file_needs_update = True
return hash
except FileNotFoundError:
return "NOFILE"
# Find the path of the original model-keyword extension
def write_model_keyword_path():
# Ensure the file exists even if the extension is not installed
mk_path = STATIC_TEMP_PATH.joinpath("modelKeywordPath.txt")
mk_path.write_text("")
base_keywords = list(EXT_PATH.glob("*/lora-keyword.txt"))
custom_keywords = list(EXT_PATH.glob("*/lora-keyword-user.txt"))
custom_found = custom_keywords is not None and len(custom_keywords) > 0
if base_keywords is not None and len(base_keywords) > 0:
with open(mk_path, "w", encoding="utf-8") as f:
f.write(f"{base_keywords[0].parent.as_posix()},{custom_found}")
return True
else:
print(
"Tag Autocomplete: Could not locate model-keyword extension, LORA/LYCO trigger word completion will be unavailable."
)
return False

47
scripts/shared_paths.py Normal file
View File

@@ -0,0 +1,47 @@
from pathlib import Path
from modules import scripts, shared
try:
from modules.paths import extensions_dir, script_path
# Webui root path
FILE_DIR = Path(script_path)
# The extension base path
EXT_PATH = Path(extensions_dir)
except ImportError:
# Webui root path
FILE_DIR = Path().absolute()
# The extension base path
EXT_PATH = FILE_DIR.joinpath('extensions')
# Tags base path
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
try:
LORA_PATH = Path(shared.cmd_opts.lora_dir)
except AttributeError:
LORA_PATH = None
try:
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
except AttributeError:
LYCO_PATH = None
def find_ext_wildcard_paths():
"""Returns the path to the extension wildcards folder"""
found = list(EXT_PATH.glob('*/wildcards/'))
return found
# The path to the extension wildcards folder
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
# The path to the temporary files
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files

View File

@@ -6,52 +6,12 @@ from pathlib import Path
import gradio as gr
import yaml
from modules import script_callbacks, scripts, sd_hijack, shared
from modules import script_callbacks, sd_hijack, shared
try:
from modules.paths import extensions_dir, script_path
# Webui root path
FILE_DIR = Path(script_path)
# The extension base path
EXT_PATH = Path(extensions_dir)
except ImportError:
# Webui root path
FILE_DIR = Path().absolute()
# The extension base path
EXT_PATH = FILE_DIR.joinpath('extensions')
# Tags base path
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
try:
LORA_PATH = Path(shared.cmd_opts.lora_dir)
except AttributeError:
LORA_PATH = None
try:
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
except AttributeError:
LYCO_PATH = None
def find_ext_wildcard_paths():
"""Returns the path to the extension wildcards folder"""
found = list(EXT_PATH.glob('*/wildcards/'))
return found
# The path to the extension wildcards folder
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
# The path to the temporary files
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files
from scripts.model_keyword_support import (get_lora_simple_hash,
load_hash_cache, update_hash_cache,
write_model_keyword_path)
from scripts.shared_paths import *
def get_wildcards():
@@ -171,23 +131,47 @@ def get_hypernetworks():
# Remove file extensions
return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower())
model_keyword_installed = write_model_keyword_path()
def get_lora():
"""Write a list of all lora"""
global model_keyword_installed
# Get a list of all lora in the folder
lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)]
all_lora = [str(l.name) for l in lora_paths if l.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower())
# Get hashes
valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
for l in valid_loras:
name = l.name[:l.name.rfind('.')]
if model_keyword_installed:
hashes[name] = get_lora_simple_hash(l)
else:
hashes[name] = ""
# Sort
sorted_loras = dict(sorted(hashes.items()))
# Add hashes and return
return [f"{name},{hash}" for name, hash in sorted_loras.items()]
def get_lyco():
"""Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris"""
# Get a list of all LyCORIS in the folder
lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)]
all_lyco = [str(ly.name) for ly in lyco_paths if ly.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return sorted([ly[:ly.rfind('.')] for ly in all_lyco], key=lambda x: x.lower())
# Get hashes
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
for ly in valid_lycos:
name = ly.name[:ly.name.rfind('.')]
hashes[name] = get_lora_simple_hash(ly)
# Sort
sorted_lycos = dict(sorted(hashes.items()))
# Add hashes and return
return [f"{name},{hash}" for name, hash in sorted_lycos.items()]
def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
@@ -276,6 +260,9 @@ def write_temp_files():
if hypernets:
write_to_temp_file('hyp.txt', hypernets)
if model_keyword_installed:
load_hash_cache()
if LORA_PATH is not None and LORA_PATH.exists():
lora = get_lora()
if lora:
@@ -286,6 +273,9 @@ def write_temp_files():
if lyco:
write_to_temp_file('lyco.txt', lyco)
if model_keyword_installed:
update_hash_cache()
write_temp_files()
@@ -334,6 +324,7 @@ def on_ui_settings():
"tac_appendComma": shared.OptionInfo(True, "Append comma on tag autocompletion"),
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
"tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"),
"tac_modelKeywordCompletion": shared.OptionInfo(False, "Try to add known trigger words for LORA/LyCO models", gr.Checkbox, lambda: {"interactive": model_keyword_installed}).info("Requires the <a href=\"https://github.com/mix1009/model-keyword\" target=\"_blank\">model-keyword</a> extension to be installed, but will work with it disabled"),
"tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""),
# Alias settings
"tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),