From d55e6b5bfeaa1d0433b8d371f9ab4ef051a019db Mon Sep 17 00:00:00 2001 From: altoiddealer Date: Mon, 26 Aug 2024 18:08:02 -0400 Subject: [PATCH] Replace API sd-vae with sd-modules (#1463) --- modules/api/api.py | 13 +++++-------- modules/api/models.py | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index c40b4966..b4ee221f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -19,7 +19,6 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers from modules.api import models -from modules_forge import main_entry from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding @@ -224,7 +223,7 @@ class Api: self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem]) - self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem]) + self.add_api_route("/sdapi/v1/sd-modules", self.get_sd_vaes_and_text_encoders, methods=["GET"], response_model=list[models.SDModuleItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) @@ -691,9 +690,7 @@ class Api: for k, v in req.items(): shared.opts.set(k, v, is_api=True) - main_entry.checkpoint_change(checkpoint_name) - # shared.opts.save(shared.config_filename) --- applied in checkpoint_change() - + shared.opts.save(shared.config_filename) return def get_cmd_flags(self): @@ -737,9 +734,9 @@ class Api: import modules.sd_models as sd_models return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": getattr(x, 'config', None)} for x in sd_models.checkpoints_list.values()] - def get_sd_vaes(self): - import modules.sd_vae as sd_vae - return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()] + def get_sd_vaes_and_text_encoders(self): + from modules_forge.main_entry import module_list + return [{"model_name": x, "filename": module_list[x]} for x in module_list.keys()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index 972bb1bd..9b6a8fe6 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -264,7 +264,7 @@ class SDModelItem(BaseModel): filename: str = Field(title="Filename") config: Optional[str] = Field(default=None, title="Config file") -class SDVaeItem(BaseModel): +class SDModuleItem(BaseModel): class Config: protected_namespaces = ()