Partial embedding fixes for webui forge

Resolves some symptoms of #297, but doesn't fix the underlying cause
This commit is contained in:
DominikDoom
2024-08-11 14:08:31 +02:00
parent 1c6bba2a3d
commit f64d728ac6

View File

@@ -13,7 +13,7 @@ import gradio as gr
import yaml
from fastapi import FastAPI
from fastapi.responses import Response, FileResponse, JSONResponse
from modules import script_callbacks, sd_hijack, shared, hashes
from modules import script_callbacks, sd_hijack, shared, hashes, sd_models
from pydantic import BaseModel
from scripts.model_keyword_support import (get_lora_simple_hash,
@@ -41,9 +41,32 @@ except (ImportError, ValueError, sqlite3.Error) as e:
print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"")
db = None
def get_embed_db(sd_model=None):
"""Returns the embedding database, if available."""
try:
return sd_hijack.model_hijack.embedding_db
except Exception:
try: # sd next with diffusers backend
sdnext_model = sd_model if sd_model is not None else shared.sd_model
return sdnext_model.embedding_db
except Exception:
try: # forge webui
forge_model = sd_model if sd_model is not None else sd_models.model_data.get_sd_model()
if type(forge_model).__name__ == "FakeInitialModel":
return None
else:
processer = getattr(forge_model, "text_processing_engine", getattr(forge_model, "text_processing_engine_l"))
return processer.embeddings
except Exception:
return None
# Attempt to get embedding load function, using the same call as api.
try:
load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings
embed_db = get_embed_db()
if embed_db is not None:
load_textual_inversion_embeddings = embed_db.load_textual_inversion_embeddings
else:
load_textual_inversion_embeddings = lambda *args, **kwargs: None
except Exception as e: # Not supported.
load_textual_inversion_embeddings = lambda *args, **kwargs: None
print("Tag Autocomplete: Cannot reload embeddings instantly:", e)
@@ -190,19 +213,10 @@ def get_embeddings(sd_model):
results = []
try:
# The sd_model embedding_db reference only exists in sd.next with diffusers backend
try:
loaded_sdnext = sd_model.embedding_db.word_embeddings
skipped_sdnext = sd_model.embedding_db.skipped_embeddings
except (NameError, AttributeError):
loaded_sdnext = {}
skipped_sdnext = {}
embed_db = get_embed_db(sd_model)
# Get embedding dict from sd_hijack to separate v1/v2 embeddings
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
loaded = loaded | loaded_sdnext
skipped = skipped | skipped_sdnext
loaded = embed_db.word_embeddings
skipped = embed_db.skipped_embeddings
# Add embeddings to the correct list
for key, emb in (skipped | loaded).items():
@@ -430,8 +444,11 @@ def refresh_embeddings(force: bool, *args, **kwargs):
# Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand.
# This will just skip embedding loading if no model is loaded yet (or there really are no embeddings).
# Try catch is just for safety incase sd_hijack access fails for some reason.
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
embed_db = get_embed_db()
if embed_db is None:
return
loaded = embed_db.word_embeddings
skipped = embed_db.skipped_embeddings
if len((loaded | skipped)) > 0:
load_textual_inversion_embeddings(force_reload=force)
get_embeddings(None)