Merge branch 'main' into feature-sort-by-frequent-use

This commit is contained in:
DominikDoom
2023-12-12 14:23:18 +01:00
2 changed files with 32 additions and 6 deletions

View File

@@ -1348,6 +1348,13 @@ async function refreshTacTempFiles(api = false) {
}
}
async function refreshEmbeddings() {
await postAPI("tacapi/v1/refresh-embeddings", null);
embeddings = [];
await processQueue(QUEUE_FILE_LOAD, null);
console.log("TAC: Refreshed embeddings");
}
function addAutocompleteToArea(area) {
// Return if autocomplete is disabled for the current area type in config
let textAreaId = getTextAreaIdentifier(area);
@@ -1452,6 +1459,7 @@ async function setup() {
if (mutation.type === "attributes" && mutation.attributeName === "title") {
currentModelHash = mutation.target.title;
updateModelName();
refreshEmbeddings();
}
}
});

View File

@@ -262,20 +262,22 @@ def _get_lyco():
# Attempt to use the build-in Lora.networks Lora/LyCORIS models lists.
try:
import importlib
lora_networks = importlib.import_module("extensions-builtin.Lora.networks")
import sys
from modules import extensions
sys.path.append(Path(extensions.extensions_builtin_dir).joinpath("Lora").as_posix())
import lora # pyright: ignore [reportMissingImports]
def _get_lora():
return [
Path(model.filename).absolute()
for model in lora_networks.available_networks.values()
for model in lora.available_loras.values()
if Path(model.filename).absolute().is_relative_to(LORA_PATH)
]
def _get_lyco():
return [
Path(model.filename).absolute()
for model in lora_networks.available_networks.values()
for model in lora.available_loras.values()
if Path(model.filename).absolute().is_relative_to(LYCO_PATH)
]
@@ -379,12 +381,24 @@ if EMB_PATH.exists():
# Get embeddings after the model loaded callback
script_callbacks.on_model_loaded(get_embeddings)
def refresh_embeddings(force: bool, *args, **kwargs):
try:
# Fix for SD.Next infinite refresh loop due to gradio not updating after model load on demand.
# This will just skip embedding loading if no model is loaded yet (or there really are no embeddings).
# Try catch is just for safety incase sd_hijack access fails for some reason.
loaded = sd_hijack.model_hijack.embedding_db.word_embeddings
skipped = sd_hijack.model_hijack.embedding_db.skipped_embeddings
if len((loaded | skipped)) > 0:
load_textual_inversion_embeddings(force_reload=force)
get_embeddings(None)
except Exception:
pass
def refresh_temp_files(*args, **kwargs):
global WILDCARD_EXT_PATHS
WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
load_textual_inversion_embeddings(force_reload = True) # Instant embedding reload.
write_temp_files()
get_embeddings(shared.sd_model)
refresh_embeddings(force=True)
def write_temp_files():
# Write wildcards to wc.txt if found
@@ -597,6 +611,10 @@ def api_tac(_: gr.Blocks, app: FastAPI):
async def api_refresh_temp_files():
refresh_temp_files()
@app.post("/tacapi/v1/refresh-embeddings")
async def api_refresh_embeddings():
refresh_embeddings(force=False)
@app.get("/tacapi/v1/lora-info/{lora_name}")
async def get_lora_info(lora_name):
return await get_json_info(LORA_PATH, lora_name)