mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 19:19:57 +00:00
Model sort selection.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# This helper script scans folders for wildcards and embeddings and writes them
|
||||
# to a temporary file to expose it to the javascript side
|
||||
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import urllib.parse
|
||||
@@ -24,12 +25,48 @@ except Exception as e: # Not supported.
|
||||
load_textual_inversion_embeddings = lambda *args, **kwargs: None
|
||||
print("Tag Autocomplete: Cannot reload embeddings instantly:", e)
|
||||
|
||||
QUO = "\""
|
||||
EXTKEY = "tac"
|
||||
# EXTNAME = "Tag Autocomplete Helper"
|
||||
# Default values, because shared doesn't allocate a value automatically.
|
||||
# (id: def)
|
||||
DEXTSETV = {
|
||||
"sortModels": "default",
|
||||
}
|
||||
fseti = lambda x: shared.opts.data.get(EXTKEY + "_" + x, DEXTSETV[x])
|
||||
|
||||
def sort_models(lmodels, sort_method = None, indwrap = False):
|
||||
"""Sorts models according to setting.
|
||||
|
||||
Input: list of (full_path, display_name, {hash}) models.
|
||||
Returns models in the standard temp file format (ie name, hash).
|
||||
Default sort is lexicographical, mdate is by file modification date.
|
||||
Hash is optional, can be any textual value written after the name (eg v1/v2 for embeddings).
|
||||
For some reason only loras and lycos are wrapped in quote marks, so it's left to caller.
|
||||
Creep: Requires sort modifications on js side to preserve order during merge.
|
||||
"""
|
||||
if len(lmodels) == 0:
|
||||
return lmodels
|
||||
if sort_method is None:
|
||||
sort_method = fseti("sortModels")
|
||||
if sort_method == "modified_date":
|
||||
lsorted = sorted(lmodels, key=lambda x: os.path.getmtime(x[0]), reverse = True)
|
||||
else:
|
||||
lsorted = sorted(lmodels, key = lambda x: x[1].lower())
|
||||
if len(lsorted[0]) > 2:
|
||||
# lret = [f"\"{name}\",{hash}" for pt, name, hash in lsorted]
|
||||
lret = [f"{name},{hash}" for pt, name, hash in lsorted]
|
||||
else:
|
||||
lret = [name for pt, name in lsorted]
|
||||
return lret
|
||||
|
||||
def get_wildcards():
|
||||
"""Returns a list of all wildcards. Works on nested folders."""
|
||||
wildcard_files = list(WILDCARD_PATH.rglob("*.txt"))
|
||||
resolved = [w.relative_to(WILDCARD_PATH).as_posix(
|
||||
) for w in wildcard_files if w.name != "put wildcards here.txt"]
|
||||
return resolved
|
||||
resolved = [(w, w.relative_to(WILDCARD_PATH).as_posix())
|
||||
for w in wildcard_files
|
||||
if w.name != "put wildcards here.txt"]
|
||||
return sort_models(resolved)
|
||||
|
||||
|
||||
def get_ext_wildcards():
|
||||
@@ -38,7 +75,10 @@ def get_ext_wildcards():
|
||||
|
||||
for path in WILDCARD_EXT_PATHS:
|
||||
wildcard_files.append(path.as_posix())
|
||||
wildcard_files.extend(p.relative_to(path).as_posix() for p in path.rglob("*.txt") if p.name != "put wildcards here.txt")
|
||||
lfiles = [(w, w.relative_to(path).as_posix())
|
||||
for w in path.rglob("*.txt")
|
||||
if w.name != "put wildcards here.txt"]
|
||||
wildcard_files.extend(sort_models(lfiles))
|
||||
wildcard_files.append("-----")
|
||||
|
||||
return wildcard_files
|
||||
@@ -136,14 +176,14 @@ def get_embeddings(sd_model):
|
||||
|
||||
# Add embeddings to the correct list
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_a.keys())
|
||||
emb_v1 = [(v.filename, k, "v1") for (k,v) in emb_type_a.items()]
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_a.keys())
|
||||
emb_v2 = [(v.filename, k, "v2") for (k,v) in emb_type_a.items()]
|
||||
|
||||
if (emb_b_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_b.keys())
|
||||
emb_v1 = [(v.filename, k, "v1") for (k,v) in emb_type_b.items()]
|
||||
elif (emb_b_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_b.keys())
|
||||
emb_v2 = [(v.filename, k, "v2") for (k,v) in emb_type_b.items()]
|
||||
|
||||
# Get shape of current model
|
||||
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
@@ -155,7 +195,7 @@ def get_embeddings(sd_model):
|
||||
# results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
|
||||
#else:
|
||||
# raise AttributeError # Fallback to old method
|
||||
results = sorted([e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2], key=lambda x: x.lower())
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2)
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
|
||||
# Get a list of all embeddings in the folder
|
||||
@@ -173,9 +213,10 @@ def get_hypernetworks():
|
||||
|
||||
# Get a list of all hypernetworks in the folder
|
||||
hyp_paths = [Path(h) for h in glob.glob(HYP_PATH.joinpath("**/*").as_posix(), recursive=True)]
|
||||
all_hypernetworks = [str(h.name) for h in hyp_paths if h.suffix in {".pt"}]
|
||||
all_hypernetworks = [h for h in hyp_paths if h.suffix in {".pt"}]
|
||||
# Remove file extensions
|
||||
return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower())
|
||||
lfiles = [(h, os.path.splitext(h.name)[0]) for h in all_hypernetworks]
|
||||
return sort_models(lfiles)
|
||||
|
||||
model_keyword_installed = write_model_keyword_path()
|
||||
def get_lora():
|
||||
@@ -186,17 +227,17 @@ def get_lora():
|
||||
lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)]
|
||||
# Get hashes
|
||||
valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
hashes = {}
|
||||
lhashes = []
|
||||
for l in valid_loras:
|
||||
name = l.relative_to(LORA_PATH).as_posix()
|
||||
name = QUO + name + QUO # Wrapped in quote marks.
|
||||
if model_keyword_installed:
|
||||
hashes[name] = get_lora_simple_hash(l)
|
||||
vhash = get_lora_simple_hash(l)
|
||||
else:
|
||||
hashes[name] = ""
|
||||
vhash = ""
|
||||
lhashes.append((l, name, vhash))
|
||||
# Sort
|
||||
sorted_loras = dict(sorted(hashes.items()))
|
||||
# Add hashes and return
|
||||
return [f"\"{name}\",{hash}" for name, hash in sorted_loras.items()]
|
||||
return sort_models(lhashes)
|
||||
|
||||
|
||||
def get_lyco():
|
||||
@@ -207,19 +248,18 @@ def get_lyco():
|
||||
|
||||
# Get hashes
|
||||
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
|
||||
hashes = {}
|
||||
lhashes = []
|
||||
for ly in valid_lycos:
|
||||
name = ly.relative_to(LYCO_PATH).as_posix()
|
||||
name = QUO + name + QUO
|
||||
if model_keyword_installed:
|
||||
hashes[name] = get_lora_simple_hash(ly)
|
||||
vhash = get_lora_simple_hash(ly)
|
||||
else:
|
||||
hashes[name] = ""
|
||||
vhash = ""
|
||||
lhashes.append((ly, name, vhash))
|
||||
|
||||
# Sort
|
||||
sorted_lycos = dict(sorted(hashes.items()))
|
||||
# Add hashes and return
|
||||
return [f"\"{name}\",{hash}" for name, hash in sorted_lycos.items()]
|
||||
|
||||
return sort_models(lhashes)
|
||||
|
||||
def write_tag_base_path():
|
||||
"""Writes the tag base path to a fixed location temporary file"""
|
||||
@@ -397,6 +437,7 @@ def on_ui_settings():
|
||||
"tac_extra.addMode": shared.OptionInfo("Insert before", "Mode to add the extra tags to the main tag list", gr.Dropdown, lambda: {"choices": ["Insert before","Insert after"]}),
|
||||
# Chant settings
|
||||
"tac_chantFile": shared.OptionInfo("demo-chants.json", "Chant filename", gr.Dropdown, lambda: {"choices": json_files_withnone}, refresh=update_json_files).info("Chants are longer prompt presets"),
|
||||
"tac_sortModels": shared.OptionInfo("name", "Model sort order", gr.Dropdown, lambda: {"choices": ["name", "modified_date"]}).info("WIP: Order of appearance for models in dropdown"),
|
||||
}
|
||||
|
||||
# Add normal settings
|
||||
|
||||
Reference in New Issue
Block a user