Auto-refresh embedding list after model change

Uses own API endpoint and doesn't force-reload to skip unneeded work
(only works for A1111 as SD.Next model change detection isn't implemented yet)
This commit is contained in:
DominikDoom
2023-12-12 14:13:13 +01:00
parent 886704e351
commit f840586b6b
2 changed files with 20 additions and 6 deletions

View File

@@ -1269,6 +1269,13 @@ async function refreshTacTempFiles(api = false) {
}
}
async function refreshEmbeddings() {
await postAPI("tacapi/v1/refresh-embeddings", null);
embeddings = [];
await processQueue(QUEUE_FILE_LOAD, null);
console.log("TAC: Refreshed embeddings");
}
function addAutocompleteToArea(area) {
// Return if autocomplete is disabled for the current area type in config
let textAreaId = getTextAreaIdentifier(area);
@@ -1373,6 +1380,7 @@ async function setup() {
if (mutation.type === "attributes" && mutation.attributeName === "title") {
currentModelHash = mutation.target.title;
updateModelName();
refreshEmbeddings();
}
}
});

View File

@@ -365,11 +365,7 @@ if EMB_PATH.exists():
# Get embeddings after the model loaded callback
script_callbacks.on_model_loaded(get_embeddings)
def refresh_temp_files(*args, **kwargs):
global WILDCARD_EXT_PATHS
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
write_temp_files()
def refresh_embeddings(force: bool, *args, **kwargs):
try:
# Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand.
# This will just skip embedding loading if no model is loaded yet (or there really are no embeddings).
@@ -377,11 +373,17 @@ def refresh_temp_files(*args, **kwargs):
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
if len((loaded | skipped)) > 0:
load_textual_inversion_embeddings(force_reload = True)
load_textual_inversion_embeddings(force_reload=force)
get_embeddings(None)
except Exception:
pass
def refresh_temp_files(*args, **kwargs):
global WILDCARD_EXT_PATHS
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
write_temp_files()
refresh_embeddings(force=True)
def write_temp_files():
# Write wildcards to wc.txt if found
if WILDCARD_PATH.exists():
@@ -580,6 +582,10 @@ def api_tac(_: gr.Blocks, app: FastAPI):
async def api_refresh_temp_files():
refresh_temp_files()
@app.post("/tacapi/v1/refresh-embeddings")
async def api_refresh_embeddings():
refresh_embeddings(force=False)
@app.get("/tacapi/v1/lora-info/{lora_name}")
async def get_lora_info(lora_name):
return await get_json_info(LORA_PATH, lora_name)