Proper support for SDXL embeddings

Now in their own category, other embeddings don't get mislabeled anymore if an XL model is loaded
This commit is contained in:
DominikDoom
2023-09-26 14:14:20 +02:00
parent 94ec8884c3
commit 998514bebb
2 changed files with 18 additions and 29 deletions

View File

@@ -11,12 +11,15 @@ class EmbeddingParser extends BaseTagParser {
if (searchTerm.startsWith("v1") || searchTerm.startsWith("v2")) {
versionString = searchTerm.slice(0, 2);
searchTerm = searchTerm.slice(2);
} else if (searchTerm.startsWith("vxl")) {
versionString = searchTerm.slice(0, 3);
searchTerm = searchTerm.slice(3);
}
let filterCondition = x => x[0].toLowerCase().includes(searchTerm) || x[0].toLowerCase().replaceAll(" ", "_").includes(searchTerm);
if (versionString)
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2] === versionString); // Filter by tagword
tempResults = embeddings.filter(x => filterCondition(x) && x[2] && x[2].toLowerCase() === versionString.toLowerCase()); // Filter by tagword
else
tempResults = embeddings.filter(x => filterCondition(x)); // Filter by tagword
} else {

View File

@@ -156,44 +156,30 @@ def get_embeddings(sd_model):
# Version constants
V1_SHAPE = 768
V2_SHAPE = 1024
VXL_SHAPE = 2048
emb_v1 = []
emb_v2 = []
emb_vXL = []
results = []
try:
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
# Get the shape of the first item in the dict
emb_a_shape = -1
emb_b_shape = -1
if (len(emb_type_a) > 0):
emb_a_shape = next(iter(emb_type_a.items()))[1].shape
if (len(emb_type_b) > 0):
emb_b_shape = next(iter(emb_type_b.items()))[1].shape
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
# Add embeddings to the correct list
if (emb_a_shape == V1_SHAPE):
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_a.items() if v.filename is not None]
elif (emb_a_shape == V2_SHAPE):
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_a.items() if v.filename is not None]
for key, emb in (loaded | skipped).items():
if (emb.filename is None):
continue
if (emb_b_shape == V1_SHAPE):
emb_v1 = [(Path(v.filename), k, "v1") for (k,v) in emb_type_b.items() if v.filename is not None]
elif (emb_b_shape == V2_SHAPE):
emb_v2 = [(Path(v.filename), k, "v2") for (k,v) in emb_type_b.items() if v.filename is not None]
if emb.shape == V1_SHAPE:
emb_v1.append((Path(emb.filename), key, "v1"))
elif emb.shape == V2_SHAPE:
emb_v2.append((Path(emb.filename), key, "v2"))
elif emb.shape == VXL_SHAPE:
emb_vXL.append((Path(emb.filename), key, "vXL"))
# Get shape of current model
#vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
#model_shape = vec.shape[1]
# Show relevant entries at the top
#if (model_shape == V1_SHAPE):
# results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2]
#elif (model_shape == V2_SHAPE):
# results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
#else:
# raise AttributeError # Fallback to old method
results = sort_models(emb_v1) + sort_models(emb_v2)
results = sort_models(emb_v1) + sort_models(emb_v2) + sort_models(emb_vXL)
except AttributeError:
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
# Get a list of all embeddings in the folder