Skipped embeddings now also hold shape info

so we don't need to guess the type anymore if the model didn't load any.
This commit is contained in:
Dominik Reh
2023-01-02 12:45:56 +01:00
parent 454c13ef6d
commit 5f2f746310

View File

@@ -97,8 +97,11 @@ def get_embeddings():
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
# Get the shape of the first item in the dict
emb_a_shape = -1
emb_b_shape = -1
if (len(emb_type_a) > 0):
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
if (len(emb_type_b) > 0):
emb_b_shape = next(iter(emb_type_b.items()))[1].shape
# Add embeddings to the correct list
V1_SHAPE = 768
@@ -108,11 +111,14 @@ def get_embeddings():
if (emb_a_shape == V1_SHAPE):
emb_v1 = list(emb_type_a.keys())
emb_v2 = list(emb_type_b)
elif (emb_a_shape == V2_SHAPE):
emb_v1 = list(emb_type_b)
emb_v2 = list(emb_type_a.keys())
if (emb_b_shape == V1_SHAPE):
emb_v1 = list(emb_type_b.keys())
elif (emb_b_shape == V2_SHAPE):
emb_v2 = list(emb_type_b.keys())
# Create a new list to store the modified strings
results = []
@@ -123,7 +129,6 @@ def get_embeddings():
elif string in emb_v2:
results.append(string + ",v2")
# If the string is not in either, default to v1
# (we can't know what it is since the startup model loaded none of them, but it's probably v1 since v2 is newer)
else:
results.append(string + ",v1")