mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 19:19:57 +00:00
@@ -81,24 +81,17 @@ def get_ext_wildcard_tags():
|
||||
return output
|
||||
|
||||
|
||||
def get_embeddings():
|
||||
def get_embeddings(sd_model):
|
||||
"""Write a list of all embeddings with their version"""
|
||||
# Get a list of all embeddings in the folder
|
||||
embs_in_dir = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("**/*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
|
||||
# Remove file extensions
|
||||
embs_in_dir = [e[:e.rfind('.')] for e in embs_in_dir]
|
||||
|
||||
# Version constants
|
||||
V1_SHAPE = 768
|
||||
V2_SHAPE = 1024
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Wait for all embeddings to be loaded
|
||||
while len(sd_hijack.model_hijack.embedding_db.word_embeddings) + len(sd_hijack.model_hijack.embedding_db.skipped_embeddings) < len(embs_in_dir):
|
||||
time.sleep(2) # Sleep for 2 seconds
|
||||
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
@@ -120,21 +113,27 @@ def get_embeddings():
|
||||
emb_v1 = list(emb_type_b.keys())
|
||||
elif (emb_b_shape == V2_SHAPE):
|
||||
emb_v2 = list(emb_type_b.keys())
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version, using fallback for embedding completion.")
|
||||
|
||||
# Create a new list to store the modified strings
|
||||
results = []
|
||||
|
||||
# Iterate through each string in the big list
|
||||
for string in embs_in_dir:
|
||||
if string in emb_v1:
|
||||
results.append(string + ",v1")
|
||||
elif string in emb_v2:
|
||||
results.append(string + ",v2")
|
||||
# If the string is not in either, default to v1
|
||||
else:
|
||||
results.append(string + ",")
|
||||
# Get shape of current model
|
||||
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
#model_shape = vec.shape[1]
|
||||
# Show relevant entries at the top
|
||||
#if (model_shape == V1_SHAPE):
|
||||
# results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2]
|
||||
#elif (model_shape == V2_SHAPE):
|
||||
# results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
|
||||
#else:
|
||||
# raise AttributeError # Fallback to old method
|
||||
results = sorted([e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2], key=lambda x: x.lower())
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
|
||||
# Get a list of all embeddings in the folder
|
||||
all_embeds = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.glob("*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
|
||||
# Remove files with a size of 0
|
||||
all_embeds = [e for e in all_embeds if EMB_PATH.joinpath(e).stat().st_size > 0]
|
||||
# Remove file extensions
|
||||
all_embeds = [e[:e.rfind('.')] for e in all_embeds]
|
||||
results = [e + "," for e in all_embeds]
|
||||
|
||||
write_to_temp_file('emb.txt', results)
|
||||
|
||||
@@ -179,7 +178,9 @@ if not TEMP_PATH.exists():
|
||||
write_to_temp_file('wc.txt', [])
|
||||
write_to_temp_file('wce.txt', [])
|
||||
write_to_temp_file('wcet.txt', [])
|
||||
write_to_temp_file('emb.txt', [])
|
||||
# Only reload embeddings if the file doesn't exist, since they are already re-written on model load
|
||||
if not TEMP_PATH.joinpath("emb.txt").exists():
|
||||
write_to_temp_file('emb.txt', [])
|
||||
|
||||
# Write wildcards to wc.txt if found
|
||||
if WILDCARD_PATH.exists():
|
||||
@@ -199,9 +200,8 @@ if WILDCARD_EXT_PATHS is not None:
|
||||
|
||||
# Write embeddings to emb.txt if found
|
||||
if EMB_PATH.exists():
|
||||
# We need to load the embeddings in a separate thread since we wait for them to be checked (after the model loads)
|
||||
thread = threading.Thread(target=get_embeddings)
|
||||
thread.start()
|
||||
# Get embeddings after the model loaded callback
|
||||
script_callbacks.on_model_loaded(get_embeddings)
|
||||
|
||||
|
||||
# Register autocomplete options
|
||||
|
||||
Reference in New Issue
Block a user