mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
fix API get/refresh embeddings (#2271)
This commit is contained in:
@@ -21,7 +21,9 @@ from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing,
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images
|
||||
from modules.textual_inversion.textual_inversion import create_embedding
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
from PIL import PngImagePlugin
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
@@ -265,6 +267,10 @@ class Api:
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
|
||||
|
||||
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
|
||||
|
||||
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
@@ -744,8 +750,6 @@ class Api:
|
||||
return styleList
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
@@ -759,13 +763,13 @@ class Api:
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": convert_embeddings(db.word_embeddings),
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
"loaded": convert_embeddings(self.embedding_db.word_embeddings),
|
||||
"skipped": convert_embeddings(self.embedding_db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_embeddings(self):
|
||||
with self.queue_lock:
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
with self.queue_lock:
|
||||
@@ -778,15 +782,14 @@ class Api:
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_embedding")
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
filename = modules.textual_inversion.textual_inversion.create_embedding(**args) # create empty embedding
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) # reload embeddings so new one can be immediately used
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_hypernetwork")
|
||||
|
||||
Reference in New Issue
Block a user