Rework sorting function to calculate keys instead of pre-sort the list

Rename added/changed variables to be clearer
This commit is contained in:
DominikDoom
2023-09-13 11:46:17 +02:00
parent 3953260485
commit 475ef59197

View File

@@ -25,40 +25,44 @@ except Exception as e: # Not supported.
load_textual_inversion_embeddings = lambda *args, **kwargs: None
print("Tag Autocomplete: Cannot reload embeddings instantly:", e)
QUO = "\""
EXTKEY = "tac"
# EXTNAME = "Tag Autocomplete Helper"
# Default values, because shared doesn't allocate a value automatically.
# (id: def)
DEXTSETV = {
"sortModels": "default",
# Sorting functions for extra networks / embeddings stuff
sort_criteria = {
"Name": {
"key": lambda path, name: name.lower() if Path(name).parts > 1 else path.stem.lower(),
"reverse": False
},
"Date Modified": {
"key": lambda path, name: path.stat().st_mtime,
"reverse": True
},
}
fseti = lambda x: shared.opts.data.get(EXTKEY + "_" + x, DEXTSETV[x])
def sort_models(lmodels, sort_method = None, indwrap = False):
"""Sorts models according to setting.
def sort_models(model_list, sort_method = None):
"""Sorts models according to the setting.
Input: list of (full_path, display_name, {hash}) models.
Returns models in the standard temp file format (ie name, hash).
Default sort is lexicographical, mdate is by file modification date.
Hash is optional, can be any textual value written after the name (eg v1/v2 for embeddings).
For some reason only loras and lycos are wrapped in quote marks, so it's left to caller.
Creep: Requires sort modifications on js side to preserve order during merge.
Returns models in the format of name, sort key, meta.
Meta is optional and can be a hash, version string or other required info.
Whether the currently selected sort method needs to be reversed is provided
by an API endpoint to reduce duplication in temp files.
"""
if len(lmodels) == 0:
return lmodels
if len(model_list) == 0:
return model_list
if sort_method is None:
sort_method = fseti("sortModels")
if sort_method == "modified_date":
lsorted = sorted(lmodels, key=lambda x: os.path.getmtime(x[0]), reverse = True)
else:
lsorted = sorted(lmodels, key = lambda x: x[1].lower())
if len(lsorted[0]) > 2:
# lret = [f"\"{name}\",{hash}" for pt, name, hash in lsorted]
lret = [f"{name},{hash}" for pt, name, hash in lsorted]
sort_method = getattr(shared.opts, "tac_modelSortOrder", "Name")
# Get sorting method from dictionary
sorter = sort_criteria[sort_method] if sort_criteria[sort_method] else sort_criteria['Name']
# During merging on the JS side we need to re-sort anyway, so here only the sort criteria are calculated.
# The list itself doesn't need to get sorted at this point.
if len(model_list[0]) > 2:
results = [f'{name},"{sorter["key"](path, name)}",{meta}' for path, name, meta in model_list]
else:
lret = [name for pt, name in lsorted]
return lret
results = [f'{name},"{sorter["key"](path, name)}"' for path, name in model_list]
return results
def get_wildcards():
"""Returns a list of all wildcards. Works on nested folders."""
@@ -75,10 +79,10 @@ def get_ext_wildcards():
for path in WILDCARD_EXT_PATHS:
wildcard_files.append(path.as_posix())
lfiles = [(w, w.relative_to(path).as_posix())
for w in path.rglob("*.txt")
if w.name != "put wildcards here.txt"]
wildcard_files.extend(sort_models(lfiles))
resolved = [(w, w.relative_to(path).as_posix())
for w in path.rglob("*.txt")
if w.name != "put wildcards here.txt"]
wildcard_files.extend(sort_models(resolved))
wildcard_files.append("-----")
return wildcard_files
@@ -176,14 +180,14 @@ def get_embeddings(sd_model):
# Add embeddings to the correct list
if (emb_a_shape == V1_SHAPE):
emb_v1 = [(v.filename, k, "v1") for (k,v) in emb_type_a.items()]
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_a.items()]
elif (emb_a_shape == V2_SHAPE):
emb_v2 = [(v.filename, k, "v2") for (k,v) in emb_type_a.items()]
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_a.items()]
if (emb_b_shape == V1_SHAPE):
emb_v1 = [(v.filename, k, "v1") for (k,v) in emb_type_b.items()]
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_b.items()]
elif (emb_b_shape == V2_SHAPE):
emb_v2 = [(v.filename, k, "v2") for (k,v) in emb_type_b.items()]
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_b.items()]
# Get shape of current model
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
@@ -213,10 +217,8 @@ def get_hypernetworks():
# Get a list of all hypernetworks in the folder
hyp_paths = [Path(h) for h in glob.glob(HYP_PATH.joinpath("**/*").as_posix(), recursive=True)]
all_hypernetworks = [h for h in hyp_paths if h.suffix in {".pt"}]
# Remove file extensions
lfiles = [(h, os.path.splitext(h.name)[0]) for h in all_hypernetworks]
return sort_models(lfiles)
all_hypernetworks = [(h, h.stem) for h in hyp_paths if h.suffix in {".pt"}]
return sort_models(all_hypernetworks)
model_keyword_installed = write_model_keyword_path()
def get_lora():
@@ -227,17 +229,17 @@ def get_lora():
lora_paths = [Path(l) for l in glob.glob(LORA_PATH.joinpath("**/*").as_posix(), recursive=True)]
# Get hashes
valid_loras = [lf for lf in lora_paths if lf.suffix in {".safetensors", ".ckpt", ".pt"}]
lhashes = []
loras_with_hash = []
for l in valid_loras:
name = l.relative_to(LORA_PATH).as_posix()
name = QUO + name + QUO # Wrapped in quote marks.
name = f'"{name}"'
if model_keyword_installed:
vhash = get_lora_simple_hash(l)
hash = get_lora_simple_hash(l)
else:
vhash = ""
lhashes.append((l, name, vhash))
hash = ""
loras_with_hash.append((l, name, hash))
# Sort
return sort_models(lhashes)
return sort_models(loras_with_hash)
def get_lyco():
@@ -248,18 +250,17 @@ def get_lyco():
# Get hashes
valid_lycos = [lyf for lyf in lyco_paths if lyf.suffix in {".safetensors", ".ckpt", ".pt"}]
lhashes = []
lycos_with_hash = []
for ly in valid_lycos:
name = ly.relative_to(LYCO_PATH).as_posix()
name = QUO + name + QUO
name = f'"{name}"'
if model_keyword_installed:
vhash = get_lora_simple_hash(ly)
hash = get_lora_simple_hash(ly)
else:
vhash = ""
lhashes.append((ly, name, vhash))
hash = ""
lycos_with_hash.append((ly, name, hash))
# Sort
return sort_models(lhashes)
return sort_models(lycos_with_hash)
def write_tag_base_path():
"""Writes the tag base path to a fixed location temporary file"""
@@ -415,6 +416,7 @@ def on_ui_settings():
"tac_useLycos": shared.OptionInfo(True, "Search for LyCORIS/LoHa"),
"tac_showWikiLinks": shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page").info("Warning: This is an external site and very likely contains NSFW examples!"),
"tac_showExtraNetworkPreviews": shared.OptionInfo(True, "Show preview thumbnails for extra networks if available"),
"tac_modelSortOrder": shared.OptionInfo("Name", "Model sort order", gr.Dropdown, lambda: {"choices": ["Name", "Date Modified"]}).info("Order for extra network models and wildcards in dropdown"),
# Insertion related settings
"tac_replaceUnderscores": shared.OptionInfo(True, "Replace underscores with spaces on insertion"),
"tac_escapeParentheses": shared.OptionInfo(True, "Escape parentheses on insertion"),
@@ -437,7 +439,6 @@ def on_ui_settings():
"tac_extra.addMode": shared.OptionInfo("Insert before", "Mode to add the extra tags to the main tag list", gr.Dropdown, lambda: {"choices": ["Insert before","Insert after"]}),
# Chant settings
"tac_chantFile": shared.OptionInfo("demo-chants.json", "Chant filename", gr.Dropdown, lambda: {"choices": json_files_withnone}, refresh=update_json_files).info("Chants are longer prompt presets"),
"tac_sortModels": shared.OptionInfo("name", "Model sort order", gr.Dropdown, lambda: {"choices": ["name", "modified_date"]}).info("WIP: Order of appearance for models in dropdown"),
}
# Add normal settings
@@ -524,6 +525,11 @@ def api_tac(_: gr.Blocks, app: FastAPI):
except Exception as e:
return JSONResponse({"error": e}, status_code=500)
@app.get("/tacapi/v1/sort-direction")
async def get_sort_direction():
criterium = getattr(shared.opts, "tac_modelSortOrder", "Name")
return sort_criteria[criterium]['reverse'] if sort_criteria[criterium] else sort_criteria['Name']['reverse']
@app.get("/tacapi/v1/lora-info/{lora_name}")
async def get_lora_info(lora_name):
return await get_json_info(LORA_PATH, lora_name)