diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index 9ebed8c..ec1a985 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -1348,6 +1348,13 @@ async function refreshTacTempFiles(api = false) { } } +async function refreshEmbeddings() { + await postAPI("tacapi/v1/refresh-embeddings", null); + embeddings = []; + await processQueue(QUEUE_FILE_LOAD, null); + console.log("TAC: Refreshed embeddings"); +} + function addAutocompleteToArea(area) { // Return if autocomplete is disabled for the current area type in config let textAreaId = getTextAreaIdentifier(area); @@ -1452,6 +1459,7 @@ async function setup() { if (mutation.type === "attributes" && mutation.attributeName === "title") { currentModelHash = mutation.target.title; updateModelName(); + refreshEmbeddings(); } } }); diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 07057c5..0a0ae15 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -262,20 +262,22 @@ def _get_lyco(): # Attempt to use the build-in Lora.networks Lora/LyCORIS models lists. try: - import importlib - lora_networks = importlib.import_module("extensions-builtin.Lora.networks") + import sys + from modules import extensions + sys.path.append(Path(extensions.extensions_builtin_dir).joinpath("Lora").as_posix()) + import lora # pyright: ignore [reportMissingImports] def _get_lora(): return [ Path(model.filename).absolute() - for model in lora_networks.available_networks.values() + for model in lora.available_loras.values() if Path(model.filename).absolute().is_relative_to(LORA_PATH) ] def _get_lyco(): return [ Path(model.filename).absolute() - for model in lora_networks.available_networks.values() + for model in lora.available_loras.values() if Path(model.filename).absolute().is_relative_to(LYCO_PATH) ] @@ -379,12 +381,24 @@ if EMB_PATH.exists(): # Get embeddings after the model loaded callback script_callbacks.on_model_loaded(get_embeddings) +def refresh_embeddings(force: bool, *args, **kwargs): + try: + # Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand. + # This will just skip embedding loading if no model is loaded yet (or there really are no embeddings). + # Try catch is just for safety incase sd_hijack access fails for some reason. + loaded = sd_hijack.model_hijack.embedding_db.word_embeddings + skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings + if len((loaded | skipped)) > 0: + load_textual_inversion_embeddings(force_reload=force) + get_embeddings(None) + except Exception: + pass + def refresh_temp_files(*args, **kwargs): global WILDCARD_EXT_PATHS WILDCARD_EXT_PATHS = find_ext_wildcard_paths() - load_textual_inversion_embeddings(force_reload = True) # Instant embedding reload. write_temp_files() - get_embeddings(shared.sd_model) + refresh_embeddings(force=True) def write_temp_files(): # Write wildcards to wc.txt if found @@ -597,6 +611,10 @@ def api_tac(_: gr.Blocks, app: FastAPI): async def api_refresh_temp_files(): refresh_temp_files() + @app.post("/tacapi/v1/refresh-embeddings") + async def api_refresh_embeddings(): + refresh_embeddings(force=False) + @app.get("/tacapi/v1/lora-info/{lora_name}") async def get_lora_info(lora_name): return await get_json_info(LORA_PATH, lora_name)