mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-02-27 10:24:11 +00:00
Add fallback for embedding loading
Fixes error on outdated webuis, as mentioned in #98 and #99
This commit is contained in:
@@ -88,36 +88,40 @@ def get_embeddings():
|
||||
# Remove file extensions
|
||||
embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir]
|
||||
|
||||
# Wait for all embeddings to be loaded
|
||||
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
|
||||
time.sleep(2) # Sleep for 2 seconds
|
||||
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
# Get the shape of the first item in the dict
|
||||
emb_a_shape = -1
|
||||
emb_b_shape = -1
|
||||
if (len(emb_type_a) > 0):
|
||||
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
|
||||
if (len(emb_type_b) > 0):
|
||||
emb_b_shape = next(iter(emb_type_b.items()))[1].shape
|
||||
|
||||
# Add embeddings to the correct list
|
||||
# Version constants
|
||||
V1_SHAPE = 768
|
||||
V2_SHAPE = 1024
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_a.keys())
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_a.keys())
|
||||
try:
|
||||
# Wait for all embeddings to be loaded
|
||||
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
|
||||
time.sleep(2) # Sleep for 2 seconds
|
||||
|
||||
if (emb_b_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_b.keys())
|
||||
elif (emb_b_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_b.keys())
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
# Get the shape of the first item in the dict
|
||||
emb_a_shape = -1
|
||||
emb_b_shape = -1
|
||||
if (len(emb_type_a) > 0):
|
||||
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
|
||||
if (len(emb_type_b) > 0):
|
||||
emb_b_shape = next(iter(emb_type_b.items()))[1].shape
|
||||
|
||||
# Add embeddings to the correct list
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_a.keys())
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_a.keys())
|
||||
|
||||
if (emb_b_shape == V1_SHAPE):
|
||||
emb_v1 = list(emb_type_b.keys())
|
||||
elif (emb_b_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_b.keys())
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version, using fallback for embedding completion.")
|
||||
|
||||
# Create a new list to store the modified strings
|
||||
results = []
|
||||
@@ -130,7 +134,7 @@ def get_embeddings():
|
||||
results.append(string + ",v2")
|
||||
# If the string is not in either, default to v1
|
||||
else:
|
||||
results.append(string + ",v1")
|
||||
results.append(string + ",")
|
||||
|
||||
write_to_temp_file('emb.txt', results)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user