diff --git a/javascript/tagAutocomplete.js b/javascript/tagAutocomplete.js index 37a20b6..2d0b3fc 100644 --- a/javascript/tagAutocomplete.js +++ b/javascript/tagAutocomplete.js @@ -1269,6 +1269,13 @@ async function refreshTacTempFiles(api = false) { } } +async function refreshEmbeddings() { + await postAPI("tacapi/v1/refresh-embeddings", null); + embeddings = []; + await processQueue(QUEUE_FILE_LOAD, null); + console.log("TAC: Refreshed embeddings"); +} + function addAutocompleteToArea(area) { // Return if autocomplete is disabled for the current area type in config let textAreaId = getTextAreaIdentifier(area); @@ -1373,6 +1380,7 @@ async function setup() { if (mutation.type === "attributes" && mutation.attributeName === "title") { currentModelHash = mutation.target.title; updateModelName(); + refreshEmbeddings(); } } }); diff --git a/scripts/tag_autocomplete_helper.py b/scripts/tag_autocomplete_helper.py index 1e0072d..994ce02 100644 --- a/scripts/tag_autocomplete_helper.py +++ b/scripts/tag_autocomplete_helper.py @@ -365,11 +365,7 @@ if EMB_PATH.exists(): # Get embeddings after the model loaded callback script_callbacks.on_model_loaded(get_embeddings) -def refresh_temp_files(*args, **kwargs): - global WILDCARD_EXT_PATHS - WILDCARD_EXT_PATHS = find_ext_wildcard_paths() - write_temp_files() - +def refresh_embeddings(force: bool, *args, **kwargs): try: # Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand. # This will just skip embedding loading if no model is loaded yet (or there really are no embeddings). @@ -377,11 +373,17 @@ def refresh_temp_files(*args, **kwargs): loaded = sd_hijack.model_hijack.embedding_db.word_embeddings skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings if len((loaded | skipped)) > 0: - load_textual_inversion_embeddings(force_reload = True) + load_textual_inversion_embeddings(force_reload=force) get_embeddings(None) except Exception: pass +def refresh_temp_files(*args, **kwargs): + global WILDCARD_EXT_PATHS + WILDCARD_EXT_PATHS = find_ext_wildcard_paths() + write_temp_files() + refresh_embeddings(force=True) + def write_temp_files(): # Write wildcards to wc.txt if found if WILDCARD_PATH.exists(): @@ -580,6 +582,10 @@ def api_tac(_: gr.Blocks, app: FastAPI): async def api_refresh_temp_files(): refresh_temp_files() + @app.post("/tacapi/v1/refresh-embeddings") + async def api_refresh_embeddings(): + refresh_embeddings(force=False) + @app.get("/tacapi/v1/lora-info/{lora_name}") async def get_lora_info(lora_name): return await get_json_info(LORA_PATH, lora_name)