diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index d88b18e..60799c1 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -88,36 +88,40 @@ def get_embeddings(): # Remove file extensions embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir] - # Wait for all embeddings to be loaded - while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir): - time.sleep(2) # Sleep for 2 seconds - - # Get embedding dict from sd_hijack to separate v1/v2 embeddings - emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings - emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings - # Get the shape of the first item in the dict - emb_a_shape = -1 - emb_b_shape = -1 - if (len(emb_type_a) > 0): - emb_a_shape = next(iter(emb_type_a.items()))[1].shape - if (len(emb_type_b) > 0): - emb_b_shape = next(iter(emb_type_b.items()))[1].shape - - # Add embeddings to the correct list + # Version constants V1_SHAPE = 768 V2_SHAPE = 1024 emb_v1 = [] emb_v2 = [] - if (emb_a_shape == V1_SHAPE): - emb_v1 = list(emb_type_a.keys()) - elif (emb_a_shape == V2_SHAPE): - emb_v2 = list(emb_type_a.keys()) + try: + # Wait for all embeddings to be loaded + while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir): + time.sleep(2) # Sleep for 2 seconds - if (emb_b_shape == V1_SHAPE): - emb_v1 = list(emb_type_b.keys()) - elif (emb_b_shape == V2_SHAPE): - emb_v2 = list(emb_type_b.keys()) + # Get embedding dict from sd_hijack to separate v1/v2 embeddings + emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings + emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings + # Get the shape of the first item in the dict + emb_a_shape = -1 + emb_b_shape = -1 + if (len(emb_type_a) > 0): + emb_a_shape = next(iter(emb_type_a.items()))[1].shape + if (len(emb_type_b) > 0): + emb_b_shape = next(iter(emb_type_b.items()))[1].shape + + # Add embeddings to the correct list + if (emb_a_shape == V1_SHAPE): + emb_v1 = list(emb_type_a.keys()) + elif (emb_a_shape == V2_SHAPE): + emb_v2 = list(emb_type_a.keys()) + + if (emb_b_shape == V1_SHAPE): + emb_v1 = list(emb_type_b.keys()) + elif (emb_b_shape == V2_SHAPE): + emb_v2 = list(emb_type_b.keys()) + except AttributeError: + print("tag_autocomplete_helper: Old webui version, using fallback for embedding completion.") # Create a new list to store the modified strings results = [] @@ -130,7 +134,7 @@ def get_embeddings(): results.append(string + ",v2") # If the string is not in either, default to v1 else: - results.append(string + ",v1") + results.append(string + ",") write_to_temp_file('emb.txt', results)