From da9acfea2a02b3ccfa696a8cd65c0e60e8a778bb Mon Sep 17 00:00:00 2001 From: Dominik Reh Date: Tue, 3 Jan 2023 17:30:30 +0100 Subject: [PATCH] Rework embedding load, now uses callback. Should hopefully fix #100 --- scripts/tag_autocomplete_helper.py | 54 +++++++++++++++--------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 60799c1..eb023de 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -81,24 +81,17 @@ def get_ext_wildcard_tags(): return output -def get_embeddings(): +def get_embeddings(sd_model): """Write a list of all embeddings with their version""" - # Get a list of all embeddings in the folder - embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}] - # Remove file extensions - embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir] # Version constants V1_SHAPE = 768 V2_SHAPE = 1024 emb_v1 = [] emb_v2 = [] + results = [] try: - # Wait for all embeddings to be loaded - while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir): - time.sleep(2) # Sleep for 2 seconds - # Get embedding dict from sd_hijack to separate v1/v2 embeddings emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings @@ -120,21 +113,27 @@ def get_embeddings(): emb_v1 = list(emb_type_b.keys()) elif (emb_b_shape == V2_SHAPE): emb_v2 = list(emb_type_b.keys()) - except AttributeError: - print("tag_autocomplete_helper: Old webui version, using fallback for embedding completion.") - # Create a new list to store the modified strings - results = [] - - # Iterate through each string in the big list - for string in embs_in_dir: - if string in emb_v1: - results.append(string + ",v1") - elif string in emb_v2: - results.append(string + ",v2") - # If the string is not in either, default to v1 - else: - results.append(string + ",") + # Get shape of current model + #vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1) + #model_shape = vec.shape[1] + # Show relevant entries at the top + #if (model_shape == V1_SHAPE): + # results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2] + #elif (model_shape == V2_SHAPE): + # results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1] + #else: + # raise AttributeError # Fallback to old method + results = sorted([e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2], key=lambda x: x.lower()) + except AttributeError: + print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.") + # Get a list of all embeddings in the folder + all_embeds = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}] + # Remove files with a size of 0 + all_embeds = [e for e in all_embeds if EMB_PATH.joinpath(e).stat().st_size > 0] + # Remove file extensions + all_embeds = [e[:e.rfind('.')] for e in all_embeds] + results = [e + "," for e in all_embeds] write_to_temp_file('emb.txt', results) @@ -179,7 +178,9 @@ if not TEMP_PATH.exists(): write_to_temp_file('wc.txt', []) write_to_temp_file('wce.txt', []) write_to_temp_file('wcet.txt', []) -write_to_temp_file('emb.txt', []) +# Only reload embeddings if the file doesn't exist, since they are already re-written on model load +if not TEMP_PATH.joinpath("emb.txt").exists(): + write_to_temp_file('emb.txt', []) # Write wildcards to wc.txt if found if WILDCARD_PATH.exists(): @@ -199,9 +200,8 @@ if WILDCARD_EXT_PATHS is not None: # Write embeddings to emb.txt if found if EMB_PATH.exists(): - # We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads) - thread = threading.Thread(target=get_embeddings) - thread.start() + # Get embeddings after the model loaded callback + script_callbacks.on_model_loaded(get_embeddings) # Register autocomplete options