mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 11:09:54 +00:00
Add trigger word completion using the model-keyword extension
Works for both the built-in and user defined list Restructure some of the python helper for path reusability
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// Core components
|
||||
var TAC_CFG = null;
|
||||
var tagBasePath = "";
|
||||
var modelKeywordPath = "";
|
||||
|
||||
// Tag completion data loaded from files
|
||||
var allTags = [];
|
||||
@@ -14,6 +15,7 @@ var embeddings = [];
|
||||
var hypernetworks = [];
|
||||
var loras = [];
|
||||
var lycos = [];
|
||||
var modelKeywordDict = new Map();
|
||||
var chants = [];
|
||||
|
||||
// Selected model info for black/whitelisting
|
||||
|
||||
@@ -25,6 +25,7 @@ class AutocompleteResult {
|
||||
count = null;
|
||||
aliases = null;
|
||||
meta = null;
|
||||
hash = null;
|
||||
|
||||
// Constructor
|
||||
constructor(text, type) {
|
||||
|
||||
@@ -8,7 +8,7 @@ class LoraParser extends BaseTagParser {
|
||||
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lora:") {
|
||||
let searchTerm = tagword.replace("<lora:", "").replace("<l:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
tempResults = loras.filter(x => filterCondition(x)); // Filter by tagword
|
||||
tempResults = loras.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = loras;
|
||||
}
|
||||
@@ -16,8 +16,9 @@ class LoraParser extends BaseTagParser {
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t.trim(), ResultType.lora)
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.lora)
|
||||
result.meta = "Lora";
|
||||
result.hash = t[1];
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
@@ -30,7 +31,7 @@ async function load() {
|
||||
try {
|
||||
loras = (await readFile(`${tagBasePath}/temp/lora.txt`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => x.trim()); // Remove carriage returns and padding if it exists
|
||||
.map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading lora.txt: " + e);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ class LycoParser extends BaseTagParser {
|
||||
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lyco:") {
|
||||
let searchTerm = tagword.replace("<lyco:", "").replace("<l:", "").replace("<", "");
|
||||
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
tempResults = lycos.filter(x => filterCondition(x)); // Filter by tagword
|
||||
tempResults = lycos.filter(x => filterCondition(x[0])); // Filter by tagword
|
||||
} else {
|
||||
tempResults = lycos;
|
||||
}
|
||||
@@ -16,8 +16,9 @@ class LycoParser extends BaseTagParser {
|
||||
// Add final results
|
||||
let finalResults = [];
|
||||
tempResults.forEach(t => {
|
||||
let result = new AutocompleteResult(t.trim(), ResultType.lyco)
|
||||
let result = new AutocompleteResult(t[0].trim(), ResultType.lyco)
|
||||
result.meta = "Lyco";
|
||||
result.hash = t[1];
|
||||
finalResults.push(result);
|
||||
});
|
||||
|
||||
@@ -30,7 +31,7 @@ async function load() {
|
||||
try {
|
||||
lycos = (await readFile(`${tagBasePath}/temp/lyco.txt`)).split("\n")
|
||||
.filter(x => x.trim().length > 0) // Remove empty lines
|
||||
.map(x => x.trim()); // Remove carriage returns and padding if it exists
|
||||
.map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs
|
||||
} catch (e) {
|
||||
console.error("Error loading lyco.txt: " + e);
|
||||
}
|
||||
|
||||
29
javascript/ext_modelKeyword.js
Normal file
29
javascript/ext_modelKeyword.js
Normal file
@@ -0,0 +1,29 @@
|
||||
async function load() {
|
||||
let modelKeywordParts = (await readFile(`tmp/modelKeywordPath.txt`)).split(",")
|
||||
modelKeywordPath = modelKeywordParts[0];
|
||||
let customFileExists = modelKeywordParts[1] === "True";
|
||||
|
||||
if (modelKeywordPath.length > 0 && modelKeywordDict.size === 0) {
|
||||
try {
|
||||
let lines = (await readFile(`${modelKeywordPath}/lora-keyword.txt`)).split("\n");
|
||||
// Add custom user keywords if the file exists
|
||||
if (customFileExists)
|
||||
lines = lines.concat((await readFile(`${modelKeywordPath}/lora-keyword-user.txt`)).split("\n"));
|
||||
|
||||
lines = lines.filter(x => x.trim().length > 0 && x.trim()[0] !== "#") // Remove empty lines and comments
|
||||
|
||||
// Add to the dict
|
||||
lines.forEach(line => {
|
||||
const parts = line.split(",");
|
||||
const hash = parts[0];
|
||||
const keywords = parts[1].replaceAll("| ", ", ").replaceAll("|", ", ").trim();
|
||||
|
||||
modelKeywordDict.set(hash, keywords);
|
||||
});
|
||||
} catch (e) {
|
||||
console.error("Error loading model-keywords list: " + e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
QUEUE_FILE_LOAD.push(load);
|
||||
@@ -202,6 +202,7 @@ async function syncOptions() {
|
||||
appendSpace: opts["tac_appendSpace"],
|
||||
alwaysSpaceAtEnd: opts["tac_alwaysSpaceAtEnd"],
|
||||
wildcardCompletionMode: opts["tac_wildcardCompletionMode"],
|
||||
modelKeywordCompletion: opts["tac_modelKeywordCompletion"],
|
||||
// Alias settings
|
||||
alias: {
|
||||
searchByAlias: opts["tac_alias.searchByAlias"],
|
||||
@@ -441,8 +442,22 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
|
||||
|
||||
// Add back start
|
||||
var newPrompt = prompt.substring(0, editStart) + insert + prompt.substring(editEnd);
|
||||
|
||||
// Add lora/lyco keywords if enabled and found
|
||||
let keywordsLength = 0;
|
||||
if (TAC_CFG.modelKeywordCompletion && modelKeywordPath.length > 0 && (tagType === ResultType.lora || tagType === ResultType.lyco)) {
|
||||
if (result.hash && result.hash !== "NOFILE" && result.hash.length > 0) {
|
||||
let keywords = modelKeywordDict.get(result.hash);
|
||||
if (keywords && keywords.length > 0) {
|
||||
newPrompt = `${keywords}, ${newPrompt}`;
|
||||
keywordsLength = keywords.length + 2; // +2 for the comma and space
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert into prompt textbox and reposition cursor
|
||||
textArea.value = newPrompt;
|
||||
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length;
|
||||
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength;
|
||||
textArea.selectionEnd = textArea.selectionStart
|
||||
|
||||
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
|
||||
|
||||
76
scripts/model_keyword_support.py
Normal file
76
scripts/model_keyword_support.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# This file provides support for the model-keyword extension to add known lora keywords on completion
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from scripts.shared_paths import EXT_PATH, STATIC_TEMP_PATH, TEMP_PATH
|
||||
|
||||
# Set up our hash cache
|
||||
known_hashes_file = TEMP_PATH.joinpath("known_lora_hashes.txt")
|
||||
known_hashes_file.touch()
|
||||
file_needs_update = False
|
||||
|
||||
# Load the hashes from the file
|
||||
hash_dict = {}
|
||||
|
||||
|
||||
def load_hash_cache():
|
||||
with open(known_hashes_file, "r") as file:
|
||||
for line in file:
|
||||
name, hash, mtime = line.replace("\n", "").split(",")
|
||||
hash_dict[name] = (hash, mtime)
|
||||
|
||||
|
||||
def update_hash_cache():
|
||||
global file_needs_update
|
||||
if file_needs_update:
|
||||
with open(known_hashes_file, "w") as file:
|
||||
for name, (hash, mtime) in hash_dict.items():
|
||||
file.write(f"{name},{hash},{mtime}\n")
|
||||
|
||||
|
||||
# Copy of the fast inaccurate hash function from the extension
|
||||
# with some modifications to load from and write to the cache
|
||||
def get_lora_simple_hash(path):
|
||||
global file_needs_update
|
||||
mtime = str(Path(path).stat().st_mtime)
|
||||
filename = Path(path).name
|
||||
|
||||
if filename in hash_dict:
|
||||
(hash, old_mtime) = hash_dict[filename]
|
||||
if mtime == old_mtime:
|
||||
return hash
|
||||
try:
|
||||
with open(path, "rb") as file:
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
m.update(file.read(0x10000))
|
||||
hash = m.hexdigest()[0:8]
|
||||
|
||||
hash_dict[filename] = (hash, mtime)
|
||||
file_needs_update = True
|
||||
|
||||
return hash
|
||||
except FileNotFoundError:
|
||||
return "NOFILE"
|
||||
|
||||
|
||||
# Find the path of the original model-keyword extension
|
||||
def write_model_keyword_path():
|
||||
# Ensure the file exists even if the extension is not installed
|
||||
mk_path = STATIC_TEMP_PATH.joinpath("modelKeywordPath.txt")
|
||||
mk_path.write_text("")
|
||||
|
||||
base_keywords = list(EXT_PATH.glob("*/lora-keyword.txt"))
|
||||
custom_keywords = list(EXT_PATH.glob("*/lora-keyword-user.txt"))
|
||||
custom_found = custom_keywords is not None and len(custom_keywords) > 0
|
||||
if base_keywords is not None and len(base_keywords) > 0:
|
||||
with open(mk_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"{base_keywords[0].parent.as_posix()},{custom_found}")
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
"Tag Autocomplete: Could not locate model-keyword extension, LORA/LYCO trigger word completion will be unavailable."
|
||||
)
|
||||
return False
|
||||
47
scripts/shared_paths.py
Normal file
47
scripts/shared_paths.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
from modules import scripts, shared
|
||||
|
||||
try:
|
||||
from modules.paths import extensions_dir, script_path
|
||||
|
||||
# Webui root path
|
||||
FILE_DIR = Path(script_path)
|
||||
|
||||
# The extension base path
|
||||
EXT_PATH = Path(extensions_dir)
|
||||
except ImportError:
|
||||
# Webui root path
|
||||
FILE_DIR = Path().absolute()
|
||||
# The extension base path
|
||||
EXT_PATH = FILE_DIR.joinpath('extensions')
|
||||
|
||||
# Tags base path
|
||||
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
|
||||
|
||||
# The path to the folder containing the wildcards and embeddings
|
||||
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
|
||||
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
|
||||
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
|
||||
|
||||
try:
|
||||
LORA_PATH = Path(shared.cmd_opts.lora_dir)
|
||||
except AttributeError:
|
||||
LORA_PATH = None
|
||||
|
||||
try:
|
||||
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
|
||||
except AttributeError:
|
||||
LYCO_PATH = None
|
||||
|
||||
def find_ext_wildcard_paths():
|
||||
"""Returns the path to the extension wildcards folder"""
|
||||
found = list(EXT_PATH.glob('*/wildcards/'))
|
||||
return found
|
||||
|
||||
|
||||
# The path to the extension wildcards folder
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
|
||||
# The path to the temporary files
|
||||
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
|
||||
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files
|
||||
@@ -6,52 +6,12 @@ from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
from modules import script_callbacks, scripts, sd_hijack, shared
|
||||
from modules import script_callbacks, sd_hijack, shared
|
||||
|
||||
try:
|
||||
from modules.paths import extensions_dir, script_path
|
||||
|
||||
# Webui root path
|
||||
FILE_DIR = Path(script_path)
|
||||
|
||||
# The extension base path
|
||||
EXT_PATH = Path(extensions_dir)
|
||||
except ImportError:
|
||||
# Webui root path
|
||||
FILE_DIR = Path().absolute()
|
||||
# The extension base path
|
||||
EXT_PATH = FILE_DIR.joinpath('extensions')
|
||||
|
||||
# Tags base path
|
||||
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
|
||||
|
||||
# The path to the folder containing the wildcards and embeddings
|
||||
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
|
||||
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
|
||||
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
|
||||
|
||||
try:
|
||||
LORA_PATH = Path(shared.cmd_opts.lora_dir)
|
||||
except AttributeError:
|
||||
LORA_PATH = None
|
||||
|
||||
try:
|
||||
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
|
||||
except AttributeError:
|
||||
LYCO_PATH = None
|
||||
|
||||
def find_ext_wildcard_paths():
|
||||
"""Returns the path to the extension wildcards folder"""
|
||||
found = list(EXT_PATH.glob('*/wildcards/'))
|
||||
return found
|
||||
|
||||
|
||||
# The path to the extension wildcards folder
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
|
||||
# The path to the temporary files
|
||||
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
|
||||
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files
|
||||
from scripts.model_keyword_support import (get_lora_simple_hash,
|
||||
load_hash_cache, update_hash_cache,
|
||||
write_model_keyword_path)
|
||||
from scripts.shared_paths import *
|
||||
|
||||
|
||||
def get_wildcards():
|
||||
@@ -171,23 +131,47 @@ def get_hypernetworks():
|
||||
# Remove file extensions
|
||||
return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower())
|
||||
|
||||
model_keyword_installed = write_model_keyword_path()
|
||||
def get_lora():
|
||||
"""Write a list of all lora"""
|
||||
global model_keyword_installed
|
||||
|
||||
# Get a list of all lora in the folder
|
||||
lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)]
|
||||
all_lora = [str(l.name) for l in lora_paths if l.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
# Remove file extensions
|
||||
return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower())
|
||||
# Get hashes
|
||||
valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
hashes = {}
|
||||
for l in valid_loras:
|
||||
name = l.name[:l.name.rfind('.')]
|
||||
if model_keyword_installed:
|
||||
hashes[name] = get_lora_simple_hash(l)
|
||||
else:
|
||||
hashes[name] = ""
|
||||
|
||||
# Sort
|
||||
sorted_loras = dict(sorted(hashes.items()))
|
||||
# Add hashes and return
|
||||
return [f"{name},{hash}" for name, hash in sorted_loras.items()]
|
||||
|
||||
|
||||
def get_lyco():
|
||||
"""Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris"""
|
||||
|
||||
# Get a list of all LyCORIS in the folder
|
||||
lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)]
|
||||
all_lyco = [str(ly.name) for ly in lyco_paths if ly.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
# Remove file extensions
|
||||
return sorted([ly[:ly.rfind('.')] for ly in all_lyco], key=lambda x: x.lower())
|
||||
|
||||
# Get hashes
|
||||
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
hashes = {}
|
||||
for ly in valid_lycos:
|
||||
name = ly.name[:ly.name.rfind('.')]
|
||||
hashes[name] = get_lora_simple_hash(ly)
|
||||
|
||||
# Sort
|
||||
sorted_lycos = dict(sorted(hashes.items()))
|
||||
# Add hashes and return
|
||||
return [f"{name},{hash}" for name, hash in sorted_lycos.items()]
|
||||
|
||||
|
||||
def write_tag_base_path():
|
||||
"""Writes the tag base path to a fixed location temporary file"""
|
||||
@@ -276,6 +260,9 @@ def write_temp_files():
|
||||
if hypernets:
|
||||
write_to_temp_file('hyp.txt', hypernets)
|
||||
|
||||
if model_keyword_installed:
|
||||
load_hash_cache()
|
||||
|
||||
if LORA_PATH is not None and LORA_PATH.exists():
|
||||
lora = get_lora()
|
||||
if lora:
|
||||
@@ -286,6 +273,9 @@ def write_temp_files():
|
||||
if lyco:
|
||||
write_to_temp_file('lyco.txt', lyco)
|
||||
|
||||
if model_keyword_installed:
|
||||
update_hash_cache()
|
||||
|
||||
|
||||
write_temp_files()
|
||||
|
||||
@@ -334,6 +324,7 @@ def on_ui_settings():
|
||||
"tac_appendComma": shared.OptionInfo(True, "Append comma on tag autocompletion"),
|
||||
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
|
||||
"tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"),
|
||||
"tac_modelKeywordCompletion": shared.OptionInfo(False, "Try to add known trigger words for LORA/LyCO models", gr.Checkbox, lambda: {"interactive": model_keyword_installed}).info("Requires the <a href=\"https://github.com/mix1009/model-keyword\" target=\"_blank\">model-keyword</a> extension to be installed, but will work with it disabled"),
|
||||
"tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""),
|
||||
# Alias settings
|
||||
"tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),
|
||||
|
||||
Reference in New Issue
Block a user