mirror of
https://github.com/DominikDoom/a1111-sd-webui-tagcomplete.git
synced 2026-01-26 11:09:54 +00:00
Auto-refresh embedding list after model change
Uses own API endpoint and doesn't force-reload to skip unneeded work (only works for A1111 as SD.Next model change detection isn't implemented yet)
This commit is contained in:
@@ -365,11 +365,7 @@ if EMB_PATH.exists():
|
||||
# Get embeddings after the model loaded callback
|
||||
script_callbacks.on_model_loaded(get_embeddings)
|
||||
|
||||
def refresh_temp_files(*args, **kwargs):
|
||||
global WILDCARD_EXT_PATHS
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
write_temp_files()
|
||||
|
||||
def refresh_embeddings(force: bool, *args, **kwargs):
|
||||
try:
|
||||
# Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand.
|
||||
# This will just skip embedding loading if no model is loaded yet (or there really are no embeddings).
|
||||
@@ -377,11 +373,17 @@ def refresh_temp_files(*args, **kwargs):
|
||||
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
|
||||
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
|
||||
if len((loaded | skipped)) > 0:
|
||||
load_textual_inversion_embeddings(force_reload = True)
|
||||
load_textual_inversion_embeddings(force_reload=force)
|
||||
get_embeddings(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def refresh_temp_files(*args, **kwargs):
|
||||
global WILDCARD_EXT_PATHS
|
||||
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
|
||||
write_temp_files()
|
||||
refresh_embeddings(force=True)
|
||||
|
||||
def write_temp_files():
|
||||
# Write wildcards to wc.txt if found
|
||||
if WILDCARD_PATH.exists():
|
||||
@@ -580,6 +582,10 @@ def api_tac(_: gr.Blocks, app: FastAPI):
|
||||
async def api_refresh_temp_files():
|
||||
refresh_temp_files()
|
||||
|
||||
@app.post("/tacapi/v1/refresh-embeddings")
|
||||
async def api_refresh_embeddings():
|
||||
refresh_embeddings(force=False)
|
||||
|
||||
@app.get("/tacapi/v1/lora-info/{lora_name}")
|
||||
async def get_lora_info(lora_name):
|
||||
return await get_json_info(LORA_PATH, lora_name)
|
||||
|
||||
Reference in New Issue
Block a user