From 6d789653b946b79cceb9b715bc75f16552eef584 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:34:57 -0700 Subject: [PATCH] better model load logic --- modules/initialize.py | 4 ---- modules/sd_models.py | 17 +---------------- modules_forge/main_entry.py | 3 +++ 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/modules/initialize.py b/modules/initialize.py index 3cf71c66..199ccec9 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -120,10 +120,6 @@ def initialize_rest(*, reload_script_modules=False): sd_unet.list_unets() startup_timer.record("scripts list_unets") - from modules_forge import main_thread - import modules.sd_models - main_thread.async_run(modules.sd_models.model_data.get_sd_model) - from modules import shared_items shared_items.reload_hypernetworks() startup_timer.record("reload hypernetworks") diff --git a/modules/sd_models.py b/modules/sd_models.py index bb5bc34c..b5c04151 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -380,25 +380,10 @@ class SdModelData: self.sd_model = None self.loaded_sd_models = [] self.was_loaded_at_least_once = False - self.lock = threading.Lock() def get_sd_model(self): - if self.was_loaded_at_least_once: - return self.sd_model - if self.sd_model is None: - with self.lock: - if self.sd_model is not None or self.was_loaded_at_least_once: - return self.sd_model - - try: - load_model() - - except Exception as e: - errors.display(e, "loading stable diffusion model", full_traceback=True) - print("", file=sys.stderr) - print("Stable diffusion model failed to load", file=sys.stderr) - self.sd_model = None + raise ValueError('Something went wrong! Model is not loaded yet ...') return self.sd_model diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index e98449b6..689eb1b9 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -66,4 +66,7 @@ def forge_main_entry(): ui_checkpoint.change(lambda x: main_thread.async_run(checkpoint_change, x), inputs=[ui_checkpoint], show_progress=False) ui_vae.change(lambda x: main_thread.async_run(vae_change, x), inputs=[ui_vae], show_progress=False) ui_clip_skip.change(lambda x: main_thread.async_run(clip_skip_change, x), inputs=[ui_clip_skip], show_progress=False) + + # Load Model + main_thread.async_run(sd_models.load_model) return