Rework embedding load, now uses callback.

Should hopefully fix #100
This commit is contained in:
Dominik Reh
2023-01-03 17:30:30 +01:00
parent 552c6517b8
commit da9acfea2a

View File

@@ -81,24 +81,17 @@ def get_ext_wildcard_tags():
return output
def get_embeddings():
def get_embeddings(sd_model):
"""Write a list of all embeddings with their version"""
# Get a list of all embeddings in the folder
embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
# Remove file extensions
embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir]
# Version constants
V1_SHAPE = 768
V2_SHAPE = 1024
emb_v1 = []
emb_v2 = []
results = []
try:
# Wait for all embeddings to be loaded
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
time.sleep(2) # Sleep for 2 seconds
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
@@ -120,21 +113,27 @@ def get_embeddings():
emb_v1 = list(emb_type_b.keys())
elif (emb_b_shape == V2_SHAPE):
emb_v2 = list(emb_type_b.keys())
except AttributeError:
print("tag_autocomplete_helper: Old webui version, using fallback for embedding completion.")
# Create a new list to store the modified strings
results = []
# Iterate through each string in the big list
for string in embs_in_dir:
if string in emb_v1:
results.append(string + ",v1")
elif string in emb_v2:
results.append(string + ",v2")
# If the string is not in either, default to v1
else:
results.append(string + ",")
# Get shape of current model
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
#model_shape = vec.shape[1]
# Show relevant entries at the top
#if (model_shape == V1_SHAPE):
# results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2]
#elif (model_shape == V2_SHAPE):
# results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
#else:
# raise AttributeError # Fallback to old method
results = sorted([e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2], key=lambda x: x.lower())
except AttributeError:
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
# Get a list of all embeddings in the folder
all_embeds = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
# Remove files with a size of 0
all_embeds = [e for e in all_embeds if EMB_PATH.joinpath(e).stat().st_size > 0]
# Remove file extensions
all_embeds = [e[:e.rfind('.')] for e in all_embeds]
results = [e + "," for e in all_embeds]
write_to_temp_file('emb.txt', results)
@@ -179,7 +178,9 @@ if not TEMP_PATH.exists():
write_to_temp_file('wc.txt', [])
write_to_temp_file('wce.txt', [])
write_to_temp_file('wcet.txt', [])
write_to_temp_file('emb.txt', [])
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
if not TEMP_PATH.joinpath("emb.txt").exists():
write_to_temp_file('emb.txt', [])
# Write wildcards to wc.txt if found
if WILDCARD_PATH.exists():
@@ -199,9 +200,8 @@ if WILDCARD_EXT_PATHS is not None:
# Write embeddings to emb.txt if found
if EMB_PATH.exists():
# We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads)
thread = threading.Thread(target=get_embeddings)
thread.start()
# Get embeddings after the model loaded callback
script_callbacks.on_model_loaded(get_embeddings)
# Register autocomplete options