diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 24a107b..d88b18e 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -97,8 +97,11 @@ def get_embeddings(): emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings # Get the shape of the first item in the dict emb_a_shape = -1 + emb_b_shape = -1 if (len(emb_type_a) > 0): emb_a_shape = next(iter(emb_type_a.items()))[1].shape + if (len(emb_type_b) > 0): + emb_b_shape = next(iter(emb_type_b.items()))[1].shape # Add embeddings to the correct list V1_SHAPE = 768 @@ -108,11 +111,14 @@ def get_embeddings(): if (emb_a_shape == V1_SHAPE): emb_v1 = list(emb_type_a.keys()) - emb_v2 = list(emb_type_b) elif (emb_a_shape == V2_SHAPE): - emb_v1 = list(emb_type_b) emb_v2 = list(emb_type_a.keys()) + if (emb_b_shape == V1_SHAPE): + emb_v1 = list(emb_type_b.keys()) + elif (emb_b_shape == V2_SHAPE): + emb_v2 = list(emb_type_b.keys()) + # Create a new list to store the modified strings results = [] @@ -123,7 +129,6 @@ def get_embeddings(): elif string in emb_v2: results.append(string + ",v2") # If the string is not in either, default to v1 - # (we can't know what it is since the startup model loaded none of them, but it's probably v1 since v2 is newer) else: results.append(string + ",v1")