fix API get/refresh embeddings (#2271)

This commit is contained in:
DenOfEquity
2024-11-06 18:24:28 +00:00
committed by GitHub
parent e2fe29c104
commit 329c3ca334

View File

@@ -21,7 +21,9 @@ from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing,
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images
from modules.textual_inversion.textual_inversion import create_embedding
import modules.textual_inversion.textual_inversion
from modules.shared import cmd_opts
from PIL import PngImagePlugin
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -265,6 +267,10 @@ class Api:
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
def add_api_route(self, path: str, endpoint, **kwargs):
@@ -744,8 +750,6 @@ class Api:
return styleList
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {
"step": embedding.step,
@@ -759,13 +763,13 @@ class Api:
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {
"loaded": convert_embeddings(db.word_embeddings),
"skipped": convert_embeddings(db.skipped_embeddings),
"loaded": convert_embeddings(self.embedding_db.word_embeddings),
"skipped": convert_embeddings(self.embedding_db.skipped_embeddings),
}
def refresh_embeddings(self):
with self.queue_lock:
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
def refresh_checkpoints(self):
with self.queue_lock:
@@ -778,15 +782,14 @@ class Api:
def create_embedding(self, args: dict):
try:
shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
filename = modules.textual_inversion.textual_inversion.create_embedding(**args) # create empty embedding
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) # reload embeddings so new one can be immediately used
return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
return models.TrainResponse(info=f"create embedding error: {e}")
finally:
shared.state.end()
def create_hypernetwork(self, args: dict):
try:
shared.state.begin(job="create_hypernetwork")