diff --git a/modules/api/api.py b/modules/api/api.py index 50ad9ef3..9754be03 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -21,7 +21,9 @@ from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images -from modules.textual_inversion.textual_inversion import create_embedding +import modules.textual_inversion.textual_inversion +from modules.shared import cmd_opts + from PIL import PngImagePlugin from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -265,6 +267,10 @@ class Api: if not self.default_script_arg_img2img: self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner) + self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) + self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) + def add_api_route(self, path: str, endpoint, **kwargs): @@ -744,8 +750,6 @@ class Api: return styleList def get_embeddings(self): - db = sd_hijack.model_hijack.embedding_db - def convert_embedding(embedding): return { "step": embedding.step, @@ -759,13 +763,13 @@ class Api: return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()} return { - "loaded": convert_embeddings(db.word_embeddings), - "skipped": convert_embeddings(db.skipped_embeddings), + "loaded": convert_embeddings(self.embedding_db.word_embeddings), + "skipped": convert_embeddings(self.embedding_db.skipped_embeddings), } def refresh_embeddings(self): with self.queue_lock: - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) def refresh_checkpoints(self): with self.queue_lock: @@ -778,15 +782,14 @@ class Api: def create_embedding(self, args: dict): try: shared.state.begin(job="create_embedding") - filename = create_embedding(**args) # create empty embedding - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used + filename = modules.textual_inversion.textual_inversion.create_embedding(**args) # create empty embedding + self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) # reload embeddings so new one can be immediately used return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: return models.TrainResponse(info=f"create embedding error: {e}") finally: shared.state.end() - def create_hypernetwork(self, args: dict): try: shared.state.begin(job="create_hypernetwork")