diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 781e819..f6b9dde 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -13,7 +13,7 @@ import gradio as gr import yaml from fastapi import FastAPI from fastapi.responses import Response, FileResponse, JSONResponse -from modules import script_callbacks, sd_hijack, shared, hashes +from modules import script_callbacks, sd_hijack, shared, hashes, sd_models from pydantic import BaseModel from scripts.model_keyword_support import (get_lora_simple_hash, @@ -41,9 +41,32 @@ except (ImportError, ValueError, sqlite3.Error) as e: print(f"Tag Autocomplete: Tag frequency database error - \"{e}\"") db = None +def get_embed_db(sd_model=None): + """Returns the embedding database, if available.""" + try: + return sd_hijack.model_hijack.embedding_db + except Exception: + try: # sd next with diffusers backend + sdnext_model = sd_model if sd_model is not None else shared.sd_model + return sdnext_model.embedding_db + except Exception: + try: # forge webui + forge_model = sd_model if sd_model is not None else sd_models.model_data.get_sd_model() + if type(forge_model).__name__ == "FakeInitialModel": + return None + else: + processer = getattr(forge_model, "text_processing_engine", getattr(forge_model, "text_processing_engine_l")) + return processer.embeddings + except Exception: + return None + # Attempt to get embedding load function, using the same call as api. try: - load_textual_inversion_embeddings = sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings + embed_db = get_embed_db() + if embed_db is not None: + load_textual_inversion_embeddings = embed_db.load_textual_inversion_embeddings + else: + load_textual_inversion_embeddings = lambda *args, **kwargs: None except Exception as e: # Not supported. load_textual_inversion_embeddings = lambda *args, **kwargs: None print("Tag Autocomplete: Cannot reload embeddings instantly:", e) @@ -190,19 +213,10 @@ def get_embeddings(sd_model): results = [] try: - # The sd_model embedding_db reference only exists in sd.next with diffusers backend - try: - loaded_sdnext = sd_model.embedding_db.word_embeddings - skipped_sdnext = sd_model.embedding_db.skipped_embeddings - except (NameError, AttributeError): - loaded_sdnext = {} - skipped_sdnext = {} + embed_db = get_embed_db(sd_model) - # Get embedding dict from sd_hijack to separate v1/v2 embeddings - loaded = sd_hijack.model_hijack.embedding_db.word_embeddings - skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings - loaded = loaded | loaded_sdnext - skipped = skipped | skipped_sdnext + loaded = embed_db.word_embeddings + skipped = embed_db.skipped_embeddings # Add embeddings to the correct list for key, emb in (skipped | loaded).items(): @@ -430,8 +444,11 @@ def refresh_embeddings(force: bool, *args, **kwargs): # Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand. # This will just skip embedding loading if no model is loaded yet (or there really are no embeddings). # Try catch is just for safety incase sd_hijack access fails for some reason. - loaded = sd_hijack.model_hijack.embedding_db.word_embeddings - skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings + embed_db = get_embed_db() + if embed_db is None: + return + loaded = embed_db.word_embeddings + skipped = embed_db.skipped_embeddings if len((loaded | skipped)) > 0: load_textual_inversion_embeddings(force_reload=force) get_embeddings(None)