mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-02-10 02:00:09 +00:00
Proper support for SDXL embeddings
Now in their own category, other embeddings don't get mislabeled anymore if an XL model is loaded
This commit is contained in:
@@ -11,12 +11,15 @@ class EmbeddingParser extends BaseTagParser {
|
||||
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
|
||||
versionString = searchTerm.slice(0, 2);
|
||||
searchTerm = searchTerm.slice(2);
|
||||
} else if (searchTerm.startsWith("vxl")) {
|
||||
versionString = searchTerm.slice(0, 3);
|
||||
searchTerm = searchTerm.slice(3);
|
||||
}
|
||||
|
||||
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
|
||||
|
||||
if (versionString)
|
||||
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2] === versionString); // Filter by tagword
|
||||
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2].toLowerCase() === versionString.toLowerCase()); // Filter by tagword
|
||||
else
|
||||
tempResults = embeddings.filter(x => filterCondition(x)); // Filter by tagword
|
||||
} else {
|
||||
|
||||
@@ -156,44 +156,30 @@ def get_embeddings(sd_model):
|
||||
# Version constants
|
||||
V1_SHAPE = 768
|
||||
V2_SHAPE = 1024
|
||||
VXL_SHAPE = 2048
|
||||
emb_v1 = []
|
||||
emb_v2 = []
|
||||
emb_vXL = []
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
|
||||
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
# Get the shape of the first item in the dict
|
||||
emb_a_shape = -1
|
||||
emb_b_shape = -1
|
||||
if (len(emb_type_a) > 0):
|
||||
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
|
||||
if (len(emb_type_b) > 0):
|
||||
emb_b_shape = next(iter(emb_type_b.items()))[1].shape
|
||||
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
|
||||
# Add embeddings to the correct list
|
||||
if (emb_a_shape == V1_SHAPE):
|
||||
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_a.items() if v.filename is not None]
|
||||
elif (emb_a_shape == V2_SHAPE):
|
||||
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_a.items() if v.filename is not None]
|
||||
for key, emb in (loaded | skipped).items():
|
||||
if (emb.filename is None):
|
||||
continue
|
||||
|
||||
if (emb_b_shape == V1_SHAPE):
|
||||
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_b.items() if v.filename is not None]
|
||||
elif (emb_b_shape == V2_SHAPE):
|
||||
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_b.items() if v.filename is not None]
|
||||
if emb.shape == V1_SHAPE:
|
||||
emb_v1.append((Path(emb.filename), key, "v1"))
|
||||
elif emb.shape == V2_SHAPE:
|
||||
emb_v2.append((Path(emb.filename), key, "v2"))
|
||||
elif emb.shape == VXL_SHAPE:
|
||||
emb_vXL.append((Path(emb.filename), key, "vXL"))
|
||||
|
||||
# Get shape of current model
|
||||
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
#model_shape = vec.shape[1]
|
||||
# Show relevant entries at the top
|
||||
#if (model_shape == V1_SHAPE):
|
||||
# results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2]
|
||||
#elif (model_shape == V2_SHAPE):
|
||||
# results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
|
||||
#else:
|
||||
# raise AttributeError # Fallback to old method
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2)
|
||||
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL)
|
||||
except AttributeError:
|
||||
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
|
||||
# Get a list of all embeddings in the folder
|
||||
|
||||
Reference in New Issue
Block a user