Show version info for embeddings

Also allows searching by version to quickly find v1 or v2 model embeddings
Closes #97
This commit is contained in:
Dominik Reh
2023-01-02 00:38:48 +01:00
parent b57042edd0
commit 6deefda279
2 changed files with 85 additions and 12 deletions

View File

@@ -3,8 +3,10 @@
import gradio as gr
from pathlib import Path
from modules import scripts, script_callbacks, shared
from modules import scripts, script_callbacks, shared, sd_hijack
import yaml
import time
import threading
# Webui root path
FILE_DIR = Path().absolute()
@@ -78,9 +80,54 @@ def get_ext_wildcard_tags():
output.append(f"{tag},{count}")
return output
def get_embeddings():
"""Returns a list of all embeddings"""
return [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png"}]
"""Write a list of all embeddings with their version"""
# Get a list of all embeddings in the folder
embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
# Remove file extensions
embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir]
# Wait for all embeddings to be loaded
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
time.sleep(2) # Sleep for 2 seconds
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
# Get the shape of the first item in the dict
emb_a_shape = -1
if (len(emb_type_a) > 0):
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
# Add embeddings to the correct list
V1_SHAPE = 768
V2_SHAPE = 1024
emb_v1 = []
emb_v2 = []
if (emb_a_shape == V1_SHAPE):
emb_v1 = list(emb_type_a.keys())
emb_v2 = list(emb_type_b)
elif (emb_a_shape == V2_SHAPE):
emb_v1 = list(emb_type_b)
emb_v2 = list(emb_type_a.keys())
# Create a new list to store the modified strings
results = []
# Iterate through each string in the big list
for string in embs_in_dir:
if string in emb_v1:
results.append(string + ",v1")
elif string in emb_v2:
results.append(string + ",v2")
# If the string is not in either, default to v1
# (we can't know what it is since the startup model loaded none of them, but it's probably v1 since v2 is newer)
else:
results.append(string + ",v1")
write_to_temp_file('emb.txt', results)
def write_tag_base_path():
@@ -143,9 +190,10 @@ if WILDCARD_EXT_PATHS is not None:
# Write embeddings to emb.txt if found
if EMB_PATH.exists():
embeddings = get_embeddings()
if embeddings:
write_to_temp_file('emb.txt', embeddings)
# We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads)
thread = threading.Thread(target=get_embeddings)
thread.start()
# Register autocomplete options
def on_ui_settings():