diff --git a/javascript/__globals.js b/javascript/__globals.js index 8be8614..35b9dcf 100644 --- a/javascript/__globals.js +++ b/javascript/__globals.js @@ -11,6 +11,7 @@ var extras = []; var wildcardFiles = []; var wildcardExtFiles = []; var yamlWildcards = []; +var umiWildcards = []; var embeddings = []; var hypernetworks = []; var loras = []; diff --git a/javascript/_result.js b/javascript/_result.js index 6953ba6..96129f1 100644 --- a/javascript/_result.js +++ b/javascript/_result.js @@ -8,10 +8,11 @@ const ResultType = Object.freeze({ "wildcardTag": 4, "wildcardFile": 5, "yamlWildcard": 6, - "hypernetwork": 7, - "lora": 8, - "lyco": 9, - "chant": 10 + "umiWildcard": 7, + "hypernetwork": 8, + "lora": 9, + "lyco": 10, + "chant": 11 }); // Class to hold result data and annotations to make it clearer to use diff --git a/javascript/_utils.js b/javascript/_utils.js index 4fea691..2b063a5 100644 --- a/javascript/_utils.js +++ b/javascript/_utils.js @@ -109,6 +109,28 @@ function difference(a, b) { )].reduce((acc, [v, count]) => acc.concat(Array(Math.abs(count)).fill(v)), []); } +// Object flatten function adapted from https://stackoverflow.com/a/61602592 +// $roots keeps previous parent properties as they will be added as a prefix for each prop. +// $sep is just a preference if you want to seperate nested paths other than dot. +function flatten(obj, roots = [], sep = ".") { + return Object.keys(obj).reduce( + (memo, prop) => + Object.assign( + // create a new object + {}, + // include previously returned object + memo, + Object.prototype.toString.call(obj[prop]) === "[object Object]" + ? // keep working if value is an object + flatten(obj[prop], roots.concat([prop]), sep) + : // include current prop and value and prefix prop with the roots + { [roots.concat([prop]).join(sep)]: obj[prop] } + ), + {} + ); +} + + // Sliding window function to get possible combination groups of an array function toNgrams(inputArray, size) { return Array.from( diff --git a/javascript/ext_umi.js b/javascript/ext_umi.js index ad4ced5..2029f8c 100644 --- a/javascript/ext_umi.js +++ b/javascript/ext_umi.js @@ -74,7 +74,7 @@ class UmiParser extends BaseTagParser { //console.log({ matches }) const filteredWildcards = (tagword) => { - const wildcards = yamlWildcards.filter(x => { + const wildcards = umiWildcards.filter(x => { let tags = x[1]; const matchesNeg = matches.negative.length === 0 @@ -144,7 +144,7 @@ class UmiParser extends BaseTagParser { // Add final results let finalResults = []; tempResults.forEach(t => { - let result = new AutocompleteResult(t[0].trim(), ResultType.yamlWildcard) + let result = new AutocompleteResult(t[0].trim(), ResultType.umiWildcard) result.count = t[1]; finalResults.push(result); }); @@ -156,7 +156,7 @@ class UmiParser extends BaseTagParser { // Add final results let finalResults = []; filteredWildcardsSorted.forEach(t => { - let result = new AutocompleteResult(t[0].trim(), ResultType.yamlWildcard) + let result = new AutocompleteResult(t[0].trim(), ResultType.umiWildcard) result.count = t[1]; finalResults.push(result); }); @@ -171,7 +171,7 @@ class UmiParser extends BaseTagParser { // Add final results let finalResults = []; filteredWildcardsSorted.forEach(t => { - let result = new AutocompleteResult(t[0].trim(), ResultType.yamlWildcard) + let result = new AutocompleteResult(t[0].trim(), ResultType.umiWildcard) result.count = t[1]; finalResults.push(result); }); @@ -184,8 +184,8 @@ class UmiParser extends BaseTagParser { } function updateUmiTags( tagType, sanitizedText, newPrompt, textArea) { - // If it was a yaml wildcard, also update the umiPreviousTags - if (tagType === ResultType.yamlWildcard && originalTagword.length > 0) { + // If it was a umi wildcard, also update the umiPreviousTags + if (tagType === ResultType.umiWildcard && originalTagword.length > 0) { let umiSubPrompts = [...newPrompt.matchAll(UMI_PROMPT_REGEX)]; let umiTags = []; @@ -203,11 +203,11 @@ function updateUmiTags( tagType, sanitizedText, newPrompt, textArea) { } async function load() { - if (yamlWildcards.length === 0) { + if (umiWildcards.length === 0) { try { - let yamlTags = (await readFile(`${tagBasePath}/temp/wcet.txt`)).split("\n"); + let umiTags = (await readFile(`${tagBasePath}/temp/umi_tags.txt`)).split("\n"); // Split into tag, count pairs - yamlWildcards = yamlTags.map(x => x + umiWildcards = umiTags.map(x => x .trim() .split(",")) .map(([i, ...rest]) => [ @@ -218,14 +218,14 @@ async function load() { }, {}), ]); } catch (e) { - console.error("Error loading yaml wildcards: " + e); + console.error("Error loading umi wildcards: " + e); } } } function sanitize(tagType, text) { - // Replace underscores only if the yaml tag is not using them - if (tagType === ResultType.yamlWildcard && !yamlWildcards.includes(text)) { + // Replace underscores only if the umi tag is not using them + if (tagType === ResultType.umiWildcard && !umiWildcards.includes(text)) { return text.replaceAll("_", " "); } return null; diff --git a/javascript/ext_wildcards.js b/javascript/ext_wildcards.js index c594b94..cc3e47c 100644 --- a/javascript/ext_wildcards.js +++ b/javascript/ext_wildcards.js @@ -17,8 +17,22 @@ class WildcardParser extends BaseTagParser { // Use found wildcard file or look in external wildcard files let wcPair = wcFound || wildcardExtFiles.find(x => x[1].toLowerCase() === wcFile); - let wildcards = (await readFile(`${wcPair[0]}/${wcPair[1]}.txt`)).split("\n") + if (!wcPair || !wcPair[0] || !wcPair[1]) return []; + + let wildcards = []; + if (wcPair[0].endsWith(".yaml")) { + const getDescendantProp = (obj, desc) => { + const arr = desc.split("/"); + while (arr.length) { + obj = obj[arr.shift()]; + } + return obj; + } + wildcards = getDescendantProp(yamlWildcards[wcPair[0]], wcPair[1]); + } else { + wildcards = (await readFile(`${wcPair[0]}/${wcPair[1]}.txt`)).split("\n") .filter(x => x.trim().length > 0 && !x.startsWith('#')); // Remove empty lines and comments + } let finalResults = []; let tempResults = wildcards.filter(x => (wcWord !== null && wcWord.length > 0) ? x.toLowerCase().includes(wcWord) : x) // Filter by tagword @@ -46,11 +60,20 @@ class WildcardFileParser extends BaseTagParser { let finalResults = []; // Get final results tempResults.forEach(wcFile => { - let result = new AutocompleteResult(wcFile[1].trim(), ResultType.wildcardFile); - result.meta = "Wildcard file"; + let result = null; + if (wcFile[0].endsWith(".yaml")) { + result = new AutocompleteResult(wcFile[1].trim(), ResultType.yamlWildcard); + result.meta = "YAML wildcard collection"; + } else { + result = new AutocompleteResult(wcFile[1].trim(), ResultType.wildcardFile); + result.meta = "Wildcard file"; + } + finalResults.push(result); }); + finalResults.sort((a, b) => a.text.localeCompare(b.text)); + return finalResults; } } @@ -87,6 +110,17 @@ async function load() { wcExtFile = wcExtFile.map(x => [base, x]); wildcardExtFiles.push(...wcExtFile); } + + // Load the yaml wildcard json file and append it as a wildcard file, appending each key as a path component until we reach the end + yamlWildcards = await readFile(`${tagBasePath}/temp/wc_yaml.json`, true); + + // Append each key as a path component until we reach a leaf + Object.keys(yamlWildcards).forEach(file => { + const flattened = flatten(yamlWildcards[file], [], "/"); + Object.keys(flattened).forEach(key => { + wildcardExtFiles.push([file, key]); + }); + }); } catch (e) { console.error("Error loading wildcards: " + e); } @@ -94,7 +128,7 @@ async function load() { } function sanitize(tagType, text) { - if (tagType === ResultType.wildcardFile) { + if (tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard) { return `__${text}__`; } else if (tagType === ResultType.wildcardTag) { return text.replace(/^.*?: /g, ""); @@ -104,7 +138,7 @@ function sanitize(tagType, text) { function keepOpenIfWildcard(tagType, sanitizedText, newPrompt, textArea) { // If it's a wildcard, we want to keep the results open so the user can select another wildcard - if (tagType === ResultType.wildcardFile) { + if (tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard) { hideBlocked = true; autocomplete(textArea, newPrompt, sanitizedText); setTimeout(() => { hideBlocked = false; }, 450); diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index c4ec5b2..02455fa 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -375,7 +375,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout } } - if (tagType === ResultType.wildcardFile + if (tagType === ResultType.wildcardFile || tagType === ResultType.yamlWildcard && tabCompletedWithoutChoice && TAC_CFG.wildcardCompletionMode !== "Always fully" && sanitizedText.includes("/")) { @@ -402,9 +402,11 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout } } // Don't cut off the __ at the end if it is already the full path - if (firstDifference < longestResult) { + if (firstDifference > 0 && firstDifference < longestResult) { // +2 because the sanitized text already has the __ at the start but the matched text doesn't sanitizedText = sanitizedText.substring(0, firstDifference + 2); + } else if (firstDifference === 0) { + sanitizedText = tagword; } } } @@ -420,7 +422,7 @@ async function insertTextAtCursor(textArea, result, tagword, tabCompletedWithout var optionalSeparator = ""; let extraNetworkTypes = [ResultType.hypernetwork, ResultType.lora]; - let noCommaTypes = [ResultType.wildcardFile, ResultType.yamlWildcard].concat(extraNetworkTypes); + let noCommaTypes = [ResultType.wildcardFile, ResultType.yamlWildcard, ResultType.umiWildcard].concat(extraNetworkTypes); if (!noCommaTypes.includes(tagType)) { // Append comma if enabled and not already present let beforeComma = surrounding.match(new RegExp(`${escapeRegExp(tagword)}[,:]`, "i")) !== null; @@ -597,7 +599,8 @@ function addResultsToList(textArea, results, tagword, resetList) { // Print search term bolded in result itemText.innerHTML = displayText.replace(tagword, `${tagword}`); - if (result.type === ResultType.wildcardFile && itemText.innerHTML.includes("/")) { + const splitTypes = [ResultType.wildcardFile, ResultType.yamlWildcard] + if (splitTypes.includes(result.type) && itemText.innerHTML.includes("/")) { let parts = itemText.innerHTML.split("/"); let lastPart = parts[parts.length - 1]; parts = parts.slice(0, parts.length - 1); @@ -1114,7 +1117,7 @@ async function refreshTacTempFiles() { setTimeout(async () => { wildcardFiles = []; wildcardExtFiles = []; - yamlWildcards = []; + umiWildcards = []; embeddings = []; hypernetworks = []; loras = []; diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index a86e2d4..f6c391e 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -36,37 +36,73 @@ def get_ext_wildcards(): return wildcard_files +def is_umi_format(data): + """Returns True if the YAML file is in UMI format.""" + issue_found = False + for item in data: + if not (data[item] and 'Tags' in data[item] and isinstance(data[item]['Tags'], list)): + issue_found = True + break + return not issue_found -def get_ext_wildcard_tags(): +def parse_umi_format(umi_tags, count, data): + for item in data: + umi_tags[count] = ','.join(data[item]['Tags']) + count += 1 + + +def parse_dynamic_prompt_format(yaml_wildcards, data, path): + # Recurse subkeys, delete those without string lists as values + def recurse_dict(d: dict): + for key, value in d.copy().items(): + if isinstance(value, dict): + recurse_dict(value) + elif not (isinstance(value, list) and all(isinstance(v, str) for v in value)): + del d[key] + + recurse_dict(data) + # Add to yaml_wildcards + yaml_wildcards[path.name] = data + + +def get_yaml_wildcards(): """Returns a list of all tags found in extension YAML files found under a Tags: key.""" - wildcard_tags = {} # { tag: count } yaml_files = [] for path in WILDCARD_EXT_PATHS: yaml_files.extend(p for p in path.rglob("*.yml")) yaml_files.extend(p for p in path.rglob("*.yaml")) + + yaml_wildcards = {} + + umi_tags = {} # { tag: count } count = 0 + for path in yaml_files: try: with open(path, encoding="utf8") as file: data = yaml.safe_load(file) - if data: - for item in data: - if data[item] and 'Tags' in data[item] and isinstance(data[item]['Tags'], list): - wildcard_tags[count] = ','.join(data[item]['Tags']) - count += 1 - else: - print('Issue with tags found in ' + path.name + ' at item ' + item) + if (data): + if (is_umi_format(data)): + parse_umi_format(umi_tags, count, data) + else: + parse_dynamic_prompt_format(yaml_wildcards, data, path) else: print('No data found in ' + path.name) except yaml.YAMLError: - print('Issue in parsing YAML file ' + path.name ) + print('Issue in parsing YAML file ' + path.name) continue + # Sort by count - sorted_tags = sorted(wildcard_tags.items(), key=lambda item: item[1], reverse=True) - output = [] - for tag, count in sorted_tags: - output.append(f"{tag},{count}") - return output + umi_sorted = sorted(umi_tags.items(), key=lambda item: item[1], reverse=True) + umi_output = [] + for tag, count in umi_sorted: + umi_output.append(f"{tag},{count}") + + if (len(umi_output) > 0): + write_to_temp_file('umi_tags.txt', umi_output) + + with open(TEMP_PATH.joinpath("wc_yaml.json"), "w", encoding="utf-8") as file: + json.dump(yaml_wildcards, file, ensure_ascii=False) def get_embeddings(sd_model): @@ -226,7 +262,8 @@ if not TEMP_PATH.exists(): # even if no wildcards or embeddings are found write_to_temp_file('wc.txt', []) write_to_temp_file('wce.txt', []) -write_to_temp_file('wcet.txt', []) +write_to_temp_file('wc_yaml.json', []) +write_to_temp_file('umi_tags.txt', []) write_to_temp_file('hyp.txt', []) write_to_temp_file('lora.txt', []) write_to_temp_file('lyco.txt', []) @@ -240,6 +277,8 @@ if EMB_PATH.exists(): script_callbacks.on_model_loaded(get_embeddings) def refresh_temp_files(): + global WILDCARD_EXT_PATHS + WILDCARD_EXT_PATHS = find_ext_wildcard_paths() write_temp_files() get_embeddings(shared.sd_model) @@ -255,10 +294,8 @@ def write_temp_files(): wildcards_ext = get_ext_wildcards() if wildcards_ext: write_to_temp_file('wce.txt', wildcards_ext) - # Write yaml extension wildcards to wcet.txt if found - wildcards_yaml_ext = get_ext_wildcard_tags() - if wildcards_yaml_ext: - write_to_temp_file('wcet.txt', wildcards_yaml_ext) + # Write yaml extension wildcards to umi_tags.txt and wc_yaml.json if found + get_yaml_wildcards() if HYP_PATH.exists(): hypernets = get_hypernetworks()