Merge pull request #201 from DominikDoom/feature-trigger-word-completion

This commit is contained in:
DominikDoom
2023-07-22 15:36:43 +02:00
committed by GitHub
10 changed files with 337 additions and 60 deletions

View File

@@ -42,7 +42,7 @@ You can install it using the inbuilt available extensions list, clone the files
Tag autocomplete supports built-in completion for:
- 🏷️ **Danbooru & e621 tags** (Top 100k by post count, as of November 2022)
- ✳️ [**Wildcards**](#wildcards)
- [**Extra network**](#extra-networks-embeddings-hypernets-lora) filenames, including
- [**Extra network**](#extra-networks-embeddings-hypernets-lora-) filenames, including
- Textual Inversion embeddings [(jump to readme section)]
- Hypernetworks
- LoRA
@@ -123,6 +123,30 @@ Completion for these types is triggered by typing `<`. By default it will show t
- Or `<lora:` and `<lyco:` respectively for the long form
- `<h:` or `<hypernet:` will only show Hypernetworks
### Lora / Lyco trigger word completion
This is an advanced feature that will try to add known trigger words on autocompleting a Lora/Lyco.
It uses the list provided by the [model-keyword](https://github.com/mix1009/model-keyword/) extension, which thus needs to be installed to use this feature. The list is also regularly updated through it.
However, once installed, you can deactivate it if you want, since tag autocomplete only needs the local keyword lists it ships with, not the extension itself.
The used files are `lora-keywords.txt` and `lora-keywords-user.txt` in the model-keyword installation folder.
If the main file isn't found, the feature will simply deactivate itself, everything else should work normally.
To add custom mappings for unknown Loras, you can use the UI provided by model-keyword, it will automatically write it to the `lora-keywords-user.txt` for you (and create it if it doesn't exist).
The only issue is that it has no official support for the Lycoris extension and doesn't scan its folder for files, so to add them through the UI you will have to temporarily move them into the Lora model folder to be able to select them in model-keywords dropdown.
Some are already included in the default list though, so trying it out first is advisable.
<details>
<summary>Walkthorugh to add custom keywords</summary>
![image](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/assets/34448969/4302c44e-c632-473d-a14a-76f164f966cb)
</details>
After having added your custom keywords, you will need to either restart the UI or use the "Refresh TAC temp files" setting button.
Sometimes the inserted keywords can be wrong due to a hash collision, however model-keyword and tag autocomplete take the name of the file into account too if the collision is known.
If it still inserts something wrong or you simply don't want the keywords added that time, you can undo / redo it directly after as often as you want, until you type something else
(It uses the default undo/redo action of the browser, so <kbd>CTRL</kbd> + <kbd>Z</kbd>, context menu and mouse macros should all work).
### Embedding type filtering
Embeddings trained for Stable Diffusion 1.x or 2.x models respectively are incompatible with the other type. To make it easier to find valid embeds, they are categorized by "v1 Embedding" and "v2 Embedding", including a slight color difference. You can also filter your search to include only v1 or v2 embeddings by typing `<v1/2` or `<e:v1/2` followed by the actual search term.
@@ -285,6 +309,26 @@ Depending on the last setting, tag autocomplete will append a comma and space af
![insertEscape](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/assets/34448969/d28557be-6c75-43fd-bf17-c0609223b384)
</details>
<!-- Lora keywords -->
<details>
<summary>Lora / Lyco trigger word insertion</summary>
See [the detailed readme section](#lora--lyco-trigger-word-completion) for more info.
Selects the mode to use for Lora / Lyco trigger word insertion.
Needs the [model-keyword](https://github.com/mix1009/model-keyword/) extension to be installed, else it will do nothing.
- Never
- Will not complete trigger words, even if the model-keyword extension is installed
- Only user list
- Will only load the custom keywords specified in the lora-keyword-user.txt file and ignore the default list
- Always
- Will load and use both lists
Switching from "Never" to what you had before or back will not require a restart, but changing between the full and user only list will.
![loraKeywordCompletion](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/assets/34448969/8bec89ea-68f7-4783-b5cb-55869d9121a3)
</details>
<!-- Wildcard path mode -->
<details>
<summary>Wildcard path completion</summary>

View File

@@ -1,6 +1,7 @@
// Core components
var TAC_CFG = null;
var tagBasePath = "";
var modelKeywordPath = "";
// Tag completion data loaded from files
var allTags = [];
@@ -14,6 +15,7 @@ var embeddings = [];
var hypernetworks = [];
var loras = [];
var lycos = [];
var modelKeywordDict = new Map();
var chants = [];
// Selected model info for black/whitelisting
@@ -34,6 +36,12 @@ let hideBlocked = false;
var selectedTag = null;
var oldSelectedTag = null;
// Lora keyword undo/redo history
var textBeforeKeywordInsertion = "";
var textAfterKeywordInsertion = "";
var lastEditWasKeywordInsertion = false;
var keywordInsertionUndone = false;
// UMI
var umiPreviousTags = [];

View File

@@ -25,6 +25,7 @@ class AutocompleteResult {
count = null;
aliases = null;
meta = null;
hash = null;
// Constructor
constructor(text, type) {

View File

@@ -8,7 +8,7 @@ class LoraParser extends BaseTagParser {
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lora:") {
let searchTerm = tagword.replace("<lora:", "").replace("<l:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
tempResults = loras.filter(x => filterCondition(x)); // Filter by tagword
tempResults = loras.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = loras;
}
@@ -16,8 +16,9 @@ class LoraParser extends BaseTagParser {
// Add final results
let finalResults = [];
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lora)
let result = new AutocompleteResult(t[0].trim(), ResultType.lora)
result.meta = "Lora";
result.hash = t[1];
finalResults.push(result);
});
@@ -30,7 +31,7 @@ async function load() {
try {
loras = (await readFile(`${tagBasePath}/temp/lora.txt`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.trim()); // Remove carriage returns and padding if it exists
.map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs
} catch (e) {
console.error("Error loading lora.txt: " + e);
}

View File

@@ -8,7 +8,7 @@ class LycoParser extends BaseTagParser {
if (tagword !== "<" && tagword !== "<l:" && tagword !== "<lyco:") {
let searchTerm = tagword.replace("<lyco:", "").replace("<l:", "").replace("<", "");
let filterCondition = x => x.toLowerCase().includes(searchTerm) || x.toLowerCase().replaceAll(" ", "_").includes(searchTerm);
tempResults = lycos.filter(x => filterCondition(x)); // Filter by tagword
tempResults = lycos.filter(x => filterCondition(x[0])); // Filter by tagword
} else {
tempResults = lycos;
}
@@ -16,8 +16,9 @@ class LycoParser extends BaseTagParser {
// Add final results
let finalResults = [];
tempResults.forEach(t => {
let result = new AutocompleteResult(t.trim(), ResultType.lyco)
let result = new AutocompleteResult(t[0].trim(), ResultType.lyco)
result.meta = "Lyco";
result.hash = t[1];
finalResults.push(result);
});
@@ -30,7 +31,7 @@ async function load() {
try {
lycos = (await readFile(`${tagBasePath}/temp/lyco.txt`)).split("\n")
.filter(x => x.trim().length > 0) // Remove empty lines
.map(x => x.trim()); // Remove carriage returns and padding if it exists
.map(x => x.trim().split(",")); // Remove carriage returns and padding if it exists, split into name, hash pairs
} catch (e) {
console.error("Error loading lyco.txt: " + e);
}

View File

@@ -0,0 +1,43 @@
async function load() {
let modelKeywordParts = (await readFile(`tmp/modelKeywordPath.txt`)).split(",")
modelKeywordPath = modelKeywordParts[0];
let customFileExists = modelKeywordParts[1] === "True";
if (modelKeywordPath.length > 0 && modelKeywordDict.size === 0) {
try {
let lines = [];
// Only add default keywords if wanted by the user
if (TAC_CFG.modelKeywordCompletion !== "Only user list")
lines = (await readFile(`${modelKeywordPath}/lora-keyword.txt`)).split("\n");
// Add custom user keywords if the file exists
if (customFileExists)
lines = lines.concat((await readFile(`${modelKeywordPath}/lora-keyword-user.txt`)).split("\n"));
if (lines.length === 0) return;
lines = lines.filter(x => x.trim().length > 0 && x.trim()[0] !== "#") // Remove empty lines and comments
// Add to the dict
lines.forEach(line => {
const parts = line.split(",");
const hash = parts[0];
const keywords = parts[1].replaceAll("| ", ", ").replaceAll("|", ", ").trim();
const lastSepIndex = parts[2]?.lastIndexOf("/") + 1 || parts[2]?.lastIndexOf("\\") + 1 || 0;
const name = parts[2]?.substring(lastSepIndex).trim() || "none"
if (modelKeywordDict.has(hash) && name !== "none") {
// Add a new name key if the hash already exists
modelKeywordDict.get(hash).set(name, keywords);
} else {
// Create new hash entry
let map = new Map().set(name, keywords);
modelKeywordDict.set(hash, map);
}
});
} catch (e) {
console.error("Error loading model-keywords list: " + e);
}
}
}
QUEUE_FILE_LOAD.push(load);

View File

@@ -202,6 +202,7 @@ async function syncOptions() {
appendSpace: opts["tac_appendSpace"],
alwaysSpaceAtEnd: opts["tac_alwaysSpaceAtEnd"],
wildcardCompletionMode: opts["tac_wildcardCompletionMode"],
modelKeywordCompletion: opts["tac_modelKeywordCompletion"],
// Alias settings
alias: {
searchByAlias: opts["tac_alias.searchByAlias"],
@@ -441,8 +442,39 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout
// Add back start
var newPrompt = prompt.substring(0, editStart) + insert + prompt.substring(editEnd);
// Add lora/lyco keywords if enabled and found
let keywordsLength = 0;
if (TAC_CFG.modelKeywordCompletion !== "Never" && modelKeywordPath.length > 0 && (tagType === ResultType.lora || tagType === ResultType.lyco)) {
if (result.hash && result.hash !== "NOFILE" && result.hash.length > 0) {
let keywords = null;
let nameDict = modelKeywordDict.get(result.hash);
let name = result.text + ".safetensors";
if (nameDict) {
if (nameDict.has(name))
keywords = nameDict.get(name);
else
keywords = nameDict.get("none");
}
if (keywords && keywords.length > 0) {
textBeforeKeywordInsertion = newPrompt;
newPrompt = `${keywords}, ${newPrompt}`; // Insert keywords
textAfterKeywordInsertion = newPrompt;
keywordInsertionUndone = false;
setTimeout(() => lastEditWasKeywordInsertion = true, 200)
keywordsLength = keywords.length + 2; // +2 for the comma and space
}
}
}
// Insert into prompt textbox and reposition cursor
textArea.value = newPrompt;
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length;
textArea.selectionStart = afterInsertCursorPos + optionalSeparator.length + keywordsLength;
textArea.selectionEnd = textArea.selectionStart
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure it's propagated back to python.
@@ -771,6 +803,37 @@ function rubyTagClicked(node, textBefore, prompt, textArea) {
textArea.setSelectionRange(startPos, endPos);
}
// Check if the last edit was the keyword insertion, and catch undo/redo in that case
function checkKeywordInsertionUndo(textArea, event) {
if (TAC_CFG.modelKeywordCompletion === "Never") return;
switch (event.inputType) {
case "historyUndo":
if (lastEditWasKeywordInsertion && !keywordInsertionUndone) {
keywordInsertionUndone = true;
textArea.value = textBeforeKeywordInsertion;
updateInput(textArea);
}
break;
case "historyRedo":
if (lastEditWasKeywordInsertion && keywordInsertionUndone) {
keywordInsertionUndone = false;
textArea.value = textAfterKeywordInsertion;
updateInput(textArea);
}
case undefined:
// undefined is caused by the updateInput event firing, so we just ignore it
break;
default:
// Everything else deactivates the keyword undo and returns to normal undo behavior
lastEditWasKeywordInsertion = false;
keywordInsertionUndone = false;
textBeforeKeywordInsertion = "";
textAfterKeywordInsertion = "";
break;
}
}
async function autocomplete(textArea, prompt, fixedTag = null) {
// Return if the function is deactivated in the UI
if (!isEnabled()) return;
@@ -1036,6 +1099,7 @@ async function refreshTacTempFiles() {
hypernetworks = [];
loras = [];
lycos = [];
modelKeywordDict.clear();
await processQueue(QUEUE_FILE_LOAD, null);
console.log("TAC: Refreshed temp files");
@@ -1061,9 +1125,10 @@ function addAutocompleteToArea(area) {
hideResults(area);
// Add autocomplete event listener
area.addEventListener('input', () => {
area.addEventListener('input', (e) => {
debounce(autocomplete(area, area.value), TAC_CFG.delayTime);
updateRuby(area, area.value);
checkKeywordInsertionUndo(area, e);
});
// Add focusout event listener
area.addEventListener('focusout', debounce(() => {

View File

@@ -0,0 +1,76 @@
# This file provides support for the model-keyword extension to add known lora keywords on completion
import hashlib
from pathlib import Path
from scripts.shared_paths import EXT_PATH, STATIC_TEMP_PATH, TEMP_PATH
# Set up our hash cache
known_hashes_file = TEMP_PATH.joinpath("known_lora_hashes.txt")
known_hashes_file.touch()
file_needs_update = False
# Load the hashes from the file
hash_dict = {}
def load_hash_cache():
with open(known_hashes_file, "r") as file:
for line in file:
name, hash, mtime = line.replace("\n", "").split(",")
hash_dict[name] = (hash, mtime)
def update_hash_cache():
global file_needs_update
if file_needs_update:
with open(known_hashes_file, "w") as file:
for name, (hash, mtime) in hash_dict.items():
file.write(f"{name},{hash},{mtime}\n")
# Copy of the fast inaccurate hash function from the extension
# with some modifications to load from and write to the cache
def get_lora_simple_hash(path):
global file_needs_update
mtime = str(Path(path).stat().st_mtime)
filename = Path(path).name
if filename in hash_dict:
(hash, old_mtime) = hash_dict[filename]
if mtime == old_mtime:
return hash
try:
with open(path, "rb") as file:
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
hash = m.hexdigest()[0:8]
hash_dict[filename] = (hash, mtime)
file_needs_update = True
return hash
except FileNotFoundError:
return "NOFILE"
# Find the path of the original model-keyword extension
def write_model_keyword_path():
# Ensure the file exists even if the extension is not installed
mk_path = STATIC_TEMP_PATH.joinpath("modelKeywordPath.txt")
mk_path.write_text("")
base_keywords = list(EXT_PATH.glob("*/lora-keyword.txt"))
custom_keywords = list(EXT_PATH.glob("*/lora-keyword-user.txt"))
custom_found = custom_keywords is not None and len(custom_keywords) > 0
if base_keywords is not None and len(base_keywords) > 0:
with open(mk_path, "w", encoding="utf-8") as f:
f.write(f"{base_keywords[0].parent.as_posix()},{custom_found}")
return True
else:
print(
"Tag Autocomplete: Could not locate model-keyword extension, LORA/LYCO trigger word completion will be unavailable."
)
return False

47
scripts/shared_paths.py Normal file
View File

@@ -0,0 +1,47 @@
from pathlib import Path
from modules import scripts, shared
try:
from modules.paths import extensions_dir, script_path
# Webui root path
FILE_DIR = Path(script_path)
# The extension base path
EXT_PATH = Path(extensions_dir)
except ImportError:
# Webui root path
FILE_DIR = Path().absolute()
# The extension base path
EXT_PATH = FILE_DIR.joinpath('extensions')
# Tags base path
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
try:
LORA_PATH = Path(shared.cmd_opts.lora_dir)
except AttributeError:
LORA_PATH = None
try:
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
except AttributeError:
LYCO_PATH = None
def find_ext_wildcard_paths():
"""Returns the path to the extension wildcards folder"""
found = list(EXT_PATH.glob('*/wildcards/'))
return found
# The path to the extension wildcards folder
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
# The path to the temporary files
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files

View File

@@ -6,52 +6,12 @@ from pathlib import Path
import gradio as gr
import yaml
from modules import script_callbacks, scripts, sd_hijack, shared
from modules import script_callbacks, sd_hijack, shared
try:
from modules.paths import extensions_dir, script_path
# Webui root path
FILE_DIR = Path(script_path)
# The extension base path
EXT_PATH = Path(extensions_dir)
except ImportError:
# Webui root path
FILE_DIR = Path().absolute()
# The extension base path
EXT_PATH = FILE_DIR.joinpath('extensions')
# Tags base path
TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
# The path to the folder containing the wildcards and embeddings
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
try:
LORA_PATH = Path(shared.cmd_opts.lora_dir)
except AttributeError:
LORA_PATH = None
try:
LYCO_PATH = Path(shared.cmd_opts.lyco_dir)
except AttributeError:
LYCO_PATH = None
def find_ext_wildcard_paths():
"""Returns the path to the extension wildcards folder"""
found = list(EXT_PATH.glob('*/wildcards/'))
return found
# The path to the extension wildcards folder
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
# The path to the temporary files
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files
from scripts.model_keyword_support import (get_lora_simple_hash,
load_hash_cache, update_hash_cache,
write_model_keyword_path)
from scripts.shared_paths import *
def get_wildcards():
@@ -171,23 +131,47 @@ def get_hypernetworks():
# Remove file extensions
return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower())
model_keyword_installed = write_model_keyword_path()
def get_lora():
"""Write a list of all lora"""
global model_keyword_installed
# Get a list of all lora in the folder
lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)]
all_lora = [str(l.name) for l in lora_paths if l.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower())
# Get hashes
valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
for l in valid_loras:
name = l.name[:l.name.rfind('.')]
if model_keyword_installed:
hashes[name] = get_lora_simple_hash(l)
else:
hashes[name] = ""
# Sort
sorted_loras = dict(sorted(hashes.items()))
# Add hashes and return
return [f"{name},{hash}" for name, hash in sorted_loras.items()]
def get_lyco():
"""Write a list of all LyCORIS/LOHA from https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris"""
# Get a list of all LyCORIS in the folder
lyco_paths = [Path(ly) for ly in glob.glob(LYCO_PATH.joinpath("**/*").as_posix(), recursive=True)]
all_lyco = [str(ly.name) for ly in lyco_paths if ly.suffix in {".safetensors", ".ckpt", ".pt"}]
# Remove file extensions
return sorted([ly[:ly.rfind('.')] for ly in all_lyco], key=lambda x: x.lower())
# Get hashes
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
hashes = {}
for ly in valid_lycos:
name = ly.name[:ly.name.rfind('.')]
hashes[name] = get_lora_simple_hash(ly)
# Sort
sorted_lycos = dict(sorted(hashes.items()))
# Add hashes and return
return [f"{name},{hash}" for name, hash in sorted_lycos.items()]
def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
@@ -276,6 +260,9 @@ def write_temp_files():
if hypernets:
write_to_temp_file('hyp.txt', hypernets)
if model_keyword_installed:
load_hash_cache()
if LORA_PATH is not None and LORA_PATH.exists():
lora = get_lora()
if lora:
@@ -286,6 +273,9 @@ def write_temp_files():
if lyco:
write_to_temp_file('lyco.txt', lyco)
if model_keyword_installed:
update_hash_cache()
write_temp_files()
@@ -334,6 +324,7 @@ def on_ui_settings():
"tac_appendComma": shared.OptionInfo(True, "Append comma on tag autocompletion"),
"tac_appendSpace": shared.OptionInfo(True, "Append space on tag autocompletion").info("will append after comma if the above is enabled"),
"tac_alwaysSpaceAtEnd": shared.OptionInfo(True, "Always append space if inserting at the end of the textbox").info("takes precedence over the regular space setting for that position"),
"tac_modelKeywordCompletion": shared.OptionInfo("Never", "Try to add known trigger words for LORA/LyCO models", gr.Dropdown, lambda: {"interactive": model_keyword_installed, "choices": ["Never","Only user list","Always"]}).info("Requires the <a href=\"https://github.com/mix1009/model-keyword\" target=\"_blank\">model-keyword</a> extension to be installed, but will work with it disabled.").needs_restart(),
"tac_wildcardCompletionMode": shared.OptionInfo("To next folder level", "How to complete nested wildcard paths", gr.Dropdown, lambda: {"choices": ["To next folder level","To first difference","Always fully"]}).info("e.g. \"hair/colours/light/...\""),
# Alias settings
"tac_alias.searchByAlias": shared.OptionInfo(True, "Search by alias"),